mintheinwin commited on
Commit
5f47787
Β·
1 Parent(s): b01a119

update app file and extra other file

Browse files
app.py CHANGED
@@ -1,7 +1,140 @@
1
  import gradio as gr
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ import pickle
5
+ import pandas as pd
6
+ from transformers import RobertaTokenizerFast, RobertaModel
7
 
8
+ # Load label mappings
9
+ with open("label_mappings.pkl", "rb") as f:
10
+ label_mappings = pickle.load(f)
11
 
12
+ label_to_team = label_mappings.get("label_to_team", {})
13
+ label_to_email = label_mappings.get("label_to_email", {})
14
+
15
+ # Load the tokenizer
16
+ tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
17
+
18
+ # Define RoBERTa Model
19
+ class RoBertaClassifier(nn.Module):
20
+ def __init__(self, num_teams, num_emails):
21
+ super(RoBertaClassifier, self).__init__()
22
+ self.roberta = RobertaModel.from_pretrained("roberta-base")
23
+ self.team_classifier = nn.Linear(self.roberta.config.hidden_size, num_teams)
24
+ self.email_classifier = nn.Linear(self.roberta.config.hidden_size, num_emails)
25
+
26
+ def forward(self, input_ids, attention_mask):
27
+ outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
28
+ cls_output = outputs.last_hidden_state[:, 0, :]
29
+
30
+ team_logits = self.team_classifier(cls_output)
31
+ email_logits = self.email_classifier(cls_output)
32
+
33
+ return team_logits, email_logits
34
+
35
+ # Load Model
36
+ num_teams = len(label_to_team)
37
+ num_emails = len(label_to_email)
38
+ model = RoBertaClassifier(num_teams, num_emails)
39
+ checkpoint = torch.load("ticket_classification_model.pth", map_location=torch.device("cpu"))
40
+ filtered_checkpoint = {k: v for k, v in checkpoint.items() if k in model.state_dict()}
41
+ model.load_state_dict(filtered_checkpoint, strict=False)
42
+
43
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
44
+ model.to(device)
45
+ model.eval()
46
+
47
+ # Prediction Function
48
+ def predict_tickets(ticket_descriptions):
49
+ predictions = []
50
+ csv_data = []
51
+ for idx, description in enumerate(ticket_descriptions, start=1):
52
+ inputs = tokenizer(description, return_tensors="pt", truncation=True, padding="max_length", max_length=128).to(device)
53
+ with torch.no_grad():
54
+ team_logits, email_logits = model(inputs.input_ids, inputs.attention_mask)
55
+ predicted_team_index = team_logits.argmax(dim=-1).cpu().item()
56
+ predicted_email_index = email_logits.argmax(dim=-1).cpu().item()
57
+ predicted_team = label_to_team.get(predicted_team_index, "Unknown Team")
58
+ predicted_email = label_to_email.get(predicted_email_index, "Unknown Email")
59
+ predictions.append(f"**{idx}. {description}**\n - **Assigned Team:** {predicted_team}\n - **Team Email:** {predicted_email}\n")
60
+ csv_data.append([idx, description, predicted_team, predicted_email])
61
+
62
+ df = pd.DataFrame(csv_data, columns=["Index", "Description", "Assigned Team", "Team Email"])
63
+ csv_file = "ticket-predictions.csv"
64
+ df.to_csv(csv_file, index=False)
65
+ return "\n".join(predictions), csv_file
66
+
67
+ # Gradio Functions
68
+ def gradio_predict(option, text_input, file_input):
69
+ if option == "Enter Text":
70
+ descriptions = text_input.split("\n")
71
+ descriptions = [desc.strip() for desc in descriptions if desc.strip()]
72
+ elif option == "Upload CSV" and file_input is not None:
73
+ df = pd.read_csv(file_input)
74
+ if "Description" not in df.columns:
75
+ return "⚠️ Error: CSV must contain a 'Description' column.", None
76
+ descriptions = df["Description"].tolist()
77
+ else:
78
+ return "⚠️ Please provide input.", None
79
+
80
+ results, csv_file = predict_tickets(descriptions)
81
+ return results, csv_file
82
+
83
+ def clear_inputs():
84
+ return "Enter Text", "", None, "", None
85
+
86
+ # Gradio App UI
87
+ with gr.Blocks(css=".gradio-container {max-width: 1100px; margin: auto;}") as app:
88
+ gr.Markdown(
89
+ """
90
+ # Multi-Ticket AI Classification System
91
+
92
+ **Supports:** Multi-line text input & CSV upload.
93
+ **Output:** Text results & downloadable CSV file.
94
+ **Model:** Fine-tuned **RoBERTa** for classification.
95
+
96
+ Enter ticket Description/Comment/Summary or upload a **CSV file** to predict Assigned Teams & Team Emails.
97
+
98
+ """,
99
+ elem_id="title"
100
+ )
101
+
102
+ with gr.Row():
103
+ with gr.Column(scale=1):
104
+ option = gr.Radio(["Enter Text", "Upload CSV"], label="πŸ“ Choose Input Method", value="Enter Text")
105
+
106
+ text_input = gr.Textbox(
107
+ label="Enter Ticket Description/Comment/Summary (One per line)",
108
+ lines=6,
109
+ placeholder="Example:\n - Database performance issue\n - Login fails for admin users..."
110
+ )
111
+
112
+ file_input = gr.File(label="πŸ“‚ Upload CSV (Optional)", type="filepath", visible=False)
113
+
114
+ with gr.Column(scale=1):
115
+ gr.Markdown("## Prediction Results") # **Title for Prediction Results**
116
+ results_output = gr.Markdown(elem_id="results-box", visible=True)
117
+ download_csv = gr.File(label="πŸ“₯ Download Predictions CSV", interactive=False)
118
+
119
+ with gr.Row():
120
+ predict_btn = gr.Button("PREDICT", variant="primary")
121
+ clear_btn = gr.Button("CLEAR", variant="secondary")
122
+
123
+ # Logic for Showing/ Hiding Input Fields
124
+ def toggle_input(selected_option):
125
+ return gr.update(visible=(selected_option == "Enter Text")), gr.update(visible=(selected_option == "Upload CSV"))
126
+
127
+ option.change(fn=toggle_input, inputs=[option], outputs=[text_input, file_input])
128
+
129
+ predict_btn.click(fn=gradio_predict, inputs=[option, text_input, file_input], outputs=[results_output, download_csv])
130
+ clear_btn.click(fn=clear_inputs, inputs=[], outputs=[option, text_input, file_input, results_output, download_csv])
131
+
132
+ # Footer view
133
+ gr.Markdown("---")
134
+ gr.Markdown(
135
+ "<p style='text-align: center;color: gray;'>"
136
+ "Developed by NYP student @ Min Thein Win: Student ID: 3907578Y</p>"
137
+ )
138
+
139
+ # Launch App
140
+ app.launch(share=True)
label_mappings.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:db90d041f35923a582c2bd4e795ca06632b8b23e1e9eaab6622844ccc27c47c7
3
+ size 315
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ torchvision
4
+ torchaudio
5
+ transformers
6
+ pandas
ticket_classification_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a33510b236c5dea7a1c00276dd933b302a10711c2a411fa059ca81e5651de030
3
+ size 498701000