padmanabhbosamia commited on
Commit
59124d9
·
verified ·
1 Parent(s): fb71735

Upload 3 files

Browse files
Files changed (3) hide show
  1. best_model.pt +3 -0
  2. requirements.txt +8 -0
  3. transformer.py +238 -0
best_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a29447e459ca602f675bd542b2e9c4cc696d6e1737d42d905944e532b8b302b9
3
+ size 536456766
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ transformers>=4.21.0
3
+ tiktoken>=0.5.1
4
+ gradio>=3.50.2
5
+ wandb>=0.15.12
6
+ tqdm>=4.65.0
7
+ huggingface-hub>=0.19.4
8
+ triton>=2.0.0
transformer.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn import functional as F
5
+ from dataclasses import dataclass
6
+
7
+ @dataclass
8
+ class GPTConfig:
9
+ block_size: int = 1024
10
+ vocab_size: int = 50257
11
+ n_layer: int = 12
12
+ n_head: int = 12
13
+ n_embd: int = 768
14
+ dropout: float = 0.1
15
+ bias: bool = True
16
+
17
+ class LayerNorm(nn.Module):
18
+ def __init__(self, ndim, bias):
19
+ super().__init__()
20
+ self.weight = nn.Parameter(torch.ones(ndim))
21
+ self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
22
+
23
+ def forward(self, x):
24
+ return F.layer_norm(x, self.weight.shape, self.weight, self.bias, 1e-5)
25
+
26
+ class CausalSelfAttention(nn.Module):
27
+ def __init__(self, config):
28
+ super().__init__()
29
+ assert config.n_embd % config.n_head == 0
30
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
31
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
32
+ self.attn_dropout = nn.Dropout(config.dropout)
33
+ self.resid_dropout = nn.Dropout(config.dropout)
34
+ self.n_head = config.n_head
35
+ self.n_embd = config.n_embd
36
+ self.dropout = config.dropout
37
+ self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
38
+ .view(1, 1, config.block_size, config.block_size))
39
+
40
+ def forward(self, x):
41
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
42
+
43
+ # calculate query, key, values for all heads in batch
44
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
45
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
46
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
47
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
48
+
49
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
50
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
51
+ att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
52
+ att = F.softmax(att, dim=-1)
53
+ att = self.attn_dropout(att)
54
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
55
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
56
+
57
+ # output projection
58
+ y = self.resid_dropout(self.c_proj(y))
59
+ return y
60
+
61
+ class MLP(nn.Module):
62
+ def __init__(self, config):
63
+ super().__init__()
64
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
65
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
66
+ self.dropout = nn.Dropout(config.dropout)
67
+
68
+ def forward(self, x):
69
+ x = F.gelu(self.c_fc(x))
70
+ x = self.dropout(self.c_proj(x))
71
+ return x
72
+
73
+ class Block(nn.Module):
74
+ def __init__(self, config):
75
+ super().__init__()
76
+ self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
77
+ self.attn = CausalSelfAttention(config)
78
+ self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
79
+ self.mlp = MLP(config)
80
+
81
+ def forward(self, x):
82
+ x = x + self.attn(self.ln_1(x))
83
+ x = x + self.mlp(self.ln_2(x))
84
+ return x
85
+
86
+ class GPT(nn.Module):
87
+ def __init__(self, config):
88
+ super().__init__()
89
+ assert config.vocab_size is not None
90
+ assert config.block_size is not None
91
+ self.config = config
92
+
93
+ # Add device attribute
94
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
95
+
96
+ self.transformer = nn.ModuleDict(dict(
97
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
98
+ wpe = nn.Embedding(config.block_size, config.n_embd),
99
+ drop = nn.Dropout(config.dropout),
100
+ h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
101
+ ln_f = LayerNorm(config.n_embd, bias=config.bias),
102
+ ))
103
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
104
+
105
+ # init all weights
106
+ self.apply(self._init_weights)
107
+ # apply special scaled init to the residual projections, per GPT-2 paper
108
+ for pn, p in self.named_parameters():
109
+ if pn.endswith('c_proj.weight'):
110
+ torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
111
+
112
+ # report number of parameters
113
+ print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))
114
+
115
+ def get_num_params(self, non_embedding=True):
116
+ """
117
+ Return the number of parameters in the model.
118
+ For non-embedding count (default), the position embeddings get subtracted.
119
+ """
120
+ n_params = sum(p.numel() for p in self.parameters())
121
+ if non_embedding:
122
+ n_params -= self.transformer.wpe.weight.numel()
123
+ return n_params
124
+
125
+ def _init_weights(self, module):
126
+ if isinstance(module, nn.Linear):
127
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
128
+ if module.bias is not None:
129
+ torch.nn.init.zeros_(module.bias)
130
+ elif isinstance(module, nn.Embedding):
131
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
132
+
133
+ def gradient_checkpointing_enable(self):
134
+ """
135
+ Enable gradient checkpointing for memory efficiency
136
+ """
137
+ self.gradient_checkpointing = True
138
+
139
+ def gradient_checkpointing_disable(self):
140
+ """
141
+ Disable gradient checkpointing
142
+ """
143
+ self.gradient_checkpointing = False
144
+
145
+ def forward(self, idx, targets=None):
146
+ device = idx.device
147
+ b, t = idx.size()
148
+ assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
149
+ pos = torch.arange(0, t, dtype=torch.long, device=device)
150
+
151
+ # forward the GPT model itself
152
+ tok_emb = self.transformer.wte(idx)
153
+ pos_emb = self.transformer.wpe(pos)
154
+ x = self.transformer.drop(tok_emb + pos_emb)
155
+
156
+ # Modified forward pass to use gradient checkpointing
157
+ if hasattr(self, 'gradient_checkpointing') and self.gradient_checkpointing:
158
+ for block in self.transformer.h:
159
+ x = torch.utils.checkpoint.checkpoint(block, x)
160
+ else:
161
+ for block in self.transformer.h:
162
+ x = block(x)
163
+
164
+ x = self.transformer.ln_f(x)
165
+
166
+ if targets is not None:
167
+ logits = self.lm_head(x)
168
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
169
+ else:
170
+ logits = self.lm_head(x[:, [-1], :])
171
+ loss = None
172
+
173
+ return logits, loss
174
+
175
+ def crop_block_size(self, block_size):
176
+ # model surgery to decrease the block size if necessary
177
+ # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024)
178
+ # but want to use a smaller block size for training
179
+ assert block_size <= self.config.block_size
180
+ self.config.block_size = block_size
181
+ self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size])
182
+ for block in self.transformer.h:
183
+ if hasattr(block.attn, 'bias'):
184
+ block.attn.bias = block.attn.bias[:,:,:block_size,:block_size]
185
+
186
+ @classmethod
187
+ def from_pretrained(cls, model_type):
188
+ """
189
+ Initialize a pretrained GPT model by copying over the weights
190
+ from a huggingface/transformers checkpoint.
191
+ """
192
+ assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
193
+ from transformers import GPT2LMHeadModel
194
+
195
+ # create a from-scratch initialized minGPT model
196
+ config = GPTConfig()
197
+ config.block_size = 1024 # always use block size 1024 for GPT2 models
198
+
199
+ # update config based on model type
200
+ if model_type == 'gpt2':
201
+ config.n_layer = 12; config.n_head = 12; config.n_embd = 768
202
+ elif model_type == 'gpt2-medium':
203
+ config.n_layer = 24; config.n_head = 16; config.n_embd = 1024
204
+ elif model_type == 'gpt2-large':
205
+ config.n_layer = 36; config.n_head = 20; config.n_embd = 1280
206
+ elif model_type == 'gpt2-xl':
207
+ config.n_layer = 48; config.n_head = 25; config.n_embd = 1600
208
+
209
+ # create the model
210
+ model = GPT(config)
211
+ sd = model.state_dict()
212
+
213
+ # init a huggingface/transformers model
214
+ model_hf = GPT2LMHeadModel.from_pretrained(model_type)
215
+ sd_hf = model_hf.state_dict()
216
+
217
+ # copy while ensuring all of the parameters are aligned and match in names and shapes
218
+ keys = [k for k in sd_hf if not k.endswith('attn.masked_bias')] # ignore these
219
+ transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
220
+
221
+ for k in keys:
222
+ if any(k.endswith(w) for w in transposed):
223
+ # special treatment for the Conv1D weights we need to transpose
224
+ assert sd_hf[k].shape[::-1] == sd[k].shape
225
+ with torch.no_grad():
226
+ sd[k].copy_(sd_hf[k].t())
227
+ else:
228
+ # vanilla copy over the other parameters
229
+ assert sd_hf[k].shape == sd[k].shape
230
+ with torch.no_grad():
231
+ sd[k].copy_(sd_hf[k])
232
+
233
+ return model
234
+
235
+ def to(self, device):
236
+ """Override to method to also update device attribute"""
237
+ self.device = device
238
+ return super().to(device)