Kororinpa commited on
Commit
2ddb15e
·
verified ·
1 Parent(s): ac4dae6

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +69 -54
handler.py CHANGED
@@ -1,8 +1,6 @@
1
  from typing import List, Dict
2
  import torch
3
  from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification
4
- # 移除相对导入
5
- # from .modeling import BinaryClassifier
6
 
7
  class EndpointHandler:
8
  def __init__(self, path=""):
@@ -16,71 +14,88 @@ class EndpointHandler:
16
  # 初始化tokenizer
17
  self.tokenizer = AutoTokenizer.from_pretrained(path)
18
 
19
- # 设置最大长度,可以根据你的需求调整
20
  self.max_length = 512
21
 
22
  def __call__(self, data: List[Dict[str, str]]) -> List[Dict[str, float]]:
23
  """
24
  处理文本推理请求
25
- Args:
26
- data: 输入数据列表,每个元素是一个字典
27
- 例如:[{"inputs": "这是一段测试文本"}]
28
- Returns:
29
- 预测结果列表
30
  """
31
- # 获取所有输入文本
32
- texts = [item["inputs"] for item in data]
33
-
34
- # tokenization
35
- encoded_inputs = self.tokenizer(
36
- texts,
37
- padding=True,
38
- truncation=True,
39
- max_length=self.max_length,
40
- return_tensors="pt"
41
- )
42
-
43
- # 进行预测
44
- with torch.no_grad():
45
- outputs = self.model(**encoded_inputs)
46
- logits = outputs.logits
47
- predictions = torch.softmax(logits, dim=-1)
48
-
49
- # 格式化输出
50
- results = []
51
- for pred in predictions:
52
- label_id = pred.argmax().item()
53
- score = pred[label_id].item()
54
- results.append({
55
- "label": str(label_id), # 0 或 1
56
- "score": float(score) # 预测概率
57
- })
58
-
59
- return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  def preprocess(self, text: str) -> Dict[str, torch.Tensor]:
62
  """
63
  预处理方法
64
  """
65
- encoded = self.tokenizer(
66
- text,
67
- padding=True,
68
- truncation=True,
69
- max_length=self.max_length,
70
- return_tensors="pt"
71
- )
72
- return encoded
 
 
 
 
73
 
74
  def postprocess(self, model_outputs) -> Dict:
75
  """
76
  后处理方法
77
  """
78
- logits = model_outputs.logits
79
- predictions = torch.softmax(logits, dim=-1)
80
- label_id = predictions[0].argmax().item()
81
- score = predictions[0][label_id].item()
82
-
83
- return {
84
- "label": str(label_id),
85
- "score": float(score)
86
- }
 
 
 
 
 
1
  from typing import List, Dict
2
  import torch
3
  from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification
 
 
4
 
5
  class EndpointHandler:
6
  def __init__(self, path=""):
 
14
  # 初始化tokenizer
15
  self.tokenizer = AutoTokenizer.from_pretrained(path)
16
 
17
+ # 设置最大长度
18
  self.max_length = 512
19
 
20
  def __call__(self, data: List[Dict[str, str]]) -> List[Dict[str, float]]:
21
  """
22
  处理文本推理请求
 
 
 
 
 
23
  """
24
+ try:
25
+ # 获取所有输入文本
26
+ texts = []
27
+ for item in data:
28
+ # 确保我们正确处理输入数据
29
+ if isinstance(item, dict) and "inputs" in item:
30
+ texts.append(item["inputs"])
31
+ elif isinstance(item, str):
32
+ texts.append(item)
33
+ else:
34
+ raise ValueError(f"Unexpected input format: {item}")
35
+
36
+ # tokenization
37
+ encoded_inputs = self.tokenizer(
38
+ texts,
39
+ padding=True,
40
+ truncation=True,
41
+ max_length=self.max_length,
42
+ return_tensors="pt"
43
+ )
44
+
45
+ # 进行预测
46
+ with torch.no_grad():
47
+ outputs = self.model(**encoded_inputs)
48
+ logits = outputs.logits
49
+ probabilities = torch.softmax(logits, dim=-1)
50
+
51
+ # 格式化输出
52
+ results = []
53
+ for probs in probabilities:
54
+ label_id = int(torch.argmax(probs).item())
55
+ confidence = float(probs[label_id].item())
56
+ results.append({
57
+ "label": str(label_id), # 转换为字符串
58
+ "score": confidence # 预测概率
59
+ })
60
+
61
+ return results
62
+
63
+ except Exception as e:
64
+ # 添加错误处理和日志记录
65
+ print(f"Error in prediction: {str(e)}")
66
+ return [{"error": str(e)}]
67
 
68
  def preprocess(self, text: str) -> Dict[str, torch.Tensor]:
69
  """
70
  预处理方法
71
  """
72
+ try:
73
+ encoded = self.tokenizer(
74
+ text,
75
+ padding=True,
76
+ truncation=True,
77
+ max_length=self.max_length,
78
+ return_tensors="pt"
79
+ )
80
+ return encoded
81
+ except Exception as e:
82
+ print(f"Error in preprocessing: {str(e)}")
83
+ raise e
84
 
85
  def postprocess(self, model_outputs) -> Dict:
86
  """
87
  后处理方法
88
  """
89
+ try:
90
+ logits = model_outputs.logits
91
+ probabilities = torch.softmax(logits, dim=-1)
92
+ label_id = int(torch.argmax(probabilities[0]).item())
93
+ confidence = float(probabilities[0][label_id].item())
94
+
95
+ return {
96
+ "label": str(label_id),
97
+ "score": confidence
98
+ }
99
+ except Exception as e:
100
+ print(f"Error in postprocessing: {str(e)}")
101
+ raise e