ONNX
English
Shing Yee commited on
Commit
ba6803f
·
unverified ·
1 Parent(s): e901653

feat: add files

Browse files
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .venv
2
+ .DS_store
README.md CHANGED
@@ -2,4 +2,49 @@
2
  license: other
3
  license_name: govtech-singapore
4
  license_link: LICENSE
5
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  license: other
3
  license_name: govtech-singapore
4
  license_link: LICENSE
5
+ ---
6
+
7
+ # Off-Topic Classification Model
8
+
9
+ This repository contains a fine-tuned **Jina Embeddings model** designed to perform binary classification. The model predicts whether a user prompt is **off-topic** based on the intended purpose defined in the system prompt.
10
+
11
+ ## Model Highlights
12
+
13
+ - **Base Model**: [`jina-embeddings-v2-small-en`](https://huggingface.co/jinaai/jina-embeddings-v2-small-en)
14
+ - **Maximum Context Length**: 1024 tokens
15
+ - **Task**: Binary classification (on-topic/off-topic)
16
+
17
+ ## Performance
18
+
19
+ | Approach | Model | ROC-AUC | F1 | Precision | Recall |
20
+ |---------------------------------------|--------------------------------|---------|------|-----------|--------|
21
+ | Fine-tuned bi-encoder classifier | jina-embeddings-v2-small-en | 0.99 | 0.97 | 0.99 | 0.95 |
22
+
23
+ ## Usage
24
+ 1. Clone this repository and install the required dependencies:
25
+
26
+ ```bash
27
+ pip install -r requirements.txt
28
+ ```
29
+
30
+ 2. You can run the model using two options:
31
+
32
+ **Option 1**: Using `inference_onnx.py` with the ONNX Model.
33
+
34
+ ```
35
+ python inference_onnx.py '[
36
+ ["System prompt example 1", "User prompt example 1"],
37
+ ["System prompt example 2", "System prompt example 2]
38
+ ]'
39
+ ```
40
+
41
+ **Option 2**: Using `inference_safetensors.py` with PyTorch and SafeTensors.
42
+
43
+ ```
44
+ python inference_safetensors.py '[
45
+ ["System prompt example 1", "User prompt example 1"],
46
+ ["System prompt example 2", "System prompt example 2]
47
+ ]'
48
+ ```
49
+
50
+ Read more about this model in our [technical report](https://arxiv.org/abs/2411.12946).
config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "description": "Off-Topic classifier designed to block user prompts that do not align with the intended purpose of the system, as determined by the system prompt.",
3
+ "classifier": {
4
+ "embedding": {
5
+ "model_name": "jinaai/jina-embeddings-v2-small-en",
6
+ "max_length": 1024,
7
+ "model_weights_fp": "models/off-topic-jinaai-jina-embeddings-v2-small-en-TwinEncoder.safetensors",
8
+ "model_fp": "models/off-topic-jinaai-jina-embeddings-v2-small-en-TwinEncoder.onnx"
9
+ }
10
+ }
11
+ }
inference_onnx.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ inference_onnx.py
3
+
4
+ This script leverages ONNX runtime to perform inference with a pre-trained model.
5
+ """
6
+ import json
7
+ import torch
8
+ import sys
9
+ import numpy as np
10
+ import onnxruntime as rt
11
+
12
+ from huggingface_hub import hf_hub_download
13
+ from transformers import AutoTokenizer
14
+
15
+ repo_path = "govtech/jina-embeddings-v2-small-en-off-topic"
16
+ config_path = hf_hub_download(repo_id=repo_path, filename="config.json")
17
+ config_path = "config.json"
18
+
19
+ with open(config_path, 'r') as f:
20
+ config = json.load(f)
21
+
22
+ def predict(sentence1, sentence2):
23
+ """
24
+ Predicts the label for a pair of sentences using a fine-tuned ONNX model.
25
+
26
+ This function tokenizes the input sentences, prepares them as inputs for an ONNX model,
27
+ and performs inference to predict the label and probabilities for the given sentence pair.
28
+
29
+ Args:
30
+ - sentence1 (str): The first input sentence.
31
+ - sentence2 (str): The second input sentence.
32
+
33
+ Returns:
34
+ tuple:
35
+ - predicted_label (int): The predicted label (e.g., 0 or 1).
36
+ - probabilities (numpy.ndarray): The probabilities for each class.
37
+ """
38
+ # Load model configuration
39
+ model_name = config['classifier']['embedding']['model_name']
40
+ max_length = config['classifier']['embedding']['max_length']
41
+ model_fp = config['classifier']['embedding']['model_fp']
42
+
43
+ # Set device and load tokenizer
44
+ device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
45
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
46
+
47
+ # Get inputs
48
+ inputs1 = tokenizer(sentence1, return_tensors="pt", truncation=True, padding="max_length", max_length=max_length)
49
+ inputs2 = tokenizer(sentence2, return_tensors="pt", truncation=True, padding="max_length", max_length=max_length)
50
+ input_ids1 = inputs1['input_ids'].to(device)
51
+ attention_mask1 = inputs1['attention_mask'].to(device)
52
+ input_ids2 = inputs2['input_ids'].to(device)
53
+ attention_mask2 = inputs2['attention_mask'].to(device)
54
+
55
+ # Download the classifier from HuggingFace hub
56
+ local_model_fp = model_fp
57
+ local_model_fp = hf_hub_download(repo_id=repo_path, filename=model_fp)
58
+
59
+ # Run inference
60
+ session = rt.InferenceSession(local_model_fp) # Load the ONNX model
61
+ onnx_inputs = {
62
+ session.get_inputs()[0].name: input_ids1.cpu().numpy(),
63
+ session.get_inputs()[1].name: attention_mask1.cpu().numpy(),
64
+ session.get_inputs()[2].name: input_ids2.cpu().numpy(),
65
+ session.get_inputs()[3].name: attention_mask2.cpu().numpy(),
66
+ }
67
+ outputs = session.run(None, onnx_inputs)
68
+ probabilities = torch.softmax(torch.tensor(outputs[0]), dim=1)
69
+ predicted_label = torch.argmax(probabilities, dim=1).item()
70
+
71
+ return predicted_label, probabilities.cpu().numpy()
72
+
73
+ if __name__ == "__main__":
74
+ # Load data
75
+ input_data = sys.argv[1]
76
+ sentence_pairs = json.loads(input_data)
77
+
78
+ # Validate input data format
79
+ if not all(isinstance(pair[0], str) and isinstance(pair[1], str) for pair in sentence_pairs):
80
+ raise ValueError("Each pair must contain two strings.")
81
+
82
+ for idx, (sentence1, sentence2) in enumerate(sentence_pairs):
83
+
84
+ # Generate prediction and scores
85
+ predicted_label, probabilities = predict(sentence1, sentence2)
86
+
87
+ # Print the results
88
+ print(f"Pair {idx + 1}:")
89
+ print(f" Sentence 1: {sentence1}")
90
+ print(f" Sentence 2: {sentence2}")
91
+ print(f" Predicted Label: {predicted_label}")
92
+ print(f" Probabilities: {probabilities}")
93
+ print('-' * 50)
inference_safetensors.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ inference_safetensors.py
3
+
4
+ Defines the architecture of the fine-tuned embedding model used for Off-Topic classification.
5
+ """
6
+ import json
7
+ import torch
8
+ import sys
9
+ import torch.nn as nn
10
+
11
+ from huggingface_hub import hf_hub_download
12
+ from safetensors.torch import load_file
13
+ from transformers import AutoTokenizer, AutoModel
14
+
15
+ # Adapter for embeddings
16
+ class Adapter(nn.Module):
17
+ def __init__(self, hidden_size):
18
+ super(Adapter, self).__init__()
19
+ self.down_project = nn.Linear(hidden_size, hidden_size // 2)
20
+ self.activation = nn.ReLU()
21
+ self.up_project = nn.Linear(hidden_size // 2, hidden_size)
22
+
23
+ def forward(self, x):
24
+ down = self.down_project(x)
25
+ activated = self.activation(down)
26
+ up = self.up_project(activated)
27
+ return up + x # Residual connection
28
+
29
+ # Pool by attention score
30
+ class AttentionPooling(nn.Module):
31
+ def __init__(self, hidden_size):
32
+ super(AttentionPooling, self).__init__()
33
+ self.attention_weights = nn.Parameter(torch.randn(hidden_size))
34
+
35
+ def forward(self, hidden_states):
36
+ # hidden_states: [seq_len, batch_size, hidden_size]
37
+ scores = torch.matmul(hidden_states, self.attention_weights)
38
+ attention_weights = torch.softmax(scores, dim=0)
39
+ weighted_sum = torch.sum(attention_weights.unsqueeze(-1) * hidden_states, dim=0)
40
+ return weighted_sum
41
+
42
+ # Custom bi-encoder model with MLP layers for interaction
43
+ class CrossEncoderWithSharedBase(nn.Module):
44
+ def __init__(self, base_model, num_labels=2, num_heads=8):
45
+ super(CrossEncoderWithSharedBase, self).__init__()
46
+ # Shared pre-trained model
47
+ self.shared_encoder = base_model
48
+ hidden_size = self.shared_encoder.config.hidden_size
49
+ # Sentence-specific adapters
50
+ self.adapter1 = Adapter(hidden_size)
51
+ self.adapter2 = Adapter(hidden_size)
52
+ # Cross-attention layers
53
+ self.cross_attention_1_to_2 = nn.MultiheadAttention(hidden_size, num_heads)
54
+ self.cross_attention_2_to_1 = nn.MultiheadAttention(hidden_size, num_heads)
55
+ # Attention pooling layers
56
+ self.attn_pooling_1_to_2 = AttentionPooling(hidden_size)
57
+ self.attn_pooling_2_to_1 = AttentionPooling(hidden_size)
58
+ # Projection layer with non-linearity
59
+ self.projection_layer = nn.Sequential(
60
+ nn.Linear(hidden_size * 2, hidden_size),
61
+ nn.ReLU()
62
+ )
63
+ # Classifier with three hidden layers
64
+ self.classifier = nn.Sequential(
65
+ nn.Linear(hidden_size, hidden_size // 2),
66
+ nn.ReLU(),
67
+ nn.Dropout(0.1),
68
+ nn.Linear(hidden_size // 2, hidden_size // 4),
69
+ nn.ReLU(),
70
+ nn.Dropout(0.1),
71
+ nn.Linear(hidden_size // 4, num_labels)
72
+ )
73
+ def forward(self, input_ids1, attention_mask1, input_ids2, attention_mask2):
74
+ # Encode sentences
75
+ outputs1 = self.shared_encoder(input_ids1, attention_mask=attention_mask1)
76
+ outputs2 = self.shared_encoder(input_ids2, attention_mask=attention_mask2)
77
+ # Apply sentence-specific adapters
78
+ embeds1 = self.adapter1(outputs1.last_hidden_state)
79
+ embeds2 = self.adapter2(outputs2.last_hidden_state)
80
+ # Transpose for attention layers
81
+ embeds1 = embeds1.transpose(0, 1)
82
+ embeds2 = embeds2.transpose(0, 1)
83
+ # Cross-attention
84
+ cross_attn_1_to_2, _ = self.cross_attention_1_to_2(embeds1, embeds2, embeds2)
85
+ cross_attn_2_to_1, _ = self.cross_attention_2_to_1(embeds2, embeds1, embeds1)
86
+ # Attention pooling
87
+ pooled_1_to_2 = self.attn_pooling_1_to_2(cross_attn_1_to_2)
88
+ pooled_2_to_1 = self.attn_pooling_2_to_1(cross_attn_2_to_1)
89
+ # Concatenate and project
90
+ combined = torch.cat((pooled_1_to_2, pooled_2_to_1), dim=1)
91
+ projected = self.projection_layer(combined)
92
+ # Classification
93
+ logits = self.classifier(projected)
94
+ return logits
95
+
96
+ # Load configuration file
97
+ repo_path = "govtech/jina-embeddings-v2-small-en-off-topic"
98
+ config_path = hf_hub_download(repo_id=repo_path, filename="config.json")
99
+ config_path = "config.json"
100
+
101
+ with open(config_path, 'r') as f:
102
+ config = json.load(f)
103
+
104
+ def predict(sentence1, sentence2):
105
+ """
106
+ Predicts the label for a pair of sentences using a fine-tuned model with SafeTensors weights.
107
+
108
+ Args:
109
+ - sentence1 (str): The first input sentence.
110
+ - sentence2 (str): The second input sentence.
111
+
112
+ Returns:
113
+ tuple:
114
+ - predicted_label (int): The predicted label (e.g., 0 or 1).
115
+ - probabilities (numpy.ndarray): The probabilities for each class.
116
+ """
117
+ # Load model configuration
118
+ model_name = config['classifier']['embedding']['model_name']
119
+ max_length = config['classifier']['embedding']['max_length']
120
+ model_weights_fp = config['classifier']['embedding']['model_weights_fp']
121
+
122
+ # Load tokenizer and base model
123
+ device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
124
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
125
+ base_model = AutoModel.from_pretrained(model_name)
126
+ model = CrossEncoderWithSharedBase(base_model, num_labels=2)
127
+
128
+ # Load weights into the model
129
+ weights = load_file(model_weights_fp)
130
+ model.load_state_dict(weights)
131
+ model.to(device)
132
+ model.eval()
133
+
134
+ # Get inputs
135
+ inputs1 = tokenizer(sentence1, return_tensors="pt", truncation=True, padding="max_length", max_length=max_length)
136
+ inputs2 = tokenizer(sentence2, return_tensors="pt", truncation=True, padding="max_length", max_length=max_length)
137
+ input_ids1 = inputs1['input_ids'].to(device)
138
+ attention_mask1 = inputs1['attention_mask'].to(device)
139
+ input_ids2 = inputs2['input_ids'].to(device)
140
+ attention_mask2 = inputs2['attention_mask'].to(device)
141
+
142
+ # Get outputs
143
+ with torch.no_grad():
144
+ outputs = model(input_ids1=input_ids1, attention_mask1=attention_mask1,
145
+ input_ids2=input_ids2, attention_mask2=attention_mask2)
146
+ probabilities = torch.softmax(outputs, dim=1)
147
+ predicted_label = torch.argmax(probabilities, dim=1).item()
148
+
149
+ return predicted_label, probabilities.cpu().numpy()
150
+
151
+ if __name__ == "__main__":
152
+
153
+ # Load data
154
+ input_data = sys.argv[1]
155
+ sentence_pairs = json.loads(input_data)
156
+
157
+ # Validate input data format
158
+ if not all(isinstance(pair[0], str) and isinstance(pair[1], str) for pair in sentence_pairs):
159
+ raise ValueError("Each pair must contain two strings.")
160
+
161
+ for idx, (sentence1, sentence2) in enumerate(sentence_pairs):
162
+
163
+ # Generate prediction and scores
164
+ predicted_label, probabilities = predict(sentence1, sentence2)
165
+
166
+ # Print the results
167
+ print(f"Pair {idx + 1}:")
168
+ print(f" Sentence 1: {sentence1}")
169
+ print(f" Sentence 2: {sentence2}")
170
+ print(f" Predicted Label: {predicted_label}")
171
+ print(f" Probabilities: {probabilities}")
172
+ print('-' * 50)
models/off-topic-jinaai-jina-embeddings-v2-small-en-TwinEncoder.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:61f616f540ea408e918e9a5c30b770071bc473c75a240d831a71a7309724a890
3
+ size 126521473
govtech-jina-embeddings-v2-small-en-off-topic → models/off-topic-jinaai-jina-embeddings-v2-small-en-TwinEncoder.safetensors RENAMED
File without changes
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ huggingface_hub==0.26.2
2
+ numpy==2.1.3
3
+ onnxruntime==1.20.0
4
+ safetensors==0.4.5
5
+ torch==2.5.1
6
+ transformers==4.46.3