WikiGPT-25M / modeling_custom_char_gpt.py
Ma7ee7's picture
Upload character-level GPT model for WikiText-2
5bef0ce verified
raw
history blame contribute delete
1.89 kB
# ACTION REQUIRED: Manually copy model code here.
# See warning message in notebook output.
# Add imports like:
# import torch
# import torch.nn as nn
# etc.
import torch
import torch.nn as nn
from torch.nn import functional as F
import math
# --- RoPE Helper Functions (Copied for model file) ---
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, device='cpu'):
freqs_part = torch.arange(0, dim, 2)[: (dim // 2)].float() / dim
freqs = 1.0 / (theta ** freqs_part).to(device)
t = torch.arange(end, device=device)
freqs = torch.outer(t, freqs).float()
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
if not freqs_cis.shape == (x.shape[1], x.shape[-1]): # Check dimensions used
raise ValueError(f"Freqs shape {freqs_cis.shape} does not match x shape {x.shape} at dims 1 and -1")
shape = [1] * (ndim - 2) + list(freqs_cis.shape)
return freqs_cis.view(*shape)
def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor):
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out_complex = xq_ * freqs_cis
xk_out_complex = xk_ * freqs_cis
xq_out = torch.view_as_real(xq_out_complex).flatten(start_dim=2)
xk_out = torch.view_as_real(xk_out_complex).flatten(start_dim=2)
return xq_out.type_as(xq), xk_out.type_as(xk)
# --- End RoPE Helpers ---
# --- Add Head Class Definition Below ---
# --- Add MultiHeadAttention Class Definition Below ---
# --- Add SwiGLU Class Definition Below ---
# --- Add Block Class Definition Below ---
# --- Add GPTLanguageModel Class Definition Below ---