import json from bertopic import BERTopic class EndpointHandler: def __init__(self, model_path="SCANSKY/BERTopic-Tourism-Chinese"): """ Initialize the handler. Load the BERTopic model from Hugging Face. """ self.topic_model = BERTopic.load(model_path) def preprocess(self, data): """ Preprocess the incoming request data. - Extract text input from the request. """ try: # Directly work with the incoming data dictionary text_input = data.get("inputs", "") return text_input except Exception as e: raise ValueError(f"Error during preprocessing: {str(e)}") def inference(self, text_input): """ Perform inference using the BERTopic model. - Combine all sentences into a single document and find shared topics. """ try: # Split text into sentences (assuming one sentence per line) sentences = text_input.strip().split('\n') # Combine all sentences into a single document combined_document = " ".join(sentences) # Perform topic inference on the combined document topics, probabilities = self.topic_model.transform([combined_document]) # Prepare the results results = [] for topic, prob in zip(topics, probabilities): topic_info = self.topic_model.get_topic(topic) topic_words = [word for word, _ in topic_info] if topic_info else [] # Get custom label for the topic if hasattr(self.topic_model, "custom_labels_") and self.topic_model.custom_labels_ is not None: custom_label = self.topic_model.custom_labels_[topic + 1] else: custom_label = f"Topic {topic}" # Fallback label results.append({ "topic": int(topic), "probability": float(prob), "top_words": topic_words[:5], # Top 5 words "customLabel": custom_label # Add custom label }) return results except Exception as e: raise ValueError(f"Error during inference: {str(e)}") def postprocess(self, results): """ Postprocess the inference results into a JSON-serializable list. """ return results # Directly returning the list of results def __call__(self, data): """ Handle the incoming request. """ try: # Preprocess the data text_input = self.preprocess(data) # Perform inference results = self.inference(text_input) # Postprocess the results response = self.postprocess(results) return response except Exception as e: return [{"error": str(e)}] # Returning error as a list with a dictionary