|
import gradio as gr |
|
import torch |
|
|
|
from transformers import BertTokenizer, BertForSequenceClassification |
|
|
|
|
|
model_name = "crimson78/spam_classifier_models" |
|
tokenizer = BertTokenizer.from_pretrained(model_name) |
|
model = BertForSequenceClassification.from_pretrained(model_name) |
|
|
|
|
|
LABELS = {0: "HAM", 1: "SPAM"} |
|
|
|
|
|
def classify_text(message): |
|
inputs = tokenizer(message, return_tensors="pt", truncation=True, padding=True, max_length=512) |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
logits = outputs.logits |
|
prediction = torch.argmax(logits, dim=-1).item() |
|
return LABELS[prediction] |
|
|
|
|
|
iface = gr.Interface( |
|
fn=classify_text, |
|
inputs=gr.Textbox(label="Enter your message"), |
|
outputs=gr.Label(label="Classification"), |
|
title="Spam Classifier", |
|
description="Enter a message to check if it's SPAM or HAM using a fine-tuned BERT model.", |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
iface.launch() |
|
|