File size: 5,981 Bytes
4ea6107 f511b08 4ea6107 f511b08 4ea6107 f511b08 4ea6107 f511b08 4ea6107 f511b08 4ea6107 f511b08 4ea6107 f511b08 4ea6107 f511b08 4ea6107 f511b08 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import streamlit as st
import torch
import torch.nn as nn
import pickle
import pandas as pd
from transformers import RobertaTokenizerFast, RobertaModel
# Load label mappings
with open("label_mappings.pkl", "rb") as f:
label_mappings = pickle.load(f)
label_to_team = label_mappings.get("label_to_team", {})
label_to_email = label_mappings.get("label_to_email", {})
# Load tokenizer
tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
# Define RoBERTa Model for multi-task classification
class RoBertaClassifier(nn.Module):
def __init__(self, num_teams, num_emails):
super(RoBertaClassifier, self).__init__()
self.roberta = RobertaModel.from_pretrained("roberta-base")
self.team_classifier = nn.Linear(self.roberta.config.hidden_size, num_teams)
self.email_classifier = nn.Linear(self.roberta.config.hidden_size, num_emails)
def forward(self, input_ids, attention_mask):
outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
cls_output = outputs.last_hidden_state[:, 0, :]
team_logits = self.team_classifier(cls_output)
email_logits = self.email_classifier(cls_output)
return team_logits, email_logits
# Initialize model and load checkpoint
num_teams = len(label_to_team)
num_emails = len(label_to_email)
model = RoBertaClassifier(num_teams, num_emails)
checkpoint = torch.load("ticket_classification_model.pth", map_location=torch.device("cpu"))
filtered_checkpoint = {k: v for k, v in checkpoint.items() if k in model.state_dict()}
model.load_state_dict(filtered_checkpoint, strict=False)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)
model.eval()
# Prediction function
def predict_tickets(ticket_descriptions):
predictions = []
csv_data = []
for idx, description in enumerate(ticket_descriptions, start=1):
inputs = tokenizer(
description,
return_tensors="pt",
truncation=True,
padding="max_length",
max_length=128
).to(device)
with torch.no_grad():
team_logits, email_logits = model(inputs.input_ids, inputs.attention_mask)
predicted_team_index = team_logits.argmax(dim=-1).cpu().item()
predicted_email_index = email_logits.argmax(dim=-1).cpu().item()
predicted_team = label_to_team.get(predicted_team_index, "Unknown Team")
predicted_email = label_to_email.get(predicted_email_index, "Unknown Email")
predictions.append(
f"{idx}. {description}\n - Assigned Team: {predicted_team}\n - Team Email: {predicted_email}\n"
)
csv_data.append([idx, description, predicted_team, predicted_email])
df = pd.DataFrame(csv_data, columns=["Index", "Description", "Assigned Team", "Team Email"])
return "\n".join(predictions), df
# Streamlit UI
st.markdown("<h2 style='text-align: center; font-size:22px;'>AI Solution for Defect Ticket Classification</h2>", unsafe_allow_html=True)
st.markdown("""
<p style='text-align: center; font-size:16px;'><strong>Supports:</strong> Multi-line text input & CSV upload.</p>
<p style='text-align: center; font-size:16px;'><strong>Output:</strong> Text results & downloadable CSV file.</p>
<p style='text-align: center; font-size:16px;'><strong>Model:</strong> Fine-tuned <strong>RoBERTa</strong> for classification.</p>
""", unsafe_allow_html=True)
st.markdown("<h3 style='font-size:16px;'>Enter ticket Description/Comment/Summary or upload a CSV file to predict Assigned Team & Team Email.</h3>", unsafe_allow_html=True)
# Choose input method
option = st.radio("๐ Choose Input Method", ["Enter Text", "Upload CSV"])
descriptions = []
if option == "Enter Text":
text_input = st.text_area(
"Enter Ticket Description/Comment/Summary (One per line)",
placeholder="Example:\n - Database performance issue\n - Login fails for admin users..."
)
descriptions = [line.strip() for line in text_input.split("\n") if line.strip()]
else:
file_input = st.file_uploader("Upload CSV", type=["csv"])
if file_input is not None:
df_input = pd.read_csv(file_input)
if "Description" not in df_input.columns:
st.error("โ ๏ธ Error: CSV must contain a 'Description' column.")
else:
descriptions = df_input["Description"].dropna().tolist()
# Store prediction results in session state so they persist
if "prediction_results" not in st.session_state:
st.session_state.prediction_results = None
if "df_results" not in st.session_state:
st.session_state.df_results = None
# Create a horizontal layout for the buttons
col1, col2 = st.columns([1, 1])
with col1:
if st.button("PREDICT"):
if not descriptions:
st.error("โ ๏ธ Please provide valid input.")
else:
with st.spinner("Predicting..."):
results, df_results = predict_tickets(descriptions)
st.session_state.prediction_results = results
st.session_state.df_results = df_results
# Display prediction results if available
if st.session_state.prediction_results:
st.markdown("<h3 style='font-size:16px;'>Prediction Results</h3>", unsafe_allow_html=True)
st.text(st.session_state.prediction_results)
csv_data = st.session_state.df_results.to_csv(index=False).encode('utf-8')
st.download_button(
label="๐ฅ Download Predictions CSV",
data=csv_data,
file_name="ticket-predictions.csv",
mime="text/csv"
)
with col2:
if st.button("CLEAR"):
# Clear the prediction results from session state
st.session_state.prediction_results = None
st.session_state.df_results = None
st.rerun()
st.markdown("---")
st.markdown(
"<p style='text-align: center;color: gray; font-size:14px;'>Developed by NYP student @ Min Thein Win: Student ID: 3907578Y</p>",
unsafe_allow_html=True
) |