|
|
|
import pandas as pd |
|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
import torch.nn.functional as F |
|
from torch.utils.data import Dataset, DataLoader |
|
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification |
|
from sklearn.model_selection import train_test_split |
|
from sklearn.feature_extraction.text import CountVectorizer |
|
from sklearn.metrics import classification_report, accuracy_score |
|
import gradio as gr |
|
|
|
|
|
file_path = 'spam_ham_dataset.csv' |
|
df = pd.read_csv(file_path) |
|
|
|
|
|
df['label_num'] = df['label'].astype('category').cat.codes |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") |
|
|
|
|
|
encodings = tokenizer(df['text'].tolist(), padding=True, truncation=True, max_length=128, return_tensors="pt") |
|
labels = torch.tensor(df['label_num'].values) |
|
|
|
|
|
class SpamDataset(Dataset): |
|
def __init__(self, encodings, labels): |
|
self.encodings = encodings |
|
self.labels = labels |
|
|
|
def __len__(self): |
|
return len(self.labels) |
|
|
|
def __getitem__(self, idx): |
|
item = {key: val[idx] for key, val in self.encodings.items()} |
|
item['labels'] = torch.tensor(self.labels[idx], dtype=torch.long) |
|
return item |
|
|
|
|
|
dataset = SpamDataset(encodings, labels) |
|
|
|
|
|
train_size = int(0.8 * len(dataset)) |
|
val_size = len(dataset) - train_size |
|
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size]) |
|
|
|
def get_top_words(corpus, n=None): |
|
vec = CountVectorizer(stop_words='english').fit(corpus) |
|
bag_of_words = vec.transform(corpus) |
|
sum_words = bag_of_words.sum(axis=0) |
|
words_freq = [(word, sum_words[0, idx]) for word, idx in vec.vocabulary_.items()] |
|
words_freq = sorted(words_freq, key=lambda x: x[1], reverse=True) |
|
return words_freq[:n] |
|
|
|
|
|
def collate_fn(batch): |
|
keys = batch[0].keys() |
|
collated = {key: torch.stack([b[key] for b in batch]) for key in keys} |
|
return collated |
|
|
|
|
|
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn) |
|
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, collate_fn=collate_fn) |
|
|
|
|
|
def load_model(model_path="distilbert_spam_model.pt"): |
|
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2) |
|
model.load_state_dict(torch.load(model_path, map_location=device)) |
|
model.to(device) |
|
model.eval() |
|
return model |
|
|
|
|
|
model = load_model() |
|
|
|
|
|
def classify_email(email_text): |
|
model.eval() |
|
|
|
with torch.no_grad(): |
|
inputs = tokenizer(email_text, padding=True, truncation=True, max_length=256, return_tensors="pt") |
|
inputs = {key: val.to(device) for key, val in inputs.items()} |
|
outputs = model(**inputs) |
|
logits = outputs.logits |
|
predictions = torch.argmax(logits, dim=1) |
|
probs = F.softmax(logits, dim=1) |
|
confidence = torch.max(probs).item() * 100 |
|
|
|
result = "Spam" if predictions.item() == 1 else "Ham" |
|
return result, f"{confidence:.2f}%" |
|
|
|
|
|
def evaluate_model_with_report(val_loader): |
|
model.eval() |
|
y_true = [] |
|
y_pred = [] |
|
correct = 0 |
|
total = 0 |
|
|
|
with torch.no_grad(): |
|
for batch in val_loader: |
|
inputs = {key: val.to(device) for key, val in batch.items()} |
|
labels = inputs.pop("labels").to(device) |
|
|
|
outputs = model(**inputs) |
|
predictions = torch.argmax(outputs.logits, dim=1) |
|
|
|
|
|
y_true.extend(labels.cpu().numpy()) |
|
y_pred.extend(predictions.cpu().numpy()) |
|
|
|
|
|
correct += (predictions == labels).sum().item() |
|
total += labels.size(0) |
|
|
|
|
|
accuracy = correct / total if total > 0 else 0 |
|
print(f"Validation Accuracy: {accuracy:.4f}") |
|
|
|
|
|
print("\nClassification Report:") |
|
print(classification_report(y_true, y_pred, target_names=["Ham", "Spam"])) |
|
|
|
return accuracy |
|
|
|
|
|
def generate_performance_metrics(): |
|
model.eval() |
|
|
|
y_true = [] |
|
y_pred = [] |
|
|
|
with torch.no_grad(): |
|
for batch in val_loader: |
|
inputs = {key: val.to(device) for key, val in batch.items()} |
|
labels = inputs.pop("labels").to(device) |
|
|
|
outputs = model(**inputs) |
|
predictions = torch.argmax(outputs.logits, dim=1) |
|
|
|
y_true.extend(labels.cpu().numpy()) |
|
y_pred.extend(predictions.cpu().numpy()) |
|
|
|
|
|
accuracy = accuracy_score(y_true, y_pred) |
|
report = classification_report(y_true, y_pred, output_dict=True) |
|
|
|
return { |
|
"accuracy": f"{accuracy:.2%}", |
|
"precision": f"{report['1']['precision']:.2%}", |
|
"recall": f"{report['1']['recall']:.2%}", |
|
"f1_score": f"{report['1']['f1-score']:.2%}", |
|
} |
|
|
|
|
|
|
|
|
|
|
|
def create_interface(): |
|
performance_metrics = generate_performance_metrics() |
|
with gr.Blocks() as interface: |
|
with gr.Tab(" 📨 Demo"): |
|
gr.Markdown(" # 📧🔍 Spam and Phishing Email Detection") |
|
gr.Markdown( |
|
""" |
|
Welcome to the Spam and Phishing Email Detection Demo! This tool leverages DistilBERT, a lightweight yet powerful transformer model, to classify emails as ham (legitimate), spam, or phishing based on their content. |
|
This project aims to enhance email security by identifying malicious messages with high accuracy, reducing the risk of scams and fraud. Feel free to explore the demo and see how AI can provide a safer environment for everyone. |
|
""") |
|
|
|
|
|
|
|
email_input = gr.Textbox( |
|
lines=8, placeholder="Type or paste your email content here...", label="Email Content" |
|
) |
|
|
|
|
|
result_output = gr.Textbox(label="Classification Result") |
|
confidence_output = gr.Textbox(label="Confidence Score", interactive=False) |
|
|
|
analyze_button = gr.Button("Analyze Email") |
|
|
|
def email_analysis_pipeline(email_text): |
|
results = classify_email(email_text) |
|
return ( |
|
results["result"], |
|
results["confidence"] |
|
) |
|
|
|
analyze_button.click( |
|
fn=classify_email, |
|
inputs=email_input, |
|
outputs=[result_output, confidence_output] |
|
) |
|
|
|
with gr.Tab(" 📈 Analysis"): |
|
gr.Markdown("## Dataset Overview") |
|
gr.Markdown("### Dataet Headers") |
|
gr.DataFrame(df) |
|
|
|
|
|
gr.Markdown("### Top Spam Words") |
|
top_spam_words = get_top_words(df[df['label'] == "spam"]['text'], n=10) |
|
gr.DataFrame(top_spam_words) |
|
|
|
|
|
gr.Markdown("### Top Ham Words") |
|
top_ham_words = get_top_words(df[df['label'] == "ham"]['text'], n=10) |
|
gr.DataFrame(top_ham_words) |
|
|
|
gr.Markdown("## 📊 Model Performance Analytics") |
|
with gr.Row(): |
|
gr.Textbox(value=performance_metrics["accuracy"], label="Accuracy", interactive=False) |
|
gr.Textbox(value=performance_metrics["precision"], label="Precision", interactive=False) |
|
gr.Textbox(value=performance_metrics["recall"], label="Recall", interactive=False) |
|
gr.Textbox(value=performance_metrics["f1_score"], label="F1 Score", interactive=False) |
|
|
|
with gr.Tab("📜 Glossary"): |
|
with gr.Column(): |
|
gr.Markdown( |
|
""" |
|
## Label Definitions |
|
- Spam: Unwanted or potentially harmful emails detected by the system. |
|
- Ham: Legitimate and safe emails. |
|
|
|
## Evaluation Metrics |
|
- Accuracy: Measures the percentage of correctly classified emails. |
|
- Precision: Out of all emails classified as spam, how many were actually spam? |
|
- Recall: Out of all actual spam emails, how many were identified correctly? |
|
- F1 Score: A balance between precision and recall for overall performance assessment. |
|
|
|
""" |
|
) |
|
with gr.Column(): |
|
gr.Markdown(" ## 🔍 Libraries Used and Their Objectives") |
|
gr.Markdown( |
|
""" |
|
### 1. Pandas (import pandas as pd) |
|
|
|
Objective: Data manipulation and preprocessing. |
|
Justification: Used for loading, cleaning, and structuring the email dataset for analysis and model training. |
|
|
|
### 2. NumPy (import numpy as np) |
|
|
|
Objective: Efficient numerical operations. |
|
Justification: Facilitates handling large datasets and computations, such as text vectorization and matrix operations. |
|
|
|
### 3. Torch & Torch-related Libraries |
|
|
|
import torch – Core deep learning framework for model training. |
|
import torch.nn as nn – Defines deep learning model architecture. |
|
import torch.optim as optim – Implements optimization algorithms. |
|
import torch.nn.functional as F – Provides additional functions like activation and loss functions. |
|
from torch.utils.data import Dataset, DataLoader – Handles data batching and loading for model training. |
|
Justification: Essential for training and fine-tuning DistilBERT on email classification. |
|
|
|
### 4. Transformers (from transformers import DistilBertTokenizer, DistilBertForSequenceClassification) |
|
|
|
Objective: Tokenization and model training using DistilBERT. |
|
Justification: DistilBERT offers a lighter yet powerful alternative to BERT, improving efficiency while maintaining accuracy. |
|
|
|
### 5. Scikit-learn (sklearn) |
|
|
|
Feature Extraction: |
|
CountVectorizer: Converts text into a matrix of token counts. |
|
TfidfVectorizer: Converts text into TF-IDF features, which measure the importance of words in documents. |
|
Model Training & Evaluation: |
|
MultinomialNB: Implements the Naïve Bayes classifier for a baseline model. |
|
train_test_split: Splits the dataset for training and testing. |
|
classification_report, accuracy_score, precision_score, recall_score, f1_score: Computes evaluation metrics. |
|
Justification: Used for feature extraction, baseline modeling, and performance evaluation of different models. |
|
|
|
### 6. Matplotlib & Seaborn (import matplotlib.pyplot as plt, import seaborn as sns) |
|
|
|
Objective: Data visualization. |
|
Justification: Used to visualize word distributions, spam vs. ham comparisons, and model performance metrics. |
|
|
|
### 7. Gradio (import gradio as gr) |
|
|
|
Objective: Building an interactive web-based demo. |
|
Justification: Allows users to test the spam detection system by inputting emails and viewing real-time predictions. |
|
""") |
|
with gr.Column(): |
|
gr.Markdown("## 🎉 Thanks & Acknowledgments 🎉") |
|
gr.Markdown(""" |
|
### 🙌 Special Thanks to Our Contributors |
|
|
|
**🔹 Remus** |
|
- Led **Data Collection & Preprocessing**, ensuring a clean dataset for training. |
|
- Developed the **Baseline Model**, which served as the foundation for further improvements. |
|
- Fine-tuned **BERT**, optimizing hyperparameters to enhance accuracy. |
|
|
|
**🔹 Ashley** |
|
- Played a key role in **Data Collection & Preprocessing**, improving dataset quality. |
|
- Successfully handled the **Deployment on Hugging Face**, making the model accessible to users. |
|
- Implemented and optimized **DistilBERT**, achieving a balance between speed and performance. |
|
|
|
This project was a collaborative effort, and we appreciate the hard work put into making it a success! 🚀 |
|
""") |
|
|
|
|
|
return interface |
|
|
|
|
|
interface = create_interface() |
|
interface.launch(share=True) |
|
|