jasminsongsimin commited on
Commit
f0754d3
·
verified ·
1 Parent(s): f25c9c8

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +68 -0
handler.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
4
+
5
+ # 定义模型处理类
6
+ class ModelHandler(object):
7
+ def __init__(self):
8
+ self.tokenizer = None
9
+ self.model = None
10
+ self.device = None
11
+
12
+ def load_model(self, model_dir):
13
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+
15
+ model_path = model_dir
16
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
17
+ self.model = AutoModelForCausalLM.from_pretrained(model_path).to(self.device)
18
+ self.model.eval()
19
+
20
+ print(f"Tokenizer and Model loaded from: {model_path} to device: {self.device}")
21
+
22
+
23
+ def preprocess(self, request):
24
+ input_text = request.get("inputs", request.get("text"))
25
+ if not input_text:
26
+ raise ValueError("Input text is missing in the request. Please provide 'inputs' or 'text' in your request.")
27
+
28
+ history = []
29
+ history.append({"role": "user", "content": input_text})
30
+ conversion = self.tokenizer.apply_chat_template(history, add_generation_prompt=True, tokenize=False)
31
+ encoding = self.tokenizer(conversion, return_tensors="pt").to(self.device)
32
+ return encoding
33
+
34
+
35
+ def predict(self, model_input):
36
+ with torch.no_grad():
37
+ output = self.model.generate(
38
+ **model_input,
39
+ max_new_tokens=1024,
40
+ temperature=1.5,
41
+ do_sample=True,
42
+ pad_token_id=self.tokenizer.eos_token_id
43
+ )
44
+ return output
45
+
46
+
47
+ def postprocess(self, prediction):
48
+ generated_text = self.tokenizer.decode(prediction[0], skip_special_tokens=True)
49
+ return {"response": generated_text}
50
+
51
+
52
+ _service = ModelHandler()
53
+
54
+ def load():
55
+ model_dir = '/home/aistudio/export'
56
+ _service.load_model(model_dir)
57
+
58
+
59
+ def preprocess(request):
60
+ return _service.preprocess(request)
61
+
62
+
63
+ def predict(data):
64
+ return _service.predict(data)
65
+
66
+
67
+ def postprocess(prediction):
68
+ return _service.postprocess(prediction)