leygit commited on
Commit
563702e
·
verified ·
1 Parent(s): 116aa9b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +219 -0
app.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #DISTILLBERT RUN 3 , added weight_decay=0.01
2
+ import pandas as pd
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.optim as optim
6
+ import torch.nn.functional as F
7
+ from torch.utils.data import Dataset, DataLoader
8
+ from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
9
+ from sklearn.model_selection import train_test_split
10
+ from sklearn.metrics import classification_report
11
+ from transformers import BertTokenizer
12
+
13
+ # Load dataset
14
+ file_path = 'spam_ham_dataset.csv'
15
+ df = pd.read_csv(file_path)
16
+
17
+ # Convert labels to numeric
18
+ df['label_num'] = df['label'].map({'ham': 0, 'spam': 1})
19
+
20
+ # Load tokenizer
21
+ tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
22
+
23
+ # Tokenize dataset
24
+ encodings = tokenizer(df['text'].tolist(), padding=True, truncation=True, max_length=128, return_tensors="pt")
25
+ labels = torch.tensor(df['label_num'].values)
26
+
27
+ # Custom Dataset
28
+ class SpamDataset(Dataset):
29
+ def __init__(self, encodings, labels):
30
+ self.encodings = encodings
31
+ self.labels = labels
32
+
33
+ def __len__(self):
34
+ return len(self.labels)
35
+
36
+ def __getitem__(self, idx):
37
+ item = {key: val[idx] for key, val in self.encodings.items()}
38
+ item['labels'] = torch.tensor(self.labels[idx], dtype=torch.long)
39
+ return item
40
+
41
+ # Create dataset
42
+ dataset = SpamDataset(encodings, labels)
43
+
44
+ # Split dataset (80% train, 20% validation)
45
+ train_size = int(0.8 * len(dataset))
46
+ val_size = len(dataset) - train_size
47
+ train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
48
+
49
+ # DataLoader with batch size
50
+ def collate_fn(batch):
51
+ keys = batch[0].keys()
52
+ return {key: torch.stack([b[key] for b in batch]) for key in keys}
53
+
54
+ train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=collate_fn)
55
+ val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, collate_fn=collate_fn)
56
+
57
+ # Load the trained model
58
+ def load_model(model_path="distilbert_spam_model.pt"):
59
+ model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
60
+ model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu"))) # Load model weights
61
+ model.eval() # Set model to evaluation mode
62
+ return model
63
+
64
+ # Evaluation
65
+ model.eval()
66
+ correct = 0
67
+ total = 0
68
+ with torch.no_grad():
69
+ for batch in val_loader:
70
+ inputs = {key: val.to(device) for key, val in batch.items()}
71
+ labels = inputs.pop("labels").to(device)
72
+
73
+ outputs = model(**inputs)
74
+ predictions = torch.argmax(outputs.logits, dim=1)
75
+ correct += (predictions == labels).sum().item()
76
+ total += labels.size(0)
77
+
78
+ accuracy = correct / total
79
+ print(f"Validation Accuracy: {accuracy:.4f}")
80
+
81
+
82
+
83
+ # Classification function
84
+ def classify_email(email_text):
85
+ model.eval() # Set model to evaluation mode
86
+
87
+ with torch.no_grad():
88
+ # Tokenize and convert input text to tensor
89
+ inputs = tokenizer(email_text, padding=True, truncation=True, max_length=256, return_tensors="pt")
90
+
91
+ # Move inputs to the appropriate device
92
+ inputs = {key: val.to(device) for key, val in inputs.items()}
93
+
94
+ # Get model predictions
95
+ outputs = model(**inputs)
96
+ logits = outputs.logits
97
+
98
+ # Convert logits to predicted class
99
+ predictions = torch.argmax(logits, dim=1)
100
+
101
+ # Convert logits to probabilities using softmax
102
+ probs = F.softmax(logits, dim=1)
103
+ confidence = torch.max(probs).item() * 100 # Convert to percentage
104
+
105
+ # Convert numeric prediction to label
106
+ result = "Spam" if predictions.item() == 1 else "Ham"
107
+
108
+ return {
109
+ "result": result,
110
+ "confidence": f"{confidence:.2f}%",
111
+ }
112
+
113
+ # Evaluation function with detailed classification report
114
+ def evaluate_model_with_report(val_loader):
115
+ model.eval() # Set model to evaluation mode
116
+ y_true = []
117
+ y_pred = []
118
+ correct = 0
119
+ total = 0
120
+
121
+ with torch.no_grad():
122
+ for batch in val_loader:
123
+ inputs = {key: val.to(device) for key, val in batch.items()}
124
+ labels = inputs.pop("labels").to(device)
125
+
126
+ outputs = model(**inputs)
127
+ predictions = torch.argmax(outputs.logits, dim=1)
128
+
129
+ # Collect labels and predictions
130
+ y_true.extend(labels.cpu().numpy())
131
+ y_pred.extend(predictions.cpu().numpy())
132
+
133
+ # Calculate accuracy
134
+ correct += (predictions == labels).sum().item()
135
+ total += labels.size(0)
136
+
137
+ # Calculate accuracy
138
+ accuracy = correct / total if total > 0 else 0
139
+ print(f"Validation Accuracy: {accuracy:.4f}")
140
+
141
+ # Print classification report
142
+ print("\nClassification Report:")
143
+ print(classification_report(y_true, y_pred, target_names=["Ham", "Spam"]))
144
+
145
+ return accuracy
146
+
147
+ # Run evaluation with classification report
148
+ accuracy = evaluate_model_with_report(val_loader)
149
+ print(f"Model Validation Accuracy: {accuracy:.4f}")
150
+
151
+ ## Gradio Interface
152
+
153
+ import gradio as gr
154
+
155
+ # Create Gradio Interface
156
+ def create_interface():
157
+ performance_metrics = generate_performance_metrics()
158
+
159
+ # Introduction - Title + Brief Description
160
+ with gr.Blocks(css=custom_css) as interface:
161
+ gr.Markdown("Spam Email Classification")
162
+ gr.Markdown(
163
+ """
164
+ Brief description of the project here
165
+ """
166
+ )
167
+
168
+ # Email Text Input
169
+ with gr.Row():
170
+ email_input = gr.Textbox(
171
+ lines=8, placeholder="Type or paste your email content here...", label="Email Content"
172
+ )
173
+
174
+ # Email Text Results and Analysis
175
+ with gr.Row():
176
+ result_output = gr.HTML(label="Classification Result") # label = [function that prints classification result]
177
+ confidence_output = gr.Textbox(label="Confidence Score", interactive=False)
178
+ accuracy_output = gr.Textbox(label="Accuracy", interactive=False)
179
+
180
+
181
+ analyze_button = gr.Button("Analyze Email 🕵️‍♂️")
182
+
183
+ analyze_button.click(
184
+ fn=email_analysis_pipeline,
185
+ inputs=email_input,
186
+ outputs=[result_output, confidence_output, accuracy_output]
187
+ )
188
+
189
+ # Analysis
190
+ gr.Markdown("## 📊 Model Performance Analytics")
191
+ with gr.Row():
192
+ with gr.Column():
193
+ gr.Textbox(value=performance_metrics["accuracy"], label="Accuracy", interactive=False, elem_classes=["metric"])
194
+ gr.Textbox(value=performance_metrics["precision"], label="Precision", interactive=False, elem_classes=["metric"])
195
+ gr.Textbox(value=performance_metrics["recall"], label="Recall", interactive=False, elem_classes=["metric"])
196
+ gr.Textbox(value=performance_metrics["f1_score"], label="F1 Score", interactive=False, elem_classes=["metric"])
197
+ with gr.Column():
198
+ gr.Markdown("### Confusion Matrix")
199
+ gr.HTML(f"<img src='data:image/png;base64,{performance_metrics['confusion_matrix_plot']}' style='max-width: 100%; height: auto;' />")
200
+
201
+ gr.Markdown("## 📘 Glossary and Explanation of Labels")
202
+ gr.Markdown(
203
+ """
204
+ ### Labels:
205
+ - **Spam:** Unwanted or harmful emails flagged by the system.
206
+ - **Ham:** Legitimate, safe emails.
207
+ ### Metrics:
208
+ - **Accuracy:** The percentage of correct classifications.
209
+ - **Precision:** Out of predicted Spam, how many are actually Spam.
210
+ - **Recall:** Out of all actual Spam emails, how many are predicted as Spam.
211
+ - **F1 Score:** Harmonic mean of Precision and Recall.
212
+ """
213
+ )
214
+
215
+ return interface
216
+
217
+ # Launch the interface
218
+ interface = create_interface()
219
+ interface.launch(share=True)