mintheinwin commited on
Commit
4ea6107
·
1 Parent(s): fa95ad0

stream app

Browse files
app.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import torch.nn as nn
4
+ import pickle
5
+ import pandas as pd
6
+ from transformers import RobertaTokenizerFast, RobertaModel
7
+
8
+ # -------------------------------
9
+ # Load label mappings
10
+ with open("label_mappings.pkl", "rb") as f:
11
+ label_mappings = pickle.load(f)
12
+
13
+ label_to_team = label_mappings.get("label_to_team", {})
14
+ label_to_email = label_mappings.get("label_to_email", {})
15
+
16
+ # -------------------------------
17
+ # Load tokenizer
18
+ tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
19
+
20
+ # -------------------------------
21
+ # Define RoBERTa Model for multi-task classification
22
+ class RoBertaClassifier(nn.Module):
23
+ def __init__(self, num_teams, num_emails):
24
+ super(RoBertaClassifier, self).__init__()
25
+ self.roberta = RobertaModel.from_pretrained("roberta-base")
26
+ self.team_classifier = nn.Linear(self.roberta.config.hidden_size, num_teams)
27
+ self.email_classifier = nn.Linear(self.roberta.config.hidden_size, num_emails)
28
+
29
+ def forward(self, input_ids, attention_mask):
30
+ outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
31
+ cls_output = outputs.last_hidden_state[:, 0, :]
32
+ team_logits = self.team_classifier(cls_output)
33
+ email_logits = self.email_classifier(cls_output)
34
+ return team_logits, email_logits
35
+
36
+ # -------------------------------
37
+ # Initialize model and load checkpoint
38
+ num_teams = len(label_to_team)
39
+ num_emails = len(label_to_email)
40
+ model = RoBertaClassifier(num_teams, num_emails)
41
+
42
+ checkpoint = torch.load("ticket_classification_model.pth", map_location=torch.device("cpu"))
43
+ filtered_checkpoint = {k: v for k, v in checkpoint.items() if k in model.state_dict()}
44
+ model.load_state_dict(filtered_checkpoint, strict=False)
45
+
46
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
47
+ model.to(device)
48
+ model.eval()
49
+
50
+ # -------------------------------
51
+ # Prediction function
52
+ def predict_tickets(ticket_descriptions):
53
+ predictions = []
54
+ csv_data = []
55
+ for idx, description in enumerate(ticket_descriptions, start=1):
56
+ inputs = tokenizer(
57
+ description,
58
+ return_tensors="pt",
59
+ truncation=True,
60
+ padding="max_length",
61
+ max_length=128
62
+ ).to(device)
63
+ with torch.no_grad():
64
+ team_logits, email_logits = model(inputs.input_ids, inputs.attention_mask)
65
+ predicted_team_index = team_logits.argmax(dim=-1).cpu().item()
66
+ predicted_email_index = email_logits.argmax(dim=-1).cpu().item()
67
+ predicted_team = label_to_team.get(predicted_team_index, "Unknown Team")
68
+ predicted_email = label_to_email.get(predicted_email_index, "Unknown Email")
69
+ predictions.append(
70
+ f"{idx}. {description}\n - Assigned Team: {predicted_team}\n - Team Email: {predicted_email}\n"
71
+ )
72
+ csv_data.append([idx, description, predicted_team, predicted_email])
73
+ df = pd.DataFrame(csv_data, columns=["Index", "Description", "Assigned Team", "Team Email"])
74
+ return "\n".join(predictions), df
75
+
76
+ # -------------------------------
77
+ # Streamlit UI
78
+ st.title("AI Solution for Defect Ticket Classification")
79
+ st.markdown("""
80
+ **Supports:** Multi-line text input & CSV upload.
81
+ **Output:** Text results & downloadable CSV file.
82
+ **Model:** Fine-tuned **RoBERTa** for classification.
83
+ """)
84
+
85
+ # Choose input method
86
+ option = st.radio("📝 Choose Input Method", ["Enter Text", "Upload CSV"])
87
+
88
+ if option == "Enter Text":
89
+ text_input = st.text_area(
90
+ "Enter Ticket Description/Comment/Summary (One per line)",
91
+ placeholder="Example:\n - Database performance issue\n - Login fails for admin users..."
92
+ )
93
+ descriptions = [line.strip() for line in text_input.split("\n") if line.strip()]
94
+ else:
95
+ file_input = st.file_uploader("Upload CSV", type=["csv"])
96
+ descriptions = []
97
+ if file_input is not None:
98
+ df_input = pd.read_csv(file_input)
99
+ if "Description" not in df_input.columns:
100
+ st.error("⚠️ Error: CSV must contain a 'Description' column.")
101
+ else:
102
+ descriptions = df_input["Description"].dropna().tolist()
103
+
104
+ # Trigger prediction when the button is clicked
105
+ if st.button("PREDICT"):
106
+ if not descriptions:
107
+ st.error("⚠️ Please provide valid input.")
108
+ else:
109
+ with st.spinner("Predicting..."):
110
+ results, df_results = predict_tickets(descriptions)
111
+ st.markdown("## Prediction Results")
112
+ st.text(results)
113
+ csv_data = df_results.to_csv(index=False).encode('utf-8')
114
+ st.download_button(
115
+ label="📥 Download Predictions CSV",
116
+ data=csv_data,
117
+ file_name="ticket-predictions.csv",
118
+ mime="text/csv"
119
+ )
120
+
121
+ # Clear button: simply reloads the app
122
+ if st.button("CLEAR"):
123
+ st.experimental_rerun()
124
+
125
+ st.markdown("---")
126
+ st.markdown(
127
+ "<p style='text-align: center;color: gray;'>Developed by NYP student @ Min Thein Win: Student ID: 3907578Y</p>",
128
+ unsafe_allow_html=True
129
+ )
label_mappings.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:db90d041f35923a582c2bd4e795ca06632b8b23e1e9eaab6622844ccc27c47c7
3
+ size 315
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ streamlit==1.24.0
2
+ torch==2.0.0
3
+ transformers==4.27.0
4
+ pandas==1.5.3
ticket_classification_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e4a507d772bd1ffe5ae6878e31ce1b86a479b69f204d682f775544098a6fbdb0
3
+ size 498701064