|
--- |
|
license: mit |
|
language: |
|
- en |
|
- pt |
|
base_model: |
|
- cnmoro/tangled-llama-33m-32k-instruct-v0.1-fix |
|
pipeline_tag: text-ranking |
|
library_name: sentence-transformers |
|
--- |
|
|
|
```python |
|
from tokenizers import Tokenizer |
|
import onnxruntime as ort |
|
import numpy as np |
|
|
|
reranker_tokenizer = Tokenizer.from_file('./tokenizer.json') |
|
reranker_session = ort.InferenceSession('./model.onnx') |
|
|
|
def rerank(question, passages, normalize_scores=True): |
|
# Format input templates |
|
templates = [f"Query: {question}\nSentence: {passage}" for passage in passages] |
|
encoded_inputs = reranker_tokenizer.encode_batch(templates) |
|
|
|
# Convert to lists and truncate sequences to max length (32768) |
|
input_ids = [enc.ids[:32768] for enc in encoded_inputs] # Truncate here |
|
attention_mask = [[1] * len(ids) for ids in input_ids] |
|
|
|
# Find max length in batch |
|
batch_max_length = max(len(ids) for ids in input_ids) # Already truncated to <=512 |
|
|
|
# Pad sequences |
|
def pad_sequence(seq, pad_value=0): |
|
return seq + [pad_value] * (batch_max_length - len(seq)) |
|
|
|
input_ids = np.array([pad_sequence(ids) for ids in input_ids], dtype=np.int64) |
|
attention_mask = np.array([pad_sequence(mask, pad_value=0) for mask in attention_mask], dtype=np.int64) |
|
|
|
# Create ONNX input dict |
|
inputs_onnx = { |
|
"input_ids": input_ids, |
|
"attention_mask": attention_mask |
|
} |
|
|
|
# Run ONNX model |
|
outputs = reranker_session.run(None, inputs_onnx) |
|
logits = outputs[0] |
|
|
|
# Apply softmax to get probabilities |
|
probabilities = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True) |
|
|
|
# Get predicted class and confidence score |
|
predicted_classes = np.argmax(probabilities, axis=1).tolist() |
|
confidences = np.max(probabilities, axis=1).tolist() |
|
|
|
results = [ |
|
{"passage": passage, "prediction": pred, "confidence": conf} |
|
for passage, pred, conf in zip(passages, predicted_classes, confidences) |
|
] |
|
|
|
final_results = [] |
|
for document, result in zip(passages, results): |
|
# If the prediction is 0, adjust the confidence score |
|
if result['prediction'] == 0: |
|
result['confidence'] = 1 - result['confidence'] |
|
final_results.append((document, result['confidence'])) |
|
|
|
# Sort by confidence score in descending order |
|
sorted_results = sorted(final_results, key=lambda x: x[1], reverse=True) |
|
|
|
# Normalize scores if required |
|
if normalize_scores: |
|
total_score = sum(result[1] for result in sorted_results) |
|
if total_score > 0: |
|
sorted_results = [(result[0], result[1] / total_score) for result in sorted_results] |
|
|
|
return sorted_results |
|
|
|
question = "O que é o Pantanal?" |
|
passages = [ |
|
"É um dos ecossistemas mais ricos em biodiversidade do mundo, abrigando uma grande variedade de espécies animais e vegetais.", |
|
"Sua beleza natural, com rios e lagos interligados, atrai turistas de todo o mundo.", |
|
"O Pantanal sofre com impactos ambientais, como a exploração mineral e o desmatamento.", |
|
"O Pantanal é uma extensa planície alagável localizada na América do Sul, principalmente no Brasil, mas também em partes da Bolívia e Paraguai.", |
|
"É um local com importância histórica e cultural para as populações locais.", |
|
"O Pantanal é um importante habitat para diversas espécies de animais, inclusive aves migratórias." |
|
] |
|
ranked_results = rerank(question, passages, normalize_scores=True) |
|
ranked_results |
|
# [('O Pantanal é uma extensa planície alagável localizada na América do Sul, principalmente no Brasil, mas também em partes da Bolívia e Paraguai.', |
|
# 0.7105862286443647), |
|
# ('O Pantanal é um importante habitat para diversas espécies de animais, inclusive aves migratórias.', |
|
# 0.22660008031497725), |
|
# ('O Pantanal sofre com impactos ambientais, como a exploração mineral e o desmatamento.', |
|
# 0.043374300040060654), |
|
# ('É um local com importância histórica e cultural para as populações locais.', |
|
# 0.0070428120274147726), |
|
# ('É um dos ecossistemas mais ricos em biodiversidade do mundo, abrigando uma grande variedade de espécies animais e vegetais.', |
|
# 0.006359544027065005), |
|
# ('Sua beleza natural, com rios e lagos interligados, atrai turistas de todo o mundo.', |
|
# 0.006037034946117598)] |
|
|
|
question = "What is the speed of light?" |
|
passages = [ |
|
"Isaac Newton's laws of motion and gravity laid the groundwork for classical mechanics.", |
|
"The theory of relativity, proposed by Albert Einstein, has revolutionized our understanding of space, time, and gravity.", |
|
"The Earth orbits the Sun at an average distance of about 93 million miles, taking roughly 365.25 days to complete one revolution.", |
|
"The speed of light in a vacuum is approximately 299,792 kilometers per second (km/s), or about 186,282 miles per second.", |
|
"Light can be described as both a wave and a particle, a concept known as wave-particle duality." |
|
] |
|
ranked_results = rerank(question, passages, normalize_scores=True) |
|
ranked_results |
|
# [('The speed of light in a vacuum is approximately 299,792 kilometers per second (km/s), or about 186,282 miles per second.', |
|
# 0.5686758878772575), |
|
# ('The theory of relativity, proposed by Albert Einstein, has revolutionized our understanding of space, time, and gravity.', |
|
# 0.14584055128478327), |
|
# ('The Earth orbits the Sun at an average distance of about 93 million miles, taking roughly 365.25 days to complete one revolution.', |
|
# 0.13790743024424898), |
|
# ("Isaac Newton's laws of motion and gravity laid the groundwork for classical mechanics.", |
|
# 0.08071345159269593), |
|
# ('Light can be described as both a wave and a particle, a concept known as wave-particle duality.', |
|
# 0.06686267900101434)] |
|
``` |