import streamlit as st import torch import torch.nn as nn import pickle import pandas as pd from transformers import RobertaTokenizerFast, RobertaModel # Load label mappings with open("label_mappings.pkl", "rb") as f: label_mappings = pickle.load(f) label_to_team = label_mappings.get("label_to_team", {}) label_to_email = label_mappings.get("label_to_email", {}) # Load tokenizer tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base") # Define RoBERTa Model for multi-task classification class RoBertaClassifier(nn.Module): def __init__(self, num_teams, num_emails): super(RoBertaClassifier, self).__init__() self.roberta = RobertaModel.from_pretrained("roberta-base") self.team_classifier = nn.Linear(self.roberta.config.hidden_size, num_teams) self.email_classifier = nn.Linear(self.roberta.config.hidden_size, num_emails) def forward(self, input_ids, attention_mask): outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask) cls_output = outputs.last_hidden_state[:, 0, :] team_logits = self.team_classifier(cls_output) email_logits = self.email_classifier(cls_output) return team_logits, email_logits # Initialize model and load checkpoint num_teams = len(label_to_team) num_emails = len(label_to_email) model = RoBertaClassifier(num_teams, num_emails) checkpoint = torch.load("ticket_classification_model.pth", map_location=torch.device("cpu")) filtered_checkpoint = {k: v for k, v in checkpoint.items() if k in model.state_dict()} model.load_state_dict(filtered_checkpoint, strict=False) device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model.to(device) model.eval() # Prediction function def predict_tickets(ticket_descriptions): predictions = [] csv_data = [] for idx, description in enumerate(ticket_descriptions, start=1): inputs = tokenizer( description, return_tensors="pt", truncation=True, padding="max_length", max_length=128 ).to(device) with torch.no_grad(): team_logits, email_logits = model(inputs.input_ids, inputs.attention_mask) predicted_team_index = team_logits.argmax(dim=-1).cpu().item() predicted_email_index = email_logits.argmax(dim=-1).cpu().item() predicted_team = label_to_team.get(predicted_team_index, "Unknown Team") predicted_email = label_to_email.get(predicted_email_index, "Unknown Email") predictions.append( f"{idx}. {description}\n - Assigned Team: {predicted_team}\n - Team Email: {predicted_email}\n" ) csv_data.append([idx, description, predicted_team, predicted_email]) df = pd.DataFrame(csv_data, columns=["Index", "Description", "Assigned Team", "Team Email"]) return "\n".join(predictions), df # Streamlit UI st.markdown("

AI Solution for Defect Ticket Classification

", unsafe_allow_html=True) st.markdown("""

Supports: Multi-line text input & CSV upload.

Output: Text results & downloadable CSV file.

Model: Fine-tuned RoBERTa for classification.

""", unsafe_allow_html=True) st.markdown("

Enter ticket Description/Comment/Summary or upload a CSV file to predict Assigned Team & Team Email.

", unsafe_allow_html=True) # Choose input method option = st.radio("📝 Choose Input Method", ["Enter Text", "Upload CSV"]) descriptions = [] if option == "Enter Text": text_input = st.text_area( "Enter Ticket Description/Comment/Summary (One per line)", placeholder="Example:\n - Database performance issue\n - Login fails for admin users..." ) descriptions = [line.strip() for line in text_input.split("\n") if line.strip()] else: file_input = st.file_uploader("Upload CSV", type=["csv"]) if file_input is not None: df_input = pd.read_csv(file_input) if "Description" not in df_input.columns: st.error("⚠️ Error: CSV must contain a 'Description' column.") else: descriptions = df_input["Description"].dropna().tolist() # Store prediction results in session state so they persist if "prediction_results" not in st.session_state: st.session_state.prediction_results = None if "df_results" not in st.session_state: st.session_state.df_results = None # Create a horizontal layout for the buttons col1, col2 = st.columns([1, 1]) with col1: if st.button("PREDICT"): if not descriptions: st.error("⚠️ Please provide valid input.") else: with st.spinner("Predicting..."): results, df_results = predict_tickets(descriptions) st.session_state.prediction_results = results st.session_state.df_results = df_results # Display prediction results if available if st.session_state.prediction_results: st.markdown("

Prediction Results

", unsafe_allow_html=True) st.text(st.session_state.prediction_results) csv_data = st.session_state.df_results.to_csv(index=False).encode('utf-8') st.download_button( label="📥 Download Predictions CSV", data=csv_data, file_name="ticket-predictions.csv", mime="text/csv" ) with col2: if st.button("CLEAR"): # Clear the prediction results from session state st.session_state.prediction_results = None st.session_state.df_results = None st.rerun() st.markdown("---") st.markdown( "

Developed by NYP student @ Min Thein Win: Student ID: 3907578Y

", unsafe_allow_html=True )