logasanjeev commited on
Commit
087f7cd
·
verified ·
1 Parent(s): ca4dd76

Add inference script

Browse files
Files changed (1) hide show
  1. inference.py +18 -0
inference.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BertForSequenceClassification, BertTokenizer
2
+ import torch
3
+ import json
4
+ import requests
5
+
6
+ def predict(text):
7
+ repo_id = "logasanjeev/goemotions-bert"
8
+ model = BertForSequenceClassification.from_pretrained(repo_id)
9
+ tokenizer = BertTokenizer.from_pretrained(repo_id)
10
+ thresholds_url = f"https://huggingface.co/{repo_id}/raw/main/thresholds.json"
11
+ thresholds_data = json.loads(requests.get(thresholds_url).text)
12
+ emotion_labels = thresholds_data["emotion_labels"]
13
+ thresholds = thresholds_data["thresholds"]
14
+ encodings = tokenizer(text, padding='max_length', truncation=True, max_length=128, return_tensors='pt')
15
+ with torch.no_grad():
16
+ logits = torch.sigmoid(model(**encodings).logits).numpy()[0]
17
+ predictions = [{"label": emotion_labels[i], "score": float(logit)} for i, (logit, thresh) in enumerate(zip(logits, thresholds)) if logit >= thresh]
18
+ return sorted(predictions, key=lambda x: x["score"], reverse=True)