|
from typing import Dict, List, Any |
|
import json |
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer |
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
self.tokenizer = None |
|
self.model = None |
|
self.device = None |
|
self.load_model(path) |
|
|
|
def load_model(self, model_dir): |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model_path = model_dir |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) |
|
self.model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).to(self.device) |
|
self.model.eval() |
|
print(f"Tokenizer and Model loaded from: {model_path} to device: {self.device}") |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
inputs = data.pop("inputs", data) |
|
print(f'get input {inputs}') |
|
if not inputs: |
|
raise ValueError("Input text is missing in the request. Please provide 'inputs' or 'text' in your request.") |
|
|
|
history = json.loads(inputs) |
|
print(f'history is {history}') |
|
|
|
conversion = self.tokenizer.apply_chat_template(history, add_generation_prompt=True, tokenize=False) |
|
encoding = self.tokenizer(conversion, return_tensors="pt").to(self.device) |
|
print(f'encoding success') |
|
with torch.no_grad(): |
|
output = self.model.generate( |
|
**encoding, |
|
max_new_tokens=1024, |
|
temperature=1.5, |
|
do_sample=True, |
|
pad_token_id=self.tokenizer.eos_token_id |
|
) |
|
print(f'output success') |
|
generated_text = self.tokenizer.decode(output[0], skip_special_tokens=True) |
|
|
|
return [{"response": generated_text}] |
|
|