Character-Level GPT Model Trained on WikiText-2
This repository contains a character-level GPT model trained on the WikiText-2 dataset. The model architecture is a custom implementation of a transformer-based language model with Rotary Positional Embeddings (RoPE) and SwiGLU feed-forward networks.
This was Largely AI generated.
As in all of the code to make the model.
And a large amount of this description.
Model Description
- Model Type: Custom Transformer-based Causal Language Model
- Training Data: WikiText-2
- Tokenization: Character-level
- Architecture Details:
- Rotary Positional Embeddings (RoPE)
- SwiGLU Feed-Forward Networks
- Parameters:
n_layer
: 8n_head
: 8n_embd
: 512block_size
: 512dropout
: 0.1vocab_size
: 283
Intended Use
This model is intended for research purposes, including:
- Experimenting with character-level language modeling.
- Studying the effects of different training techniques on transformer models.
Limitations
- Character-Level Tokenization: The model uses character-level tokenization, which is less efficient than subword tokenization (e.g., Byte-Pair Encoding) for capturing long-range dependencies and generating coherent text. As a result, the quality of generated text may be limited compared to models using subword tokenization.
- Limited Training Data: The model was trained on the WikiText-2 dataset, which is relatively small. Training on a larger dataset would likely improve performance.
- Custom Architecture: This is a custom model implementation, not a standard pre-trained model from the
transformers
library. - Requires Manual Intervention: Loading and using this model requires manual intervention and a deeper understanding of the architecture. The
AutoModelForCausalLM
class fromtransformers
cannot be used.
Inference Code.
import torch
import torch.nn as nn
from torch.nn import functional as F
import math
import json
import os
from huggingface_hub import hf_hub_download
# --- Configuration ---
repo_id = "Ma7ee7/WikiGPT-25M"
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
# --- Download Necessary Files ---
print(f"Downloading files from {repo_id}...")
try:
config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
weights_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin")
vocab_path = hf_hub_download(repo_id=repo_id, filename="vocab.json")
# tokenizer_config_path = hf_hub_download(repo_id=repo_id, filename="tokenizer_config.json") # Optional
print("Files downloaded successfully.")
except Exception as e:
print(f"Error downloading files: {e}")
print("Please ensure the repository ID is correct and the files exist.")
exit()
# --- Load Configuration ---
print("Loading configuration...")
try:
with open(config_path, 'r') as f:
config = json.load(f)
print("Configuration loaded:")
print(config)
# Extract necessary hyperparameters from config
vocab_size = config["vocab_size"]
n_layer = config["n_layer"]
n_head = config["n_head"]
n_embd = config["n_embd"]
block_size = config["block_size"]
dropout = config["dropout"]
except Exception as e:
print(f"Error loading config.json: {e}")
exit()
# --- RoPE Helper Functions ---
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, device='cpu'):
freqs_part = torch.arange(0, dim, 2)[: (dim // 2)].float() / dim
freqs = 1.0 / (theta ** freqs_part).to(device)
t = torch.arange(end, device=device)
freqs = torch.outer(t, freqs).float()
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
if not freqs_cis.shape == (x.shape[1], x.shape[-1]): # Check dimensions used
raise ValueError(f"Freqs shape {freqs_cis.shape} does not match x shape {x.shape} at dims 1 and -1")
shape = [1] * (ndim - 2) + list(freqs_cis.shape)
return freqs_cis.view(*shape)
def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor):
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out_complex = xq_ * freqs_cis
xk_out_complex = xk_ * freqs_cis
xq_out = torch.view_as_real(xq_out_complex).flatten(start_dim=2)
xk_out = torch.view_as_real(xk_out_complex).flatten(start_dim=2)
return xq_out.type_as(xq), xk_out.type_as(xk)
# --- End RoPE Helpers ---
class Head(nn.Module):
""" one head of self-attention with RoPE """
def __init__(self, head_size):
super().__init__()
# Use parameters from loaded config directly if possible, or pass them
self.key = nn.Linear(n_embd, head_size, bias=False)
self.query = nn.Linear(n_embd, head_size, bias=False)
self.value = nn.Linear(n_embd, head_size, bias=False)
# Register buffer requires size, use block_size from config
self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
self.dropout = nn.Dropout(dropout) # Use dropout from config
def forward(self, x, freqs_cis):
B, T, C = x.shape
k = self.key(x)
q = self.query(x)
q_rope, k_rope = apply_rotary_emb(q, k, freqs_cis=freqs_cis)
head_size = q_rope.shape[-1]
# Use tensor for scale
scale = torch.tensor(head_size ** -0.5, device=q_rope.device, dtype=q_rope.dtype)
wei = (q_rope @ k_rope.transpose(-2, -1)) * scale
mask = self.tril[:T, :T] == 0
wei = wei.masked_fill(mask.to(wei.device), float('-inf')) # Ensure mask on correct device
wei = F.softmax(wei, dim=-1)
wei = self.dropout(wei)
v = self.value(x)
out = wei @ v
return out
class MultiHeadAttention(nn.Module):
""" multiple heads of self-attention in parallel with RoPE """
def __init__(self, num_heads, head_size):
super().__init__()
self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
self.proj = nn.Linear(n_embd, n_embd) # head_size * num_heads = n_embd
self.dropout = nn.Dropout(dropout) # Use dropout from config
def forward(self, x, freqs_cis):
head_outputs = [h(x, freqs_cis) for h in self.heads]
out = torch.cat(head_outputs, dim=-1)
out = self.dropout(self.proj(out))
return out
class SwiGLU(nn.Module):
""" SwiGLU Feed-Forward Network """
def __init__(self, n_embd, hidden_dim=None, dropout=0.1): # Allow dropout override
super().__init__()
if hidden_dim is None:
hidden_dim = int(4 * n_embd * 2 / 3)
self.w1 = nn.Linear(n_embd, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, n_embd, bias=False)
self.w3 = nn.Linear(n_embd, hidden_dim, bias=False)
self.dropout = nn.Dropout(dropout) # Use passed dropout
def forward(self, x):
gate = self.w3(x)
value = self.w1(x)
swish_gate = F.silu(gate)
out = swish_gate * value
out = self.dropout(self.w2(out))
return out
class Block(nn.Module):
""" Transformer block using RoPE attention and SwiGLU FFN """
def __init__(self, n_embd, n_head):
super().__init__()
head_size = n_embd // n_head
self.sa = MultiHeadAttention(n_head, head_size)
# Pass dropout rate from config to SwiGLU
self.ffwd = SwiGLU(n_embd, dropout=dropout)
self.ln1 = nn.LayerNorm(n_embd)
self.ln2 = nn.LayerNorm(n_embd)
def forward(self, x, freqs_cis):
x = x + self.sa(self.ln1(x), freqs_cis)
x = x + self.ffwd(self.ln2(x))
return x
class GPTLanguageModel(nn.Module):
def __init__(self, vocab_size, n_embd, n_head, n_layer, block_size, dropout):
super().__init__()
self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
self.blocks = nn.ModuleList([Block(n_embd, n_head=n_head) for _ in range(n_layer)])
self.ln_f = nn.LayerNorm(n_embd)
self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)
self.block_size = block_size # Store block_size
# Precompute RoPE frequencies using parameters
head_dim = n_embd // n_head
# Pass device explicitly during precomputation if model is moved later
freqs_cis_buffer = precompute_freqs_cis(head_dim, block_size * 2, device='cpu') # Compute on CPU first
self.register_buffer("freqs_cis", freqs_cis_buffer, persistent=False) # Register but don't save in state_dict
# No weight initialization needed here for inference
def forward(self, idx, targets=None):
B, T = idx.shape
# Ensure T doesn't exceed block size for freqs_cis slicing
# In generate, idx_cond handles this, but good check here too
T_used = min(T, self.block_size)
tok_emb = self.token_embedding_table(idx[:, -T_used:]) # Use only last block_size tokens if T > block_size
# Retrieve precomputed RoPE frequencies for the actual sequence length T_used
# Move required part of freqs_cis to the same device as embeddings
freqs_cis_for_block = self.freqs_cis[:T_used].to(tok_emb.device)
x = tok_emb
for block in self.blocks:
x = block(x, freqs_cis_for_block)
x = self.ln_f(x)
if targets is not None:
# This path isn't typically used during inference
logits = self.lm_head(x)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
else:
# Inference: compute logits only for the last token
logits = self.lm_head(x[:, [-1], :]) # (B, 1, vocab_size)
loss = None
return logits, loss
@torch.no_grad() # Ensure no gradients are computed during generation
def generate(self, idx, max_new_tokens):
self.eval() # Ensure model is in eval mode
for _ in range(max_new_tokens):
# Crop context if it exceeds block size *before* forward pass
idx_cond = idx if idx.size(1) <= self.block_size else idx[:, -self.block_size:]
# Forward pass for inference (gets logits for the last token)
logits, _ = self(idx_cond) # Call forward with targets=None
logits = logits[:, -1, :] # Shape (B, vocab_size)
# Softmax and sampling
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
# Append
idx = torch.cat((idx, idx_next), dim=1)
self.train() # Optional: set back to train mode if needed elsewhere
return idx
# --- End Model Definitions ---
# --- Instantiate Model ---
print("Instantiating model...")
try:
model = GPTLanguageModel(
vocab_size=vocab_size,
n_embd=n_embd,
n_head=n_head,
n_layer=n_layer,
block_size=block_size,
dropout=dropout
)
print("Model instantiated.")
except Exception as e:
print(f"Error instantiating model: {e}")
print("Ensure the class definitions above match the configuration.")
exit()
# --- Load Weights ---
print("Loading model weights...")
try:
state_dict = torch.load(weights_path, map_location=torch.device('cpu')) # Load to CPU first
# Adapt state_dict if necessary (e.g., if module names changed)
# Example: remove unexpected keys like 'freqs_cis' if they were accidentally saved
state_dict.pop("freqs_cis", None)
load_result = model.load_state_dict(state_dict, strict=True) # Use strict=True initially
print(f"Weight loading result: {load_result}")
print("Model weights loaded successfully.")
except Exception as e:
print(f"Error loading weights: {e}")
print("Ensure the model architecture definition matches the saved weights.")
# If using strict=False, check missing/unexpected keys printed by load_state_dict
exit()
# --- Setup for Inference ---
model.eval() # Set to evaluation mode (disable dropout etc.)
model.to(device) # Move model to target device
print(f"Model moved to {device} and set to evaluation mode.")
# --- Load Vocabulary and Define Encode/Decode ---
print("Loading vocabulary...")
try:
with open(vocab_path, 'r', encoding='utf-8') as f:
stoi = json.load(f)
itos = {i: ch for ch, i in stoi.items()}
print(f"Vocabulary loaded ({len(stoi)} chars).")
# Define encoding/decoding functions based on loaded vocab
encode = lambda s: [stoi.get(c, stoi.get('\n')) for c in s] # Use \n as fallback? Or a dedicated UNK?
decode = lambda l: ''.join([itos.get(i, '?') for i in l]) # Use ? for unknown indices
except Exception as e:
print(f"Error loading vocabulary: {e}")
exit()
# --- Inference ---
print("\n--- Starting Inference ---")
prompt = " = = Artificial Intelligence in Medicine = = \n\n Artificial intelligence ( AI ) has" # Example prompt
max_tokens_to_generate = 500
print(f"Prompt:\n{prompt}")
print(f"\nGenerating ({max_tokens_to_generate} tokens)...")
# Encode the prompt
encoded_prompt = torch.tensor(encode(prompt), dtype=torch.long, device=device).unsqueeze(0)
# Generate text
# Ensure generate method uses torch.no_grad() internally
generated_ids = model.generate(encoded_prompt, max_new_tokens=max_tokens_to_generate)
generated_text = decode(generated_ids[0].tolist()) # Decode the generated indices
print("\n--- Generated Text ---")
print(generated_text)
print("\n----------------------")
This model was trained in about 1 hour and 30 minutes, so it has basic word connection (In English) but thats about it.
- Downloads last month
- 17
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
๐
Ask for provider support