SCANSKY commited on
Commit
2d39da4
·
verified ·
1 Parent(s): 56328a9

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +80 -0
handler.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from bertopic import BERTopic
3
+
4
+ class EndpointHandler:
5
+ def __init__(self, model_path="SCANSKY/BERTopic-Tourism-Hindi"):
6
+ """
7
+ Initialize the handler. Load the BERTopic model from Hugging Face.
8
+ """
9
+ self.topic_model = BERTopic.load(model_path)
10
+
11
+ def preprocess(self, data):
12
+ """
13
+ Preprocess the incoming request data.
14
+ - Extract text input from the request.
15
+ """
16
+ try:
17
+ # Directly work with the incoming data dictionary
18
+ text_input = data.get("inputs", "")
19
+ return text_input
20
+ except Exception as e:
21
+ raise ValueError(f"Error during preprocessing: {str(e)}")
22
+
23
+ def inference(self, text_input):
24
+ """
25
+ Perform inference using the BERTopic model.
26
+ - Process the text input and generate topic predictions.
27
+ """
28
+ try:
29
+ # Split text into documents (assuming one document per line)
30
+ docs = text_input.strip().split('\n')
31
+
32
+ # Perform topic inference
33
+ topics, probabilities = self.topic_model.transform(docs)
34
+
35
+ # Prepare the results
36
+ results = []
37
+ for topic, prob in zip(topics, probabilities):
38
+ topic_info = self.topic_model.get_topic(topic)
39
+ topic_words = [word for word, _ in topic_info] if topic_info else []
40
+
41
+ # Get custom label for the topic (with fallback if custom_labels_ is not available)
42
+ if hasattr(self.topic_model, "custom_labels_") and self.topic_model.custom_labels_ is not None:
43
+ custom_label = self.topic_model.custom_labels_[topic + 1]
44
+ else:
45
+ custom_label = f"Topic {topic}" # Fallback label
46
+
47
+ results.append({
48
+ "topic": int(topic),
49
+ "probability": float(prob),
50
+ "top_words": topic_words[:5], # Top 5 words
51
+ "customLabel": custom_label # Add custom label
52
+ })
53
+
54
+ return results
55
+ except Exception as e:
56
+ raise ValueError(f"Error during inference: {str(e)}")
57
+
58
+ def postprocess(self, results):
59
+ """
60
+ Postprocess the inference results into a JSON-serializable list.
61
+ """
62
+ return results # Directly returning the list of results
63
+
64
+ def __call__(self, data):
65
+ """
66
+ Handle the incoming request.
67
+ """
68
+ try:
69
+ # Preprocess the data
70
+ text_input = self.preprocess(data)
71
+
72
+ # Perform inference
73
+ results = self.inference(text_input)
74
+
75
+ # Postprocess the results
76
+ response = self.postprocess(results)
77
+
78
+ return response
79
+ except Exception as e:
80
+ return [{"error": str(e)}] # Returning error as a list with a dictionary