ipoeyke commited on
Commit
a01368f
·
1 Parent(s): 28f6140

reflect intended new api

Browse files
Files changed (3) hide show
  1. README.md +120 -0
  2. ragulator-deberta-v3-large.model +0 -3
  3. requirements.txt +5 -4
README.md CHANGED
@@ -1,3 +1,123 @@
1
  ---
2
  license: mit
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: mit
3
  ---
4
+
5
+ # RAGulator-deberta-v3-large
6
+
7
+ This is the out-of-context detection model from our work:
8
+
9
+ [**RAGulator: Lightweight Out-of-Context Detectors for Grounded Text Generation**](https://arxiv.org/abs/2411.03920)
10
+
11
+ This repository contains model files for the deberta-v3-large variant of RAGulator. Code can be found [here]().
12
+
13
+ ## Key Points
14
+ * RAGulator predicts whether a sentence is out-of-context (OOC) from retrieved text documents in a RAG setting.
15
+ * We preprocess a combination of summarisation and semantic textual similarity datasets (STS) to construct training data using minimal
16
+ resources.
17
+ * We demonstrate 2 types of trained models: tree-based meta-models trained on features engineered on preprocessed text, and BERT-based classifiers fine-tuned directly on original text.
18
+ * We find that fine-tuned DeBERTa is not only the best-performing model under this pipeline, but it is also fast and does not require additional text preprocessing or feature engineering.
19
+
20
+ ## Model Details
21
+
22
+ ### Dataset
23
+ Training data for RAGulator is adapted from a combination of summarisation and STS datasets to simulate RAG:
24
+ * [BBC](https://www.kaggle.com/datasets/pariza/bbc-news-summary)
25
+ * [CNN DailyMail ver. 3.0.0](https://huggingface.co/datasets/abisee/cnn_dailymail)
26
+ * [PubMed](https://huggingface.co/datasets/ccdv/pubmed-summarization)
27
+ * [MRPC from the GLUE dataset](https://huggingface.co/datasets/nyu-mll/glue/)
28
+ * [SNLI ver. 1.0](https://huggingface.co/datasets/stanfordnlp/snli)
29
+
30
+ The datasets were transformed before concatenation into the final dataset. Each row of the final dataset consists \[`sentence`, `context`, `OOC label`\].
31
+ * For summarisation datasets, transformation was done by randomly pairing summary abstracts with unrelated articles to create OOC pairs, then sentencizing
32
+ the abstracts to create one example for each abstract sentence.
33
+ * For STS datasets, transformation was done by inserting random sentences from the datasets to one of the sentences in the pair to simulate a long "context". The original labels were mapped to our OOC definition. If the original pair was indicated as dissimilar, we consider the pair as OOC.
34
+
35
+ To enable training of BERT-based classifiers, each training example was split into sub-sequences of maximum 512 tokens. The OOC label for each sub-sequence was derived through a generative labelling process with Llama-3.1-70b-Instruct.
36
+
37
+ ### Model Training
38
+ RAGulator is fine-tuned from `microsoft/deberta-v3-large` ([He et al., 2023](https://arxiv.org/pdf/2111.09543.pdf)).
39
+
40
+ ### Model Performance
41
+ <p align="center">
42
+ <img src="./model-performance.png" width="700">
43
+ </p>
44
+
45
+ We compare our models to LLM-as-a-judge (Llama-3.1-70b-Instruct) as a baseline. We evaluate on both a held-out data split of our simulated RAG dataset, as well as an out-of-distribution collection of private enterprise data, which consists of RAG responses from a real use case.
46
+
47
+ The deberta-v3-large variant is our best-performing model, showing a 19% increase in AUROC and a 17% increase in F1 score despite being significantly smaller than Llama-3.1.
48
+
49
+ ## Basic Usage
50
+ ```python
51
+ import torch
52
+ from transformers import DebertaV2Tokenizer, DebertaV2ForSequenceClassification
53
+
54
+ model_path = "./ragulator-deberta-v3-large" # assuming model folder located here
55
+ tokenizer = DebertaV2Tokenizer.from_pretrained(model_path)
56
+ model = DebertaV2ForSequenceClassification.from_pretrained(
57
+ model_path,
58
+ num_labels=2
59
+ )
60
+ model.eval()
61
+
62
+ # input
63
+ sentences = ["This is the first sentence", "This is the second sentence"]
64
+ contexts = ["This is the first context", "This is the second context"]
65
+ inputs = tokenizer(
66
+ sentences,
67
+ contexts,
68
+ add_special_tokens=True,
69
+ return_token_type_ids=True,
70
+ return_attention_mask=True,
71
+ padding='max_length',
72
+ max_length=512,
73
+ truncation='longest_first',
74
+ return_tensors='pt'
75
+ )
76
+
77
+ # forward pass
78
+ with torch.no_grad():
79
+ outputs = self.model(**inputs)
80
+
81
+ # OOC score
82
+ fn = torch.nn.Softmax(dim=-1)
83
+ ooc_scores = fn(outputs.logits).cpu().numpy()[:,1]
84
+ ```
85
+
86
+ ## Usage - batch and long-context inference
87
+ We provide a simple wrapper to demonstrate batch inference and accommodation for long-context examples. First, install the package:
88
+ ```bash
89
+ pip install "ragulator @ git+https://github.com/ipoeyke/RAGulator.git@main"
90
+ ```
91
+ ```python
92
+ from ragulator import RAGulator
93
+
94
+ model = RAGulator(
95
+ model_variant='deberta-v3-large', # only value supported for now
96
+ batch_size=32,
97
+ device='cpu'
98
+ )
99
+
100
+ # input
101
+ sentences = ["This is the first sentence", "This is the second sentence"]
102
+ contexts = ["This is the first context", "This is the second context"]
103
+
104
+ # batch inference
105
+ model.infer_batch(
106
+ sentences,
107
+ contexts,
108
+ return_probas=True # True for OOC probabilities, False for binary labels
109
+ )
110
+ ```
111
+
112
+ ## Citation
113
+ ```
114
+ @misc{poey2024ragulatorlightweightoutofcontextdetectors,
115
+ title={RAGulator: Lightweight Out-of-Context Detectors for Grounded Text Generation},
116
+ author={Ian Poey and Jiajun Liu and Qishuai Zhong and Adrien Chenailler},
117
+ year={2024},
118
+ eprint={2411.03920},
119
+ archivePrefix={arXiv},
120
+ primaryClass={cs.CL},
121
+ url={https://arxiv.org/abs/2411.03920},
122
+ }
123
+ ```
ragulator-deberta-v3-large.model DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:87f0aa4b5fae329ef9f7090a8312d48b4a4bd7f15f7f498e937414f020bdeacf
3
- size 1740430495
 
 
 
 
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
- numpy>=1.25.2
2
- spacy>=3.7.4
3
- transformers>=4.29.2
4
- pytorch>=1.13.1
 
 
1
+ numpy==1.25.2
2
+ sentencepiece==0.2.0
3
+ spacy==3.7.4
4
+ torch==1.13.1
5
+ transformers==4.29.2