ITI110 / app.py
crimson78's picture
Update app.py
271f25e verified
import gradio as gr
import torch
# from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers import BertTokenizer, BertForSequenceClassification
# Load the model and tokenizer from Hugging Face
model_name = "crimson78/spam_classifier_models" # Replace with your model's actual name
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(model_name)
# Define label mapping (Assuming 0 = HAM, 1 = SPAM)
LABELS = {0: "HAM", 1: "SPAM"}
# Function to classify the input text
def classify_text(message):
inputs = tokenizer(message, return_tensors="pt", truncation=True, padding=True, max_length=512)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
prediction = torch.argmax(logits, dim=-1).item()
return LABELS[prediction]
# Gradio interface
iface = gr.Interface(
fn=classify_text,
inputs=gr.Textbox(label="Enter your message"),
outputs=gr.Label(label="Classification"),
title="Spam Classifier",
description="Enter a message to check if it's SPAM or HAM using a fine-tuned BERT model.",
)
# Run the app
if __name__ == "__main__":
iface.launch()