--- language: en license: apache-2.0 tags: - character-level - causal-lm - gpt - wikitext - custom-model pipeline_tag: text-generation --- # Character-Level GPT Model Trained on WikiText-2 This repository contains a character-level GPT model trained on the WikiText-2 dataset. The model architecture is a custom implementation of a transformer-based language model with Rotary Positional Embeddings (RoPE) and SwiGLU feed-forward networks. This was Largely AI generated. As in all of the code to make the model. And a large amount of this description. ## Model Description * **Model Type:** Custom Transformer-based Causal Language Model * **Training Data:** WikiText-2 * **Tokenization:** Character-level * **Architecture Details:** * Rotary Positional Embeddings (RoPE) * SwiGLU Feed-Forward Networks * **Parameters:** * `n_layer`: 8 * `n_head`: 8 * `n_embd`: 512 * `block_size`: 512 * `dropout`: 0.1 * `vocab_size`: 283 ## Intended Use This model is intended for research purposes, including: * Experimenting with character-level language modeling. * Studying the effects of different training techniques on transformer models. ## Limitations * **Character-Level Tokenization:** The model uses character-level tokenization, which is less efficient than subword tokenization (e.g., Byte-Pair Encoding) for capturing long-range dependencies and generating coherent text. As a result, the quality of generated text may be limited compared to models using subword tokenization. * **Limited Training Data:** The model was trained on the WikiText-2 dataset, which is relatively small. Training on a larger dataset would likely improve performance. * **Custom Architecture:** This is a custom model implementation, not a standard pre-trained model from the `transformers` library. * **Requires Manual Intervention:** Loading and using this model requires manual intervention and a deeper understanding of the architecture. The `AutoModelForCausalLM` class from `transformers` *cannot* be used. ## Inference Code. ```python import torch import torch.nn as nn from torch.nn import functional as F import math import json import os from huggingface_hub import hf_hub_download # --- Configuration --- repo_id = "Ma7ee7/WikiGPT-25M" device = 'cuda' if torch.cuda.is_available() else 'cpu' print(f"Using device: {device}") # --- Download Necessary Files --- print(f"Downloading files from {repo_id}...") try: config_path = hf_hub_download(repo_id=repo_id, filename="config.json") weights_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin") vocab_path = hf_hub_download(repo_id=repo_id, filename="vocab.json") # tokenizer_config_path = hf_hub_download(repo_id=repo_id, filename="tokenizer_config.json") # Optional print("Files downloaded successfully.") except Exception as e: print(f"Error downloading files: {e}") print("Please ensure the repository ID is correct and the files exist.") exit() # --- Load Configuration --- print("Loading configuration...") try: with open(config_path, 'r') as f: config = json.load(f) print("Configuration loaded:") print(config) # Extract necessary hyperparameters from config vocab_size = config["vocab_size"] n_layer = config["n_layer"] n_head = config["n_head"] n_embd = config["n_embd"] block_size = config["block_size"] dropout = config["dropout"] except Exception as e: print(f"Error loading config.json: {e}") exit() # --- RoPE Helper Functions --- def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, device='cpu'): freqs_part = torch.arange(0, dim, 2)[: (dim // 2)].float() / dim freqs = 1.0 / (theta ** freqs_part).to(device) t = torch.arange(end, device=device) freqs = torch.outer(t, freqs).float() freqs_cis = torch.polar(torch.ones_like(freqs), freqs) return freqs_cis def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): ndim = x.ndim assert 0 <= 1 < ndim if not freqs_cis.shape == (x.shape[1], x.shape[-1]): # Check dimensions used raise ValueError(f"Freqs shape {freqs_cis.shape} does not match x shape {x.shape} at dims 1 and -1") shape = [1] * (ndim - 2) + list(freqs_cis.shape) return freqs_cis.view(*shape) def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor): xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) freqs_cis = reshape_for_broadcast(freqs_cis, xq_) xq_out_complex = xq_ * freqs_cis xk_out_complex = xk_ * freqs_cis xq_out = torch.view_as_real(xq_out_complex).flatten(start_dim=2) xk_out = torch.view_as_real(xk_out_complex).flatten(start_dim=2) return xq_out.type_as(xq), xk_out.type_as(xk) # --- End RoPE Helpers --- class Head(nn.Module): """ one head of self-attention with RoPE """ def __init__(self, head_size): super().__init__() # Use parameters from loaded config directly if possible, or pass them self.key = nn.Linear(n_embd, head_size, bias=False) self.query = nn.Linear(n_embd, head_size, bias=False) self.value = nn.Linear(n_embd, head_size, bias=False) # Register buffer requires size, use block_size from config self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size))) self.dropout = nn.Dropout(dropout) # Use dropout from config def forward(self, x, freqs_cis): B, T, C = x.shape k = self.key(x) q = self.query(x) q_rope, k_rope = apply_rotary_emb(q, k, freqs_cis=freqs_cis) head_size = q_rope.shape[-1] # Use tensor for scale scale = torch.tensor(head_size ** -0.5, device=q_rope.device, dtype=q_rope.dtype) wei = (q_rope @ k_rope.transpose(-2, -1)) * scale mask = self.tril[:T, :T] == 0 wei = wei.masked_fill(mask.to(wei.device), float('-inf')) # Ensure mask on correct device wei = F.softmax(wei, dim=-1) wei = self.dropout(wei) v = self.value(x) out = wei @ v return out class MultiHeadAttention(nn.Module): """ multiple heads of self-attention in parallel with RoPE """ def __init__(self, num_heads, head_size): super().__init__() self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)]) self.proj = nn.Linear(n_embd, n_embd) # head_size * num_heads = n_embd self.dropout = nn.Dropout(dropout) # Use dropout from config def forward(self, x, freqs_cis): head_outputs = [h(x, freqs_cis) for h in self.heads] out = torch.cat(head_outputs, dim=-1) out = self.dropout(self.proj(out)) return out class SwiGLU(nn.Module): """ SwiGLU Feed-Forward Network """ def __init__(self, n_embd, hidden_dim=None, dropout=0.1): # Allow dropout override super().__init__() if hidden_dim is None: hidden_dim = int(4 * n_embd * 2 / 3) self.w1 = nn.Linear(n_embd, hidden_dim, bias=False) self.w2 = nn.Linear(hidden_dim, n_embd, bias=False) self.w3 = nn.Linear(n_embd, hidden_dim, bias=False) self.dropout = nn.Dropout(dropout) # Use passed dropout def forward(self, x): gate = self.w3(x) value = self.w1(x) swish_gate = F.silu(gate) out = swish_gate * value out = self.dropout(self.w2(out)) return out class Block(nn.Module): """ Transformer block using RoPE attention and SwiGLU FFN """ def __init__(self, n_embd, n_head): super().__init__() head_size = n_embd // n_head self.sa = MultiHeadAttention(n_head, head_size) # Pass dropout rate from config to SwiGLU self.ffwd = SwiGLU(n_embd, dropout=dropout) self.ln1 = nn.LayerNorm(n_embd) self.ln2 = nn.LayerNorm(n_embd) def forward(self, x, freqs_cis): x = x + self.sa(self.ln1(x), freqs_cis) x = x + self.ffwd(self.ln2(x)) return x class GPTLanguageModel(nn.Module): def __init__(self, vocab_size, n_embd, n_head, n_layer, block_size, dropout): super().__init__() self.token_embedding_table = nn.Embedding(vocab_size, n_embd) self.blocks = nn.ModuleList([Block(n_embd, n_head=n_head) for _ in range(n_layer)]) self.ln_f = nn.LayerNorm(n_embd) self.lm_head = nn.Linear(n_embd, vocab_size, bias=False) self.block_size = block_size # Store block_size # Precompute RoPE frequencies using parameters head_dim = n_embd // n_head # Pass device explicitly during precomputation if model is moved later freqs_cis_buffer = precompute_freqs_cis(head_dim, block_size * 2, device='cpu') # Compute on CPU first self.register_buffer("freqs_cis", freqs_cis_buffer, persistent=False) # Register but don't save in state_dict # No weight initialization needed here for inference def forward(self, idx, targets=None): B, T = idx.shape # Ensure T doesn't exceed block size for freqs_cis slicing # In generate, idx_cond handles this, but good check here too T_used = min(T, self.block_size) tok_emb = self.token_embedding_table(idx[:, -T_used:]) # Use only last block_size tokens if T > block_size # Retrieve precomputed RoPE frequencies for the actual sequence length T_used # Move required part of freqs_cis to the same device as embeddings freqs_cis_for_block = self.freqs_cis[:T_used].to(tok_emb.device) x = tok_emb for block in self.blocks: x = block(x, freqs_cis_for_block) x = self.ln_f(x) if targets is not None: # This path isn't typically used during inference logits = self.lm_head(x) loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) else: # Inference: compute logits only for the last token logits = self.lm_head(x[:, [-1], :]) # (B, 1, vocab_size) loss = None return logits, loss @torch.no_grad() # Ensure no gradients are computed during generation def generate(self, idx, max_new_tokens): self.eval() # Ensure model is in eval mode for _ in range(max_new_tokens): # Crop context if it exceeds block size *before* forward pass idx_cond = idx if idx.size(1) <= self.block_size else idx[:, -self.block_size:] # Forward pass for inference (gets logits for the last token) logits, _ = self(idx_cond) # Call forward with targets=None logits = logits[:, -1, :] # Shape (B, vocab_size) # Softmax and sampling probs = F.softmax(logits, dim=-1) idx_next = torch.multinomial(probs, num_samples=1) # Append idx = torch.cat((idx, idx_next), dim=1) self.train() # Optional: set back to train mode if needed elsewhere return idx # --- End Model Definitions --- # --- Instantiate Model --- print("Instantiating model...") try: model = GPTLanguageModel( vocab_size=vocab_size, n_embd=n_embd, n_head=n_head, n_layer=n_layer, block_size=block_size, dropout=dropout ) print("Model instantiated.") except Exception as e: print(f"Error instantiating model: {e}") print("Ensure the class definitions above match the configuration.") exit() # --- Load Weights --- print("Loading model weights...") try: state_dict = torch.load(weights_path, map_location=torch.device('cpu')) # Load to CPU first # Adapt state_dict if necessary (e.g., if module names changed) # Example: remove unexpected keys like 'freqs_cis' if they were accidentally saved state_dict.pop("freqs_cis", None) load_result = model.load_state_dict(state_dict, strict=True) # Use strict=True initially print(f"Weight loading result: {load_result}") print("Model weights loaded successfully.") except Exception as e: print(f"Error loading weights: {e}") print("Ensure the model architecture definition matches the saved weights.") # If using strict=False, check missing/unexpected keys printed by load_state_dict exit() # --- Setup for Inference --- model.eval() # Set to evaluation mode (disable dropout etc.) model.to(device) # Move model to target device print(f"Model moved to {device} and set to evaluation mode.") # --- Load Vocabulary and Define Encode/Decode --- print("Loading vocabulary...") try: with open(vocab_path, 'r', encoding='utf-8') as f: stoi = json.load(f) itos = {i: ch for ch, i in stoi.items()} print(f"Vocabulary loaded ({len(stoi)} chars).") # Define encoding/decoding functions based on loaded vocab encode = lambda s: [stoi.get(c, stoi.get('\n')) for c in s] # Use \n as fallback? Or a dedicated UNK? decode = lambda l: ''.join([itos.get(i, '?') for i in l]) # Use ? for unknown indices except Exception as e: print(f"Error loading vocabulary: {e}") exit() # --- Inference --- print("\n--- Starting Inference ---") prompt = " = = Artificial Intelligence in Medicine = = \n\n Artificial intelligence ( AI ) has" # Example prompt max_tokens_to_generate = 500 print(f"Prompt:\n{prompt}") print(f"\nGenerating ({max_tokens_to_generate} tokens)...") # Encode the prompt encoded_prompt = torch.tensor(encode(prompt), dtype=torch.long, device=device).unsqueeze(0) # Generate text # Ensure generate method uses torch.no_grad() internally generated_ids = model.generate(encoded_prompt, max_new_tokens=max_tokens_to_generate) generated_text = decode(generated_ids[0].tolist()) # Decode the generated indices print("\n--- Generated Text ---") print(generated_text) print("\n----------------------") ``` This model was trained in about 1 hour and 30 minutes, so it has basic word connection (In English) but thats about it.