File size: 5,981 Bytes
4ea6107
 
 
 
 
 
 
 
 
 
 
 
 
 
f511b08
4ea6107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f511b08
4ea6107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f511b08
4ea6107
f511b08
 
4ea6107
f511b08
 
 
 
 
 
 
4ea6107
 
 
 
f511b08
4ea6107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f511b08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ea6107
 
 
f511b08
4ea6107
f511b08
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import streamlit as st
import torch
import torch.nn as nn
import pickle
import pandas as pd
from transformers import RobertaTokenizerFast, RobertaModel

# Load label mappings
with open("label_mappings.pkl", "rb") as f:
    label_mappings = pickle.load(f)

label_to_team = label_mappings.get("label_to_team", {})
label_to_email = label_mappings.get("label_to_email", {})


# Load tokenizer
tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")

# Define RoBERTa Model for multi-task classification
class RoBertaClassifier(nn.Module):
    def __init__(self, num_teams, num_emails):
        super(RoBertaClassifier, self).__init__()
        self.roberta = RobertaModel.from_pretrained("roberta-base")
        self.team_classifier = nn.Linear(self.roberta.config.hidden_size, num_teams)
        self.email_classifier = nn.Linear(self.roberta.config.hidden_size, num_emails)

    def forward(self, input_ids, attention_mask):
        outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
        cls_output = outputs.last_hidden_state[:, 0, :]
        team_logits = self.team_classifier(cls_output)
        email_logits = self.email_classifier(cls_output)
        return team_logits, email_logits

# Initialize model and load checkpoint
num_teams = len(label_to_team)
num_emails = len(label_to_email)
model = RoBertaClassifier(num_teams, num_emails)

checkpoint = torch.load("ticket_classification_model.pth", map_location=torch.device("cpu"))
filtered_checkpoint = {k: v for k, v in checkpoint.items() if k in model.state_dict()}
model.load_state_dict(filtered_checkpoint, strict=False)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)
model.eval()


# Prediction function
def predict_tickets(ticket_descriptions):
    predictions = []
    csv_data = []
    for idx, description in enumerate(ticket_descriptions, start=1):
        inputs = tokenizer(
            description, 
            return_tensors="pt", 
            truncation=True, 
            padding="max_length", 
            max_length=128
        ).to(device)
        with torch.no_grad():
            team_logits, email_logits = model(inputs.input_ids, inputs.attention_mask)
        predicted_team_index = team_logits.argmax(dim=-1).cpu().item()
        predicted_email_index = email_logits.argmax(dim=-1).cpu().item()
        predicted_team = label_to_team.get(predicted_team_index, "Unknown Team")
        predicted_email = label_to_email.get(predicted_email_index, "Unknown Email")
        predictions.append(
            f"{idx}. {description}\n   - Assigned Team: {predicted_team}\n   - Team Email: {predicted_email}\n"
        )
        csv_data.append([idx, description, predicted_team, predicted_email])
    df = pd.DataFrame(csv_data, columns=["Index", "Description", "Assigned Team", "Team Email"])
    return "\n".join(predictions), df


# Streamlit UI
st.markdown("<h2 style='text-align: center; font-size:22px;'>AI Solution for Defect Ticket Classification</h2>", unsafe_allow_html=True)

st.markdown("""
<p style='text-align: center; font-size:16px;'><strong>Supports:</strong> Multi-line text input & CSV upload.</p>  
<p style='text-align: center; font-size:16px;'><strong>Output:</strong> Text results & downloadable CSV file.</p>  
<p style='text-align: center; font-size:16px;'><strong>Model:</strong> Fine-tuned <strong>RoBERTa</strong> for classification.</p>
""", unsafe_allow_html=True)

st.markdown("<h3 style='font-size:16px;'>Enter ticket Description/Comment/Summary or upload a CSV file to predict Assigned Team & Team Email.</h3>", unsafe_allow_html=True)


# Choose input method
option = st.radio("๐Ÿ“ Choose Input Method", ["Enter Text", "Upload CSV"])

descriptions = []
if option == "Enter Text":
    text_input = st.text_area(
        "Enter Ticket Description/Comment/Summary (One per line)", 
        placeholder="Example:\n - Database performance issue\n - Login fails for admin users..."
    )
    descriptions = [line.strip() for line in text_input.split("\n") if line.strip()]
else:
    file_input = st.file_uploader("Upload CSV", type=["csv"])
    if file_input is not None:
        df_input = pd.read_csv(file_input)
        if "Description" not in df_input.columns:
            st.error("โš ๏ธ Error: CSV must contain a 'Description' column.")
        else:
            descriptions = df_input["Description"].dropna().tolist()


# Store prediction results in session state so they persist
if "prediction_results" not in st.session_state:
    st.session_state.prediction_results = None
if "df_results" not in st.session_state:
    st.session_state.df_results = None

# Create a horizontal layout for the buttons
col1, col2 = st.columns([1, 1])

with col1:
    if st.button("PREDICT"):
        if not descriptions:
            st.error("โš ๏ธ Please provide valid input.")
        else:
            with st.spinner("Predicting..."):
                results, df_results = predict_tickets(descriptions)
            st.session_state.prediction_results = results
            st.session_state.df_results = df_results

# Display prediction results if available
if st.session_state.prediction_results:
    st.markdown("<h3 style='font-size:16px;'>Prediction Results</h3>", unsafe_allow_html=True)
    st.text(st.session_state.prediction_results)
    csv_data = st.session_state.df_results.to_csv(index=False).encode('utf-8')
    st.download_button(
        label="๐Ÿ“ฅ Download Predictions CSV",
        data=csv_data,
        file_name="ticket-predictions.csv",
        mime="text/csv"
    )

with col2:
    if st.button("CLEAR"):
        # Clear the prediction results from session state
        st.session_state.prediction_results = None
        st.session_state.df_results = None
        st.rerun()

st.markdown("---")
st.markdown(
    "<p style='text-align: center;color: gray; font-size:14px;'>Developed by NYP student @ Min Thein Win: Student ID: 3907578Y</p>", 
    unsafe_allow_html=True
)