leygit commited on
Commit
c907c32
·
verified ·
1 Parent(s): 933a235

Upload distilbert.py

Browse files
Files changed (1) hide show
  1. distilbert.py +117 -0
distilbert.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """DistilBERT.ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1qXwFT-lCqgfmQYxeJ7cb-iuvTLqLkiim
8
+ """
9
+
10
+ #DISTILLBERT RUN 3 , added weight_decay=0.01
11
+ import pandas as pd
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.optim as optim
15
+ import torch.nn.functional as F
16
+ from torch.utils.data import Dataset, DataLoader
17
+ from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
18
+ from sklearn.model_selection import train_test_split
19
+ from sklearn.metrics import classification_report
20
+ from transformers import BertTokenizer
21
+
22
+
23
+ # Load dataset
24
+ file_path = 'spam_ham_dataset.csv'
25
+ df = pd.read_csv(file_path)
26
+
27
+ # Convert labels to numeric
28
+ df['label_num'] = df['label'].map({'ham': 0, 'spam': 1})
29
+
30
+ # Load tokenizer
31
+ tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
32
+
33
+ # Tokenize dataset
34
+ encodings = tokenizer(df['text'].tolist(), padding=True, truncation=True, max_length=128, return_tensors="pt")
35
+ labels = torch.tensor(df['label_num'].values)
36
+
37
+ # Custom Dataset
38
+ class SpamDataset(Dataset):
39
+ def __init__(self, encodings, labels):
40
+ self.encodings = encodings
41
+ self.labels = labels
42
+
43
+ def __len__(self):
44
+ return len(self.labels)
45
+
46
+ def __getitem__(self, idx):
47
+ item = {key: val[idx] for key, val in self.encodings.items()}
48
+ item['labels'] = torch.tensor(self.labels[idx], dtype=torch.long)
49
+ return item
50
+
51
+ # Create dataset
52
+ dataset = SpamDataset(encodings, labels)
53
+
54
+ # Split dataset (80% train, 20% validation)
55
+ train_size = int(0.8 * len(dataset))
56
+ val_size = len(dataset) - train_size
57
+ train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
58
+
59
+ # DataLoader with batch size
60
+ def collate_fn(batch):
61
+ keys = batch[0].keys()
62
+ return {key: torch.stack([b[key] for b in batch]) for key in keys}
63
+
64
+ train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=collate_fn)
65
+ val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, collate_fn=collate_fn)
66
+
67
+ # Load DistilBERT model
68
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
69
+ model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
70
+ model.to(device)
71
+
72
+ # Define optimizer and loss function
73
+ optimizer = optim.AdamW(model.parameters(), lr=5e-5, weight_decay=0.01)
74
+ loss_fn = nn.CrossEntropyLoss()
75
+
76
+ # Training Loop
77
+ EPOCHS = 10
78
+ for epoch in range(EPOCHS):
79
+ model.train()
80
+ total_loss = 0
81
+
82
+ for batch in train_loader:
83
+ optimizer.zero_grad()
84
+
85
+ inputs = {key: val.to(device) for key, val in batch.items()}
86
+ labels = inputs.pop("labels").to(device)
87
+
88
+ outputs = model(**inputs)
89
+ loss = loss_fn(outputs.logits, labels)
90
+
91
+ loss.backward()
92
+ optimizer.step()
93
+
94
+ total_loss += loss.item()
95
+
96
+ avg_loss = total_loss / len(train_loader)
97
+ print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}")
98
+
99
+ # Save trained model
100
+ torch.save(model.state_dict(), "distilbert_spam_model.pt")
101
+
102
+ # Evaluation
103
+ model.eval()
104
+ correct = 0
105
+ total = 0
106
+ with torch.no_grad():
107
+ for batch in val_loader:
108
+ inputs = {key: val.to(device) for key, val in batch.items()}
109
+ labels = inputs.pop("labels").to(device)
110
+
111
+ outputs = model(**inputs)
112
+ predictions = torch.argmax(outputs.logits, dim=1)
113
+ correct += (predictions == labels).sum().item()
114
+ total += labels.size(0)
115
+
116
+ accuracy = correct / total
117
+ print(f"Validation Accuracy: {accuracy:.4f}")