File size: 7,339 Bytes
b01a119
5f47787
 
 
 
 
b01a119
ad24ba0
5f47787
 
 
b01a119
5f47787
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c467e42
5442884
c467e42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad24ba0
c467e42
5f47787
c467e42
5f47787
 
ad24ba0
5f47787
 
 
 
 
9770656
5f47787
 
 
 
 
c467e42
 
 
 
 
 
 
 
5f47787
c467e42
5f47787
 
c467e42
5f47787
c467e42
 
5f47787
 
c467e42
 
 
 
 
 
5f47787
c467e42
 
5f47787
 
 
 
 
 
 
c467e42
5f47787
c467e42
 
 
 
5f47787
 
 
 
 
c467e42
5f47787
9770656
 
 
 
 
 
5f47787
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import gradio as gr
import torch
import torch.nn as nn
import pickle
import pandas as pd
from transformers import RobertaTokenizerFast, RobertaModel

# Implement by MinTheinWin@3907578Y
# Load label mappings
with open("label_mappings.pkl", "rb") as f:
    label_mappings = pickle.load(f)

label_to_team = label_mappings.get("label_to_team", {})
label_to_email = label_mappings.get("label_to_email", {})

# Load the tokenizer
tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")

# Define RoBERTa Model
class RoBertaClassifier(nn.Module):
    def __init__(self, num_teams, num_emails):
        super(RoBertaClassifier, self).__init__()
        self.roberta = RobertaModel.from_pretrained("roberta-base")
        self.team_classifier = nn.Linear(self.roberta.config.hidden_size, num_teams)
        self.email_classifier = nn.Linear(self.roberta.config.hidden_size, num_emails)

    def forward(self, input_ids, attention_mask):
        outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
        cls_output = outputs.last_hidden_state[:, 0, :]

        team_logits = self.team_classifier(cls_output)
        email_logits = self.email_classifier(cls_output)

        return team_logits, email_logits

# Load Model
num_teams = len(label_to_team)
num_emails = len(label_to_email)
model = RoBertaClassifier(num_teams, num_emails)
checkpoint = torch.load("ticket_classification_model.pth", map_location=torch.device("cpu"))
filtered_checkpoint = {k: v for k, v in checkpoint.items() if k in model.state_dict()}
model.load_state_dict(filtered_checkpoint, strict=False)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)
model.eval()

# Prediction Function
def predict_tickets(ticket_descriptions):
    predictions = []
    csv_data = []
    for idx, description in enumerate(ticket_descriptions, start=1):
        inputs = tokenizer(description, return_tensors="pt", truncation=True, padding="max_length", max_length=128).to(device)
        with torch.no_grad():
            team_logits, email_logits = model(inputs.input_ids, inputs.attention_mask)
        predicted_team_index = team_logits.argmax(dim=-1).cpu().item()
        predicted_email_index = email_logits.argmax(dim=-1).cpu().item()
        predicted_team = label_to_team.get(predicted_team_index, "Unknown Team")
        predicted_email = label_to_email.get(predicted_email_index, "Unknown Email")
        predictions.append(f"**{idx}. {description}**\n   - **Assigned Team:** {predicted_team}\n   - **Team Email:** {predicted_email}\n")
        csv_data.append([idx, description, predicted_team, predicted_email])

    df = pd.DataFrame(csv_data, columns=["Index", "Description", "Assigned Team", "Team Email"])
    csv_file = "ticket-predictions.csv"
    df.to_csv(csv_file, index=False)
    return "\n".join(predictions), csv_file

# Gradio Functions
def gradio_predict(option, text_input, file_input):
    if option == "Enter Text":
        descriptions = text_input.split("\n")
        descriptions = [desc.strip() for desc in descriptions if desc.strip()]
    elif option == "Upload CSV" and file_input is not None:
        df = pd.read_csv(file_input)
        if "Description" not in df.columns:
            return "⚠️ Error: CSV must contain a 'Description' column.", None
        descriptions = df["Description"].tolist()
    else:
        return "⚠️ Please provide input.", None

    results, csv_file = predict_tickets(descriptions)
    return results, csv_file

def clear_inputs():
    return "Enter Text", "", None, "", None

# Custom CSS for improved UI and fixed input container sizes
custom_css = """
.gradio-container { 
    max-width: 1000px !important; 
    margin: auto !important; 
}
#title { 
    text-align: center; 
    font-size: 26px !important; 
    font-weight: bold; 
}
#predict-button, #clear-button, #download-button { 
    width: 100% !important; 
    height: 55px !important; 
    font-size: 18px !important; 
}
#results-box { 
    height: 350px !important; 
    overflow-y: auto !important; 
    background: #f9f9f9; 
    padding: 15px; 
    border-radius: 10px; 
    font-size: 16px; 
}
/* Reduce vertical padding for the radio component */
#choose_input_method {
    padding-top: 5px !important;
    padding-bottom: 5px !important;
}
/* Force both input components to have the same min-height */
#text_input, #file_input {
    min-height: 200px !important;
    /* Optionally add a consistent border and padding to match styling */
    border: 1px solid #ccc;
    padding: 10px;
}
"""

# Gradio App UI
with gr.Blocks(css=custom_css) as app:
    gr.Markdown(
        """
        # AI Solution for Defect Ticket Classification

        **Supports:** Multi-line text input & CSV upload.  
        **Output:** Text results & downloadable CSV file.  
        **Model:** Fine-tuned **RoBERTa** for classification.  

        Enter ticket Description/Comment/Summary or upload a **CSV file** to predict Assigned Team & Team Email.
        """,
        elem_id="title"
    )

    with gr.Row():
        with gr.Column(scale=1):
            # Radio component with elem_id for CSS targeting
            option = gr.Radio(
                ["Enter Text", "Upload CSV"], 
                label="πŸ“ Choose Input Method", 
                value="Enter Text", 
                elem_id="choose_input_method"
            )

            # Both inputs are given an element id to force consistent dimensions.
            text_input = gr.Textbox(
                label="Enter Ticket Description/Comment/Summary (One per line)", 
                visible=True,
                lines=6, 
                placeholder="Example:\n - Database performance issue\n - Login fails for admin users...",
                elem_id="text_input"
            )

            file_input = gr.File(
                label="πŸ“‚ Upload CSV (Optional)", 
                type="filepath", 
                visible=False,
                elem_id="file_input"
            )

        with gr.Column(scale=1):
            gr.Markdown("## Prediction Results")
            results_output = gr.Markdown(elem_id="results-box", visible=True)
            download_csv = gr.File(label="πŸ“₯ Download Predictions CSV", interactive=False)

    with gr.Row():
        predict_btn = gr.Button("PREDICT", variant="primary")
        clear_btn = gr.Button("CLEAR", variant="secondary")

    # Toggle the visibility of input components to ensure consistent sizing
    def toggle_input(selected_option):
        if selected_option == "Enter Text":
            return gr.update(visible=True), gr.update(visible=False)
        else:
            return gr.update(visible=False), gr.update(visible=True)

    option.change(fn=toggle_input, inputs=[option], outputs=[text_input, file_input])
    predict_btn.click(fn=gradio_predict, inputs=[option, text_input, file_input], outputs=[results_output, download_csv])
    clear_btn.click(fn=clear_inputs, inputs=[], outputs=[option, text_input, file_input, results_output, download_csv])

    # Footer view
    gr.Markdown("---")
    gr.HTML(
        """
        <div style="text-align: center; color: gray; padding-top: 10px;">
            <p>Developed by NYP student @ Min Thein Win: Student ID: 3907578Y</p>
        </div>
        """
    )

# Launch App
app.launch(share=True)