aunghlaing commited on
Commit
d5c56ed
·
verified ·
1 Parent(s): f9763ca

Added individual predict buttons

Browse files
Files changed (1) hide show
  1. app.py +124 -0
app.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import torch
4
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
5
+ from huggingface_hub import snapshot_download
6
+
7
+ # Student Information
8
+ My_info = "Student ID: 6319250G, Name: Aung Hlaing Tun"
9
+
10
+ # Define Hugging Face Model Repo
11
+ MODEL_REPO_ID = "ZAM-ITI-110/Distil_Bert_V3"
12
+
13
+ # Load Model & Tokenizer from Hugging Face
14
+ def load_model(repo_id):
15
+ """Download and load the model and tokenizer."""
16
+ cache_dir = "/home/user/app/hf_models"
17
+ os.makedirs(cache_dir, exist_ok=True)
18
+ download_dir = snapshot_download(repo_id, cache_dir=cache_dir, local_files_only=False)
19
+ model = AutoModelForSequenceClassification.from_pretrained(download_dir)
20
+ tokenizer = AutoTokenizer.from_pretrained(download_dir)
21
+ return model, tokenizer
22
+
23
+ # Load Model
24
+ model, tokenizer = load_model(MODEL_REPO_ID)
25
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
26
+ model.to(device)
27
+ model.eval()
28
+
29
+ # Prediction Function (Single Ticket)
30
+ def predict_team_and_email(text):
31
+ """Predict team and email for a single ticket description."""
32
+ if not text.strip(): # Return empty if no input
33
+ return "", ""
34
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
35
+ with torch.no_grad():
36
+ logits = model(**inputs).logits
37
+ pred = torch.argmax(logits, dim=-1).item()
38
+
39
+ label_mapping = {
40
+ 0: "Code Review Team", 1: "Functional Team", 2: "Infrastructure Team",
41
+ 3: "Performance Team", 4: "Security Team"
42
+ }
43
+ email_mapping = {
44
45
46
47
+ }
48
+
49
+ return label_mapping.get(pred, "Unknown"), email_mapping.get(pred, "Unknown")
50
+
51
+ # Send Ticket Function (Simulation)
52
+ def send_tickets(*args):
53
+ """Simulate sending tickets based on predictions."""
54
+ tickets = []
55
+ for i, (text, team, email) in enumerate(zip(args[::2], args[1::2], args[2::2]), 1):
56
+ if text.strip() and team and email:
57
+ tickets.append(f"Ticket {i}: '{text}' -> {team} ({email})")
58
+ if tickets:
59
+ return "\n".join(tickets) + "\n\nSent successfully!"
60
+ return "No tickets to send."
61
+
62
+ # Clear Function
63
+ def clear_all():
64
+ """Clear all inputs and outputs."""
65
+ return [""] * 19 # 6 tickets x (input, team, email) + 1 sent_output = 19 fields
66
+
67
+ # Gradio UI Setup
68
+ with gr.Blocks(title="AI Ticket Classifier") as interface:
69
+ gr.Markdown("📩 **Development of an AI Ticket Classifier Model Using DistilBERT**")
70
+ gr.Markdown(f"*{My_info}*")
71
+ gr.Markdown(
72
+ """
73
+ **🔍 About this App**
74
+ - Predicts the appropriate **team** and **email** for up to 6 ticket descriptions.
75
+ - Click 'Predict' for each ticket, then 'Send Tickets' to process.
76
+ """
77
+ )
78
+
79
+ # Ticket Entry Section
80
+ with gr.Column():
81
+ gr.Markdown("### Enter Ticket Descriptions")
82
+ inputs = []
83
+ outputs = []
84
+ buttons = []
85
+ for i in range(6):
86
+ with gr.Row():
87
+ ticket_input = gr.Textbox(lines=2, placeholder=f"Ticket {i+1} description...", label=f"Ticket {i+1}")
88
+ team_output = gr.Textbox(label="Predicted Team", interactive=False)
89
+ email_output = gr.Textbox(label="Team Email", interactive=False)
90
+ predict_btn = gr.Button(f"Predict {i+1}")
91
+ inputs.append(ticket_input)
92
+ outputs.extend([team_output, email_output])
93
+ buttons.append(predict_btn)
94
+
95
+ # Action Buttons
96
+ with gr.Row():
97
+ send_btn = gr.Button("Send Tickets")
98
+ clear_btn = gr.Button("Clear")
99
+
100
+ # Output for Sent Tickets
101
+ sent_output = gr.Textbox(label="Sent Tickets", interactive=False)
102
+
103
+ # Event Handlers for Predict Buttons
104
+ for i, btn in enumerate(buttons):
105
+ btn.click(
106
+ fn=predict_team_and_email,
107
+ inputs=inputs[i],
108
+ outputs=[outputs[i*2], outputs[i*2 + 1]] # Team and email for this ticket
109
+ )
110
+
111
+ # Send and Clear Handlers
112
+ send_btn.click(
113
+ fn=send_tickets,
114
+ inputs=inputs + outputs,
115
+ outputs=sent_output
116
+ )
117
+ clear_btn.click(
118
+ fn=clear_all,
119
+ inputs=None,
120
+ outputs=inputs + outputs + [sent_output]
121
+ )
122
+
123
+ # Launch the interface
124
+ interface.launch()