s986103 commited on
Commit
319373e
·
1 Parent(s): 67f5816

modify input output format

Browse files
Files changed (1) hide show
  1. app.py +28 -6
app.py CHANGED
@@ -14,17 +14,39 @@ model.eval()
14
 
15
  def classify_text(text):
16
  inputs = tokenizer(text, return_tensors="pt",
17
- truncation=True, padding=True)
18
  with torch.no_grad():
19
  logits = model(**inputs)
20
- prediction = torch.argmax(logits, dim=1).item()
21
  return prediction
22
 
23
 
24
- iface = gr.Interface(fn=classify_text,
25
- inputs="text",
26
- outputs="label",
27
- description="自動作文評分")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  # 啟動 UI
30
  iface.launch()
 
14
 
15
  def classify_text(text):
16
  inputs = tokenizer(text, return_tensors="pt",
17
+ truncation=True, max_length=1024)
18
  with torch.no_grad():
19
  logits = model(**inputs)
20
+ prediction = torch.argmax(logits, dim=-1).item() + 1
21
  return prediction
22
 
23
 
24
+ # 自定义 CSS 样式
25
+ custom_css = """
26
+ #input_textbox textarea {
27
+ border: 2px solid #1E90FF !important; /* 设置输入框边框颜色为蓝色 */
28
+ border-radius: 10px !important; /* 设置边框圆角 */
29
+ }
30
+
31
+ #output_textbox textarea {
32
+ border: 2px solid #FFA500 !important; /* 设置输出框边框颜色为橘色 */
33
+ border-radius: 10px !important; /* 设置边框圆角 */
34
+ font-size: 24px !important; /* 设置字体大小为24px */
35
+ text-align: center !important; /* 将文本居中对齐 */
36
+ display: flex;
37
+ justify-content: center;
38
+ align-items: center;
39
+ }
40
+ """
41
+
42
+ # 定义 Gradio 接口
43
+ iface = gr.Interface(
44
+ fn=classify_text,
45
+ inputs=gr.Textbox(label="請輸入文章", elem_id="input_textbox"),
46
+ outputs=gr.Textbox(label="評分結果(1-6)", elem_id="output_textbox"),
47
+ description="自動作文評分",
48
+ css=custom_css # 将自定义 CSS 添加到 Gradio 应用
49
+ )
50
 
51
  # 啟動 UI
52
  iface.launch()