import gradio as gr import torch import torch.nn as nn import pickle import pandas as pd from transformers import RobertaTokenizerFast, RobertaModel # Implement by MinTheinWin@3907578Y # Load label mappings with open("label_mappings.pkl", "rb") as f: label_mappings = pickle.load(f) label_to_team = label_mappings.get("label_to_team", {}) label_to_email = label_mappings.get("label_to_email", {}) # Load the tokenizer tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base") # Define RoBERTa Model class RoBertaClassifier(nn.Module): def __init__(self, num_teams, num_emails): super(RoBertaClassifier, self).__init__() self.roberta = RobertaModel.from_pretrained("roberta-base") self.team_classifier = nn.Linear(self.roberta.config.hidden_size, num_teams) self.email_classifier = nn.Linear(self.roberta.config.hidden_size, num_emails) def forward(self, input_ids, attention_mask): outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask) cls_output = outputs.last_hidden_state[:, 0, :] team_logits = self.team_classifier(cls_output) email_logits = self.email_classifier(cls_output) return team_logits, email_logits # Load Model num_teams = len(label_to_team) num_emails = len(label_to_email) model = RoBertaClassifier(num_teams, num_emails) checkpoint = torch.load("ticket_classification_model.pth", map_location=torch.device("cpu")) filtered_checkpoint = {k: v for k, v in checkpoint.items() if k in model.state_dict()} model.load_state_dict(filtered_checkpoint, strict=False) device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model.to(device) model.eval() # Prediction Function def predict_tickets(ticket_descriptions): predictions = [] csv_data = [] for idx, description in enumerate(ticket_descriptions, start=1): inputs = tokenizer(description, return_tensors="pt", truncation=True, padding="max_length", max_length=128).to(device) with torch.no_grad(): team_logits, email_logits = model(inputs.input_ids, inputs.attention_mask) predicted_team_index = team_logits.argmax(dim=-1).cpu().item() predicted_email_index = email_logits.argmax(dim=-1).cpu().item() predicted_team = label_to_team.get(predicted_team_index, "Unknown Team") predicted_email = label_to_email.get(predicted_email_index, "Unknown Email") predictions.append(f"**{idx}. {description}**\n - **Assigned Team:** {predicted_team}\n - **Team Email:** {predicted_email}\n") csv_data.append([idx, description, predicted_team, predicted_email]) df = pd.DataFrame(csv_data, columns=["Index", "Description", "Assigned Team", "Team Email"]) csv_file = "ticket-predictions.csv" df.to_csv(csv_file, index=False) return "\n".join(predictions), csv_file # Gradio Functions def gradio_predict(option, text_input, file_input): if option == "Enter Text": descriptions = text_input.split("\n") descriptions = [desc.strip() for desc in descriptions if desc.strip()] elif option == "Upload CSV" and file_input is not None: df = pd.read_csv(file_input) if "Description" not in df.columns: return "⚠️ Error: CSV must contain a 'Description' column.", None descriptions = df["Description"].tolist() else: return "⚠️ Please provide input.", None results, csv_file = predict_tickets(descriptions) return results, csv_file def clear_inputs(): return "Enter Text", "", None, "", None # Custom CSS for improved UI and fixed input container sizes custom_css = """ .gradio-container { max-width: 1000px !important; margin: auto !important; } #title { text-align: center; font-size: 26px !important; font-weight: bold; } #predict-button, #clear-button, #download-button { width: 100% !important; height: 55px !important; font-size: 18px !important; } #results-box { height: 350px !important; overflow-y: auto !important; background: #f9f9f9; padding: 15px; border-radius: 10px; font-size: 16px; } /* Reduce vertical padding for the radio component */ #choose_input_method { padding-top: 5px !important; padding-bottom: 5px !important; } /* Force both input components to have the same min-height */ #text_input, #file_input { min-height: 200px !important; /* Optionally add a consistent border and padding to match styling */ border: 1px solid #ccc; padding: 10px; } """ # Gradio App UI with gr.Blocks(css=custom_css) as app: gr.Markdown( """ # AI Solution for Defect Ticket Classification **Supports:** Multi-line text input & CSV upload. **Output:** Text results & downloadable CSV file. **Model:** Fine-tuned **RoBERTa** for classification. Enter ticket Description/Comment/Summary or upload a **CSV file** to predict Assigned Team & Team Email. """, elem_id="title" ) with gr.Row(): with gr.Column(scale=1): # Radio component with elem_id for CSS targeting option = gr.Radio( ["Enter Text", "Upload CSV"], label="📝 Choose Input Method", value="Enter Text", elem_id="choose_input_method" ) # Both inputs are given an element id to force consistent dimensions. text_input = gr.Textbox( label="Enter Ticket Description/Comment/Summary (One per line)", visible=True, lines=6, placeholder="Example:\n - Database performance issue\n - Login fails for admin users...", elem_id="text_input" ) file_input = gr.File( label="📂 Upload CSV (Optional)", type="filepath", visible=False, elem_id="file_input" ) with gr.Column(scale=1): gr.Markdown("## Prediction Results") results_output = gr.Markdown(elem_id="results-box", visible=True) download_csv = gr.File(label="📥 Download Predictions CSV", interactive=False) with gr.Row(): predict_btn = gr.Button("PREDICT", variant="primary") clear_btn = gr.Button("CLEAR", variant="secondary") # Toggle the visibility of input components to ensure consistent sizing def toggle_input(selected_option): if selected_option == "Enter Text": return gr.update(visible=True), gr.update(visible=False) else: return gr.update(visible=False), gr.update(visible=True) option.change(fn=toggle_input, inputs=[option], outputs=[text_input, file_input]) predict_btn.click(fn=gradio_predict, inputs=[option, text_input, file_input], outputs=[results_output, download_csv]) clear_btn.click(fn=clear_inputs, inputs=[], outputs=[option, text_input, file_input, results_output, download_csv]) # Footer view gr.Markdown("---") gr.HTML( """
Developed by NYP student @ Min Thein Win: Student ID: 3907578Y