mintheinwin commited on
Commit
f511b08
·
1 Parent(s): 120cd79

update app

Browse files
Files changed (1) hide show
  1. app.py +52 -34
app.py CHANGED
@@ -5,7 +5,6 @@ import pickle
5
  import pandas as pd
6
  from transformers import RobertaTokenizerFast, RobertaModel
7
 
8
- # -------------------------------
9
  # Load label mappings
10
  with open("label_mappings.pkl", "rb") as f:
11
  label_mappings = pickle.load(f)
@@ -13,11 +12,10 @@ with open("label_mappings.pkl", "rb") as f:
13
  label_to_team = label_mappings.get("label_to_team", {})
14
  label_to_email = label_mappings.get("label_to_email", {})
15
 
16
- # -------------------------------
17
  # Load tokenizer
18
  tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
19
 
20
- # -------------------------------
21
  # Define RoBERTa Model for multi-task classification
22
  class RoBertaClassifier(nn.Module):
23
  def __init__(self, num_teams, num_emails):
@@ -33,7 +31,6 @@ class RoBertaClassifier(nn.Module):
33
  email_logits = self.email_classifier(cls_output)
34
  return team_logits, email_logits
35
 
36
- # -------------------------------
37
  # Initialize model and load checkpoint
38
  num_teams = len(label_to_team)
39
  num_emails = len(label_to_email)
@@ -47,7 +44,7 @@ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cp
47
  model.to(device)
48
  model.eval()
49
 
50
- # -------------------------------
51
  # Prediction function
52
  def predict_tickets(ticket_descriptions):
53
  predictions = []
@@ -73,19 +70,23 @@ def predict_tickets(ticket_descriptions):
73
  df = pd.DataFrame(csv_data, columns=["Index", "Description", "Assigned Team", "Team Email"])
74
  return "\n".join(predictions), df
75
 
76
- # -------------------------------
77
  # Streamlit UI
78
- #st.title("AI Solution for Defect Ticket Classification")
79
- st.markdown("<h2 style='text-align: center;'>AI Solution for Defect Ticket Classification</h2>", unsafe_allow_html=True)
80
  st.markdown("""
81
- **Supports:** Multi-line text input & CSV upload.
82
- **Output:** Text results & downloadable CSV file.
83
- **Model:** Fine-tuned **RoBERTa** for classification.
84
- """)
 
 
 
85
 
86
  # Choose input method
87
  option = st.radio("📝 Choose Input Method", ["Enter Text", "Upload CSV"])
88
 
 
89
  if option == "Enter Text":
90
  text_input = st.text_area(
91
  "Enter Ticket Description/Comment/Summary (One per line)",
@@ -94,7 +95,6 @@ if option == "Enter Text":
94
  descriptions = [line.strip() for line in text_input.split("\n") if line.strip()]
95
  else:
96
  file_input = st.file_uploader("Upload CSV", type=["csv"])
97
- descriptions = []
98
  if file_input is not None:
99
  df_input = pd.read_csv(file_input)
100
  if "Description" not in df_input.columns:
@@ -102,29 +102,47 @@ else:
102
  else:
103
  descriptions = df_input["Description"].dropna().tolist()
104
 
105
- # Trigger prediction when the button is clicked
106
- if st.button("PREDICT"):
107
- if not descriptions:
108
- st.error("⚠️ Please provide valid input.")
109
- else:
110
- with st.spinner("Predicting..."):
111
- results, df_results = predict_tickets(descriptions)
112
- st.markdown("## Prediction Results")
113
- st.text(results)
114
- csv_data = df_results.to_csv(index=False).encode('utf-8')
115
- st.download_button(
116
- label="📥 Download Predictions CSV",
117
- data=csv_data,
118
- file_name="ticket-predictions.csv",
119
- mime="text/csv"
120
- )
121
 
122
- # Clear button: simply reloads the app
123
- if st.button("CLEAR"):
124
- st.experimental_rerun()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
  st.markdown("---")
127
  st.markdown(
128
- "<p style='text-align: center;color: gray;'>Developed by NYP student @ Min Thein Win: Student ID: 3907578Y</p>",
129
  unsafe_allow_html=True
130
- )
 
5
  import pandas as pd
6
  from transformers import RobertaTokenizerFast, RobertaModel
7
 
 
8
  # Load label mappings
9
  with open("label_mappings.pkl", "rb") as f:
10
  label_mappings = pickle.load(f)
 
12
  label_to_team = label_mappings.get("label_to_team", {})
13
  label_to_email = label_mappings.get("label_to_email", {})
14
 
15
+
16
  # Load tokenizer
17
  tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
18
 
 
19
  # Define RoBERTa Model for multi-task classification
20
  class RoBertaClassifier(nn.Module):
21
  def __init__(self, num_teams, num_emails):
 
31
  email_logits = self.email_classifier(cls_output)
32
  return team_logits, email_logits
33
 
 
34
  # Initialize model and load checkpoint
35
  num_teams = len(label_to_team)
36
  num_emails = len(label_to_email)
 
44
  model.to(device)
45
  model.eval()
46
 
47
+
48
  # Prediction function
49
  def predict_tickets(ticket_descriptions):
50
  predictions = []
 
70
  df = pd.DataFrame(csv_data, columns=["Index", "Description", "Assigned Team", "Team Email"])
71
  return "\n".join(predictions), df
72
 
73
+
74
  # Streamlit UI
75
+ st.markdown("<h2 style='text-align: center; font-size:22px;'>AI Solution for Defect Ticket Classification</h2>", unsafe_allow_html=True)
76
+
77
  st.markdown("""
78
+ <p style='text-align: center; font-size:16px;'><strong>Supports:</strong> Multi-line text input & CSV upload.</p>
79
+ <p style='text-align: center; font-size:16px;'><strong>Output:</strong> Text results & downloadable CSV file.</p>
80
+ <p style='text-align: center; font-size:16px;'><strong>Model:</strong> Fine-tuned <strong>RoBERTa</strong> for classification.</p>
81
+ """, unsafe_allow_html=True)
82
+
83
+ st.markdown("<h3 style='font-size:16px;'>Enter ticket Description/Comment/Summary or upload a CSV file to predict Assigned Team & Team Email.</h3>", unsafe_allow_html=True)
84
+
85
 
86
  # Choose input method
87
  option = st.radio("📝 Choose Input Method", ["Enter Text", "Upload CSV"])
88
 
89
+ descriptions = []
90
  if option == "Enter Text":
91
  text_input = st.text_area(
92
  "Enter Ticket Description/Comment/Summary (One per line)",
 
95
  descriptions = [line.strip() for line in text_input.split("\n") if line.strip()]
96
  else:
97
  file_input = st.file_uploader("Upload CSV", type=["csv"])
 
98
  if file_input is not None:
99
  df_input = pd.read_csv(file_input)
100
  if "Description" not in df_input.columns:
 
102
  else:
103
  descriptions = df_input["Description"].dropna().tolist()
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
+ # Store prediction results in session state so they persist
107
+ if "prediction_results" not in st.session_state:
108
+ st.session_state.prediction_results = None
109
+ if "df_results" not in st.session_state:
110
+ st.session_state.df_results = None
111
+
112
+ # Create a horizontal layout for the buttons
113
+ col1, col2 = st.columns([1, 1])
114
+
115
+ with col1:
116
+ if st.button("PREDICT"):
117
+ if not descriptions:
118
+ st.error("⚠️ Please provide valid input.")
119
+ else:
120
+ with st.spinner("Predicting..."):
121
+ results, df_results = predict_tickets(descriptions)
122
+ st.session_state.prediction_results = results
123
+ st.session_state.df_results = df_results
124
+
125
+ # Display prediction results if available
126
+ if st.session_state.prediction_results:
127
+ st.markdown("<h3 style='font-size:16px;'>Prediction Results</h3>", unsafe_allow_html=True)
128
+ st.text(st.session_state.prediction_results)
129
+ csv_data = st.session_state.df_results.to_csv(index=False).encode('utf-8')
130
+ st.download_button(
131
+ label="📥 Download Predictions CSV",
132
+ data=csv_data,
133
+ file_name="ticket-predictions.csv",
134
+ mime="text/csv"
135
+ )
136
+
137
+ with col2:
138
+ if st.button("CLEAR"):
139
+ # Clear the prediction results from session state
140
+ st.session_state.prediction_results = None
141
+ st.session_state.df_results = None
142
+ st.rerun()
143
 
144
  st.markdown("---")
145
  st.markdown(
146
+ "<p style='text-align: center;color: gray; font-size:14px;'>Developed by NYP student @ Min Thein Win: Student ID: 3907578Y</p>",
147
  unsafe_allow_html=True
148
+ )