|
import streamlit as st |
|
import torch |
|
import torch.nn as nn |
|
import pickle |
|
import pandas as pd |
|
from transformers import RobertaTokenizerFast, RobertaModel |
|
|
|
|
|
with open("label_mappings.pkl", "rb") as f: |
|
label_mappings = pickle.load(f) |
|
|
|
label_to_team = label_mappings.get("label_to_team", {}) |
|
label_to_email = label_mappings.get("label_to_email", {}) |
|
|
|
|
|
|
|
tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base") |
|
|
|
|
|
class RoBertaClassifier(nn.Module): |
|
def __init__(self, num_teams, num_emails): |
|
super(RoBertaClassifier, self).__init__() |
|
self.roberta = RobertaModel.from_pretrained("roberta-base") |
|
self.team_classifier = nn.Linear(self.roberta.config.hidden_size, num_teams) |
|
self.email_classifier = nn.Linear(self.roberta.config.hidden_size, num_emails) |
|
|
|
def forward(self, input_ids, attention_mask): |
|
outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask) |
|
cls_output = outputs.last_hidden_state[:, 0, :] |
|
team_logits = self.team_classifier(cls_output) |
|
email_logits = self.email_classifier(cls_output) |
|
return team_logits, email_logits |
|
|
|
|
|
num_teams = len(label_to_team) |
|
num_emails = len(label_to_email) |
|
model = RoBertaClassifier(num_teams, num_emails) |
|
|
|
checkpoint = torch.load("ticket_classification_model.pth", map_location=torch.device("cpu")) |
|
filtered_checkpoint = {k: v for k, v in checkpoint.items() if k in model.state_dict()} |
|
model.load_state_dict(filtered_checkpoint, strict=False) |
|
|
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
|
model.to(device) |
|
model.eval() |
|
|
|
|
|
|
|
def predict_tickets(ticket_descriptions): |
|
predictions = [] |
|
csv_data = [] |
|
for idx, description in enumerate(ticket_descriptions, start=1): |
|
inputs = tokenizer( |
|
description, |
|
return_tensors="pt", |
|
truncation=True, |
|
padding="max_length", |
|
max_length=128 |
|
).to(device) |
|
with torch.no_grad(): |
|
team_logits, email_logits = model(inputs.input_ids, inputs.attention_mask) |
|
predicted_team_index = team_logits.argmax(dim=-1).cpu().item() |
|
predicted_email_index = email_logits.argmax(dim=-1).cpu().item() |
|
predicted_team = label_to_team.get(predicted_team_index, "Unknown Team") |
|
predicted_email = label_to_email.get(predicted_email_index, "Unknown Email") |
|
predictions.append( |
|
f"{idx}. {description}\n - Assigned Team: {predicted_team}\n - Team Email: {predicted_email}\n" |
|
) |
|
csv_data.append([idx, description, predicted_team, predicted_email]) |
|
df = pd.DataFrame(csv_data, columns=["Index", "Description", "Assigned Team", "Team Email"]) |
|
return "\n".join(predictions), df |
|
|
|
|
|
|
|
st.markdown("<h2 style='text-align: center; font-size:22px;'>AI Solution for Defect Ticket Classification</h2>", unsafe_allow_html=True) |
|
|
|
st.markdown(""" |
|
<p style='text-align: center; font-size:16px;'><strong>Supports:</strong> Multi-line text input & CSV upload.</p> |
|
<p style='text-align: center; font-size:16px;'><strong>Output:</strong> Text results & downloadable CSV file.</p> |
|
<p style='text-align: center; font-size:16px;'><strong>Model:</strong> Fine-tuned <strong>RoBERTa</strong> for classification.</p> |
|
""", unsafe_allow_html=True) |
|
|
|
st.markdown("<h3 style='font-size:16px;'>Enter ticket Description/Comment/Summary or upload a CSV file to predict Assigned Team & Team Email.</h3>", unsafe_allow_html=True) |
|
|
|
|
|
|
|
option = st.radio("๐ Choose Input Method", ["Enter Text", "Upload CSV"]) |
|
|
|
descriptions = [] |
|
if option == "Enter Text": |
|
text_input = st.text_area( |
|
"Enter Ticket Description/Comment/Summary (One per line)", |
|
placeholder="Example:\n - Database performance issue\n - Login fails for admin users..." |
|
) |
|
descriptions = [line.strip() for line in text_input.split("\n") if line.strip()] |
|
else: |
|
file_input = st.file_uploader("Upload CSV", type=["csv"]) |
|
if file_input is not None: |
|
df_input = pd.read_csv(file_input) |
|
if "Description" not in df_input.columns: |
|
st.error("โ ๏ธ Error: CSV must contain a 'Description' column.") |
|
else: |
|
descriptions = df_input["Description"].dropna().tolist() |
|
|
|
|
|
|
|
if "prediction_results" not in st.session_state: |
|
st.session_state.prediction_results = None |
|
if "df_results" not in st.session_state: |
|
st.session_state.df_results = None |
|
|
|
|
|
col1, col2 = st.columns([1, 1]) |
|
|
|
with col1: |
|
if st.button("PREDICT"): |
|
if not descriptions: |
|
st.error("โ ๏ธ Please provide valid input.") |
|
else: |
|
with st.spinner("Predicting..."): |
|
results, df_results = predict_tickets(descriptions) |
|
st.session_state.prediction_results = results |
|
st.session_state.df_results = df_results |
|
|
|
|
|
if st.session_state.prediction_results: |
|
st.markdown("<h3 style='font-size:16px;'>Prediction Results</h3>", unsafe_allow_html=True) |
|
st.text(st.session_state.prediction_results) |
|
csv_data = st.session_state.df_results.to_csv(index=False).encode('utf-8') |
|
st.download_button( |
|
label="๐ฅ Download Predictions CSV", |
|
data=csv_data, |
|
file_name="ticket-predictions.csv", |
|
mime="text/csv" |
|
) |
|
|
|
with col2: |
|
if st.button("CLEAR"): |
|
|
|
st.session_state.prediction_results = None |
|
st.session_state.df_results = None |
|
st.rerun() |
|
|
|
st.markdown("---") |
|
st.markdown( |
|
"<p style='text-align: center;color: gray; font-size:14px;'>Developed by NYP student @ Min Thein Win: Student ID: 3907578Y</p>", |
|
unsafe_allow_html=True |
|
) |