mintheinwin's picture
update app
f511b08
import streamlit as st
import torch
import torch.nn as nn
import pickle
import pandas as pd
from transformers import RobertaTokenizerFast, RobertaModel
# Load label mappings
with open("label_mappings.pkl", "rb") as f:
label_mappings = pickle.load(f)
label_to_team = label_mappings.get("label_to_team", {})
label_to_email = label_mappings.get("label_to_email", {})
# Load tokenizer
tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
# Define RoBERTa Model for multi-task classification
class RoBertaClassifier(nn.Module):
def __init__(self, num_teams, num_emails):
super(RoBertaClassifier, self).__init__()
self.roberta = RobertaModel.from_pretrained("roberta-base")
self.team_classifier = nn.Linear(self.roberta.config.hidden_size, num_teams)
self.email_classifier = nn.Linear(self.roberta.config.hidden_size, num_emails)
def forward(self, input_ids, attention_mask):
outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
cls_output = outputs.last_hidden_state[:, 0, :]
team_logits = self.team_classifier(cls_output)
email_logits = self.email_classifier(cls_output)
return team_logits, email_logits
# Initialize model and load checkpoint
num_teams = len(label_to_team)
num_emails = len(label_to_email)
model = RoBertaClassifier(num_teams, num_emails)
checkpoint = torch.load("ticket_classification_model.pth", map_location=torch.device("cpu"))
filtered_checkpoint = {k: v for k, v in checkpoint.items() if k in model.state_dict()}
model.load_state_dict(filtered_checkpoint, strict=False)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)
model.eval()
# Prediction function
def predict_tickets(ticket_descriptions):
predictions = []
csv_data = []
for idx, description in enumerate(ticket_descriptions, start=1):
inputs = tokenizer(
description,
return_tensors="pt",
truncation=True,
padding="max_length",
max_length=128
).to(device)
with torch.no_grad():
team_logits, email_logits = model(inputs.input_ids, inputs.attention_mask)
predicted_team_index = team_logits.argmax(dim=-1).cpu().item()
predicted_email_index = email_logits.argmax(dim=-1).cpu().item()
predicted_team = label_to_team.get(predicted_team_index, "Unknown Team")
predicted_email = label_to_email.get(predicted_email_index, "Unknown Email")
predictions.append(
f"{idx}. {description}\n - Assigned Team: {predicted_team}\n - Team Email: {predicted_email}\n"
)
csv_data.append([idx, description, predicted_team, predicted_email])
df = pd.DataFrame(csv_data, columns=["Index", "Description", "Assigned Team", "Team Email"])
return "\n".join(predictions), df
# Streamlit UI
st.markdown("<h2 style='text-align: center; font-size:22px;'>AI Solution for Defect Ticket Classification</h2>", unsafe_allow_html=True)
st.markdown("""
<p style='text-align: center; font-size:16px;'><strong>Supports:</strong> Multi-line text input & CSV upload.</p>
<p style='text-align: center; font-size:16px;'><strong>Output:</strong> Text results & downloadable CSV file.</p>
<p style='text-align: center; font-size:16px;'><strong>Model:</strong> Fine-tuned <strong>RoBERTa</strong> for classification.</p>
""", unsafe_allow_html=True)
st.markdown("<h3 style='font-size:16px;'>Enter ticket Description/Comment/Summary or upload a CSV file to predict Assigned Team & Team Email.</h3>", unsafe_allow_html=True)
# Choose input method
option = st.radio("๐Ÿ“ Choose Input Method", ["Enter Text", "Upload CSV"])
descriptions = []
if option == "Enter Text":
text_input = st.text_area(
"Enter Ticket Description/Comment/Summary (One per line)",
placeholder="Example:\n - Database performance issue\n - Login fails for admin users..."
)
descriptions = [line.strip() for line in text_input.split("\n") if line.strip()]
else:
file_input = st.file_uploader("Upload CSV", type=["csv"])
if file_input is not None:
df_input = pd.read_csv(file_input)
if "Description" not in df_input.columns:
st.error("โš ๏ธ Error: CSV must contain a 'Description' column.")
else:
descriptions = df_input["Description"].dropna().tolist()
# Store prediction results in session state so they persist
if "prediction_results" not in st.session_state:
st.session_state.prediction_results = None
if "df_results" not in st.session_state:
st.session_state.df_results = None
# Create a horizontal layout for the buttons
col1, col2 = st.columns([1, 1])
with col1:
if st.button("PREDICT"):
if not descriptions:
st.error("โš ๏ธ Please provide valid input.")
else:
with st.spinner("Predicting..."):
results, df_results = predict_tickets(descriptions)
st.session_state.prediction_results = results
st.session_state.df_results = df_results
# Display prediction results if available
if st.session_state.prediction_results:
st.markdown("<h3 style='font-size:16px;'>Prediction Results</h3>", unsafe_allow_html=True)
st.text(st.session_state.prediction_results)
csv_data = st.session_state.df_results.to_csv(index=False).encode('utf-8')
st.download_button(
label="๐Ÿ“ฅ Download Predictions CSV",
data=csv_data,
file_name="ticket-predictions.csv",
mime="text/csv"
)
with col2:
if st.button("CLEAR"):
# Clear the prediction results from session state
st.session_state.prediction_results = None
st.session_state.df_results = None
st.rerun()
st.markdown("---")
st.markdown(
"<p style='text-align: center;color: gray; font-size:14px;'>Developed by NYP student @ Min Thein Win: Student ID: 3907578Y</p>",
unsafe_allow_html=True
)