Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
import torch | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
from huggingface_hub import snapshot_download | |
# Student Information | |
My_info = "Student ID: 6319250G, Name: Aung Hlaing Tun" | |
# Define Hugging Face Model Repo | |
MODEL_REPO_ID = "ZAM-ITI-110/Distil_Bert_V3" | |
# Load Model & Tokenizer from Hugging Face | |
def load_model(repo_id): | |
"""Download and load the model and tokenizer.""" | |
# Define cache directory | |
cache_dir = "/home/user/app/hf_models" | |
# Ensure directory exists | |
os.makedirs(cache_dir, exist_ok=True) | |
# Download model from Hugging Face (if not cached) | |
download_dir = snapshot_download(repo_id, cache_dir=cache_dir, local_files_only=False) | |
# Load model and tokenizer | |
model = AutoModelForSequenceClassification.from_pretrained(download_dir) | |
tokenizer = AutoTokenizer.from_pretrained(download_dir) | |
return model, tokenizer | |
# Load Model | |
model, tokenizer = load_model(MODEL_REPO_ID) | |
model.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu')) | |
# Prediction Function | |
def predict_team_and_email(text): | |
"""Predict the team and corresponding email for a given ticket description.""" | |
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True) | |
with torch.no_grad(): | |
logits = model(**inputs).logits | |
predicted_label = torch.argmax(logits, dim=-1).item() | |
# Mapping Labels to Team Names and Emails | |
label_mapping = { | |
0: "Code Review Team", | |
1: "Functional Team", | |
2: "Infrastructure Team", | |
3: "Performance Team", | |
4: "Security Team" | |
} | |
email_mapping = { | |
0: "[email protected]", | |
1: "[email protected]", | |
2: "[email protected]", | |
3: "[email protected]", | |
4: "[email protected]" | |
} | |
return f"Predicted Team: {label_mapping.get(predicted_label, 'Unknown')}", f"Predicted Email: {email_mapping.get(predicted_label, 'Unknown')}" | |
# Gradio UI Setup | |
with gr.Blocks() as interface: | |
gr.Markdown("📩 Development of an AI Ticket Classifier Model Using DistilBERT") | |
gr.Markdown(f"*{My_info}*") | |
gr.Markdown( | |
""" | |
**🔍 About this App** | |
- This system predicts the appropriate **team assignment** and **contact email** based on the ticket description. | |
- Simply enter up to **6 ticket descriptions**, and the AI will classify them accordingly. | |
""" | |
) | |
with gr.Row(): | |
input1 = gr.Textbox(lines=2, placeholder="Enter ticket description 1...", label="Ticket 1") | |
output_team1 = gr.Textbox(label="Predicted Team 1") | |
output_email1 = gr.Textbox(label="Predicted Email 1") | |
input2 = gr.Textbox(lines=2, placeholder="Enter ticket description 2...", label="Ticket 2") | |
output_team2 = gr.Textbox(label="Predicted Team 2") | |
output_email2 = gr.Textbox(label="Predicted Email 2") | |
input3 = gr.Textbox(lines=2, placeholder="Enter ticket description 3...", label="Ticket 3") | |
output_team3 = gr.Textbox(label="Predicted Team 3") | |
output_email3 = gr.Textbox(label="Predicted Email 3") | |
input4 = gr.Textbox(lines=2, placeholder="Enter ticket description 4...", label="Ticket 4") | |
output_team4 = gr.Textbox(label="Predicted Team 4") | |
output_email4 = gr.Textbox(label="Predicted Email 4") | |
input5 = gr.Textbox(lines=2, placeholder="Enter ticket description 5...", label="Ticket 5") | |
output_team5 = gr.Textbox(label="Predicted Team 5") | |
output_email5 = gr.Textbox(label="Predicted Email 5") | |
input6 = gr.Textbox(lines=2, placeholder="Enter ticket description 6...", label="Ticket 6") | |
output_team6 = gr.Textbox(label="Predicted Team 6") | |
output_email6 = gr.Textbox(label="Predicted Email 6") | |
# Add buttons to trigger predictions | |
with gr.Row(): | |
btn1 = gr.Button("Predict for Ticket 1") | |
btn2 = gr.Button("Predict for Ticket 2") | |
btn3 = gr.Button("Predict for Ticket 3") | |
btn4 = gr.Button("Predict for Ticket 4") | |
btn5 = gr.Button("Predict for Ticket 5") | |
btn6 = gr.Button("Predict for Ticket 6") | |
# Link buttons to prediction function | |
btn1.click(fn=predict_team_and_email, inputs=input1, outputs=[output_team1, output_email1]) | |
btn2.click(fn=predict_team_and_email, inputs=input2, outputs=[output_team2, output_email2]) | |
btn3.click(fn=predict_team_and_email, inputs=input3, outputs=[output_team3, output_email3]) | |
btn4.click(fn=predict_team_and_email, inputs=input4, outputs=[output_team4, output_email4]) | |
btn5.click(fn=predict_team_and_email, inputs=input5, outputs=[output_team5, output_email5]) | |
btn6.click(fn=predict_team_and_email, inputs=input6, outputs=[output_team6, output_email6]) | |
# Launch the interface | |
interface.launch() | |