FrierenChatbotV1 / handler.py
homer7676's picture
Update handler.py
11ffb79 verified
raw
history blame
3.1 kB
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import Dict, Any
class EndpointHandler:
def __init__(self):
self.tokenizer = None
self.model = None
self.device = "cuda" if torch.cuda.is_available() else "cpu"
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""使 handler 可調用"""
inputs = self.preprocess(data)
outputs = self.inference(inputs)
return self.postprocess(outputs)
def initialize(self, context):
"""初始化模型和 tokenizer"""
self.tokenizer = AutoTokenizer.from_pretrained(
"homer7676/FrierenChatbotV1",
trust_remote_code=True
)
self.model = AutoModelForCausalLM.from_pretrained(
"homer7676/FrierenChatbotV1",
trust_remote_code=True,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
).to(self.device)
self.model.eval()
def preprocess(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""預處理輸入數據"""
inputs = data.pop("inputs", data)
if not isinstance(inputs, dict):
inputs = {"message": inputs}
return inputs
def inference(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""執行推理"""
try:
message = inputs.get("message", "")
context = inputs.get("context", "")
prompt = f"""你是芙莉蓮,需要遵守以下規則回答:
1. 身份設定:
- 千年精靈魔法師
- 態度溫柔但帶著些許嘲諷
- 說話優雅且有距離感
2. 重要關係:
- 弗蘭梅是我的師傅
- 費倫是我的學生
- 欣梅爾是我的摯友
- 海塔是我的故友
3. 回答規則:
- 使用繁體中文
- 必須提供具體詳細的內容
- 保持回答的連貫性和完整性
相關資訊:{context}
用戶:{message}
芙莉蓮:"""
inputs = self.tokenizer(
prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=2048
).to(self.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=256,
temperature=0.7,
top_p=0.9,
top_k=50,
do_sample=True,
repetition_penalty=1.2,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id
)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
response = response.split("芙莉蓮:")[-1].strip()
return {"generated_text": response}
except Exception as e:
print(f"推理過程錯誤: {str(e)}")
return {"error": str(e)}
def postprocess(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""後處理輸出數據"""
return data