Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
import torch | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
from huggingface_hub import snapshot_download | |
# Student Information | |
My_info = "Student ID: 6319250G, Name: Aung Hlaing Tun" | |
# Define Hugging Face Model Repo | |
MODEL_REPO_ID = "ZAM-ITI-110/Distil_Bert_V3" | |
# Load Model & Tokenizer from Hugging Face | |
def load_model(repo_id): | |
"""Download and load the model and tokenizer.""" | |
cache_dir = "/home/user/app/hf_models" | |
os.makedirs(cache_dir, exist_ok=True) | |
download_dir = snapshot_download(repo_id, cache_dir=cache_dir, local_files_only=False) | |
model = AutoModelForSequenceClassification.from_pretrained(download_dir) | |
tokenizer = AutoTokenizer.from_pretrained(download_dir) | |
return model, tokenizer | |
# Load Model | |
model, tokenizer = load_model(MODEL_REPO_ID) | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model.to(device) | |
model.eval() | |
# Prediction Function (Single Ticket) | |
def predict_team_and_email(text): | |
"""Predict team and email for a single ticket description.""" | |
if not text.strip(): # Return empty if no input | |
return "", "" | |
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device) | |
with torch.no_grad(): | |
logits = model(**inputs).logits | |
pred = torch.argmax(logits, dim=-1).item() | |
label_mapping = { | |
0: "Code Review Team", 1: "Functional Team", 2: "Infrastructure Team", | |
3: "Performance Team", 4: "Security Team" | |
} | |
email_mapping = { | |
0: "[email protected]", 1: "[email protected]", | |
2: "[email protected]", 3: "[email protected]", | |
4: "[email protected]" | |
} | |
return label_mapping.get(pred, "Unknown"), email_mapping.get(pred, "Unknown") | |
# Send Ticket Function (Simulation) | |
def send_tickets(*args): | |
"""Simulate sending tickets based on predictions.""" | |
tickets = [] | |
for i, (text, team, email) in enumerate(zip(args[::2], args[1::2], args[2::2]), 1): | |
if text.strip() and team and email: | |
tickets.append(f"Ticket {i}: '{text}' -> {team} ({email})") | |
if tickets: | |
return "\n".join(tickets) + "\n\nSent successfully!" | |
return "No tickets to send." | |
# Clear Function | |
def clear_all(): | |
"""Clear all inputs and outputs.""" | |
return [""] * 19 # 6 tickets x (input, team, email) + 1 sent_output = 19 fields | |
# Gradio UI Setup | |
with gr.Blocks(title="AI Ticket Classifier") as interface: | |
gr.Markdown("📩 **Development of an AI Ticket Classifier Model Using DistilBERT**") | |
gr.Markdown(f"*{My_info}*") | |
gr.Markdown( | |
""" | |
**🔍 About this App** | |
- Predicts the appropriate **team** and **email** for up to 6 ticket descriptions. | |
- Click 'Predict' for each ticket, then 'Send Tickets' to process. | |
""" | |
) | |
# Ticket Entry Section | |
with gr.Column(): | |
gr.Markdown("### Enter Ticket Descriptions") | |
inputs = [] | |
outputs = [] | |
buttons = [] | |
for i in range(6): | |
with gr.Row(): | |
ticket_input = gr.Textbox(lines=2, placeholder=f"Ticket {i+1} description...", label=f"Ticket {i+1}") | |
team_output = gr.Textbox(label="Predicted Team", interactive=False) | |
email_output = gr.Textbox(label="Team Email", interactive=False) | |
predict_btn = gr.Button(f"Predict {i+1}") | |
inputs.append(ticket_input) | |
outputs.extend([team_output, email_output]) | |
buttons.append(predict_btn) | |
# Action Buttons | |
with gr.Row(): | |
send_btn = gr.Button("Send Tickets") | |
clear_btn = gr.Button("Clear") | |
# Output for Sent Tickets | |
sent_output = gr.Textbox(label="Sent Tickets", interactive=False) | |
# Event Handlers for Predict Buttons | |
for i, btn in enumerate(buttons): | |
btn.click( | |
fn=predict_team_and_email, | |
inputs=inputs[i], | |
outputs=[outputs[i*2], outputs[i*2 + 1]] # Team and email for this ticket | |
) | |
# Send and Clear Handlers | |
send_btn.click( | |
fn=send_tickets, | |
inputs=inputs + outputs, | |
outputs=sent_output | |
) | |
clear_btn.click( | |
fn=clear_all, | |
inputs=None, | |
outputs=inputs + outputs + [sent_output] | |
) | |
# Launch the interface | |
interface.launch() |