Character-Level GPT Model Trained on WikiText-2

This repository contains a character-level GPT model trained on the WikiText-2 dataset. The model architecture is a custom implementation of a transformer-based language model with Rotary Positional Embeddings (RoPE) and SwiGLU feed-forward networks.

This was Largely AI generated.

As in all of the code to make the model.

And a large amount of this description.

Model Description

  • Model Type: Custom Transformer-based Causal Language Model
  • Training Data: WikiText-2
  • Tokenization: Character-level
  • Architecture Details:
    • Rotary Positional Embeddings (RoPE)
    • SwiGLU Feed-Forward Networks
  • Parameters:
    • n_layer: 8
    • n_head: 8
    • n_embd: 512
    • block_size: 512
    • dropout: 0.1
    • vocab_size: 283

Intended Use

This model is intended for research purposes, including:

  • Experimenting with character-level language modeling.
  • Studying the effects of different training techniques on transformer models.

Limitations

  • Character-Level Tokenization: The model uses character-level tokenization, which is less efficient than subword tokenization (e.g., Byte-Pair Encoding) for capturing long-range dependencies and generating coherent text. As a result, the quality of generated text may be limited compared to models using subword tokenization.
  • Limited Training Data: The model was trained on the WikiText-2 dataset, which is relatively small. Training on a larger dataset would likely improve performance.
  • Custom Architecture: This is a custom model implementation, not a standard pre-trained model from the transformers library.
  • Requires Manual Intervention: Loading and using this model requires manual intervention and a deeper understanding of the architecture. The AutoModelForCausalLM class from transformers cannot be used.

Inference Code.

import torch
import torch.nn as nn
from torch.nn import functional as F
import math
import json
import os
from huggingface_hub import hf_hub_download

# --- Configuration ---
repo_id = "Ma7ee7/WikiGPT-25M"
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

# --- Download Necessary Files ---
print(f"Downloading files from {repo_id}...")
try:
    config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
    weights_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin")
    vocab_path = hf_hub_download(repo_id=repo_id, filename="vocab.json")
    # tokenizer_config_path = hf_hub_download(repo_id=repo_id, filename="tokenizer_config.json") # Optional
    print("Files downloaded successfully.")
except Exception as e:
    print(f"Error downloading files: {e}")
    print("Please ensure the repository ID is correct and the files exist.")
    exit()

# --- Load Configuration ---
print("Loading configuration...")
try:
    with open(config_path, 'r') as f:
        config = json.load(f)
    print("Configuration loaded:")
    print(config)

    # Extract necessary hyperparameters from config
    vocab_size = config["vocab_size"]
    n_layer = config["n_layer"]
    n_head = config["n_head"]
    n_embd = config["n_embd"]
    block_size = config["block_size"]
    dropout = config["dropout"]

except Exception as e:
    print(f"Error loading config.json: {e}")
    exit()

# --- RoPE Helper Functions ---
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, device='cpu'):
    freqs_part = torch.arange(0, dim, 2)[: (dim // 2)].float() / dim
    freqs = 1.0 / (theta ** freqs_part).to(device)
    t = torch.arange(end, device=device)
    freqs = torch.outer(t, freqs).float()
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_cis

def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    if not freqs_cis.shape == (x.shape[1], x.shape[-1]): # Check dimensions used
         raise ValueError(f"Freqs shape {freqs_cis.shape} does not match x shape {x.shape} at dims 1 and -1")
    shape = [1] * (ndim - 2) + list(freqs_cis.shape)
    return freqs_cis.view(*shape)

def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor):
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out_complex = xq_ * freqs_cis
    xk_out_complex = xk_ * freqs_cis
    xq_out = torch.view_as_real(xq_out_complex).flatten(start_dim=2)
    xk_out = torch.view_as_real(xk_out_complex).flatten(start_dim=2)
    return xq_out.type_as(xq), xk_out.type_as(xk)
# --- End RoPE Helpers ---


class Head(nn.Module):
    """ one head of self-attention with RoPE """
    def __init__(self, head_size):
        super().__init__()
        # Use parameters from loaded config directly if possible, or pass them
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        # Register buffer requires size, use block_size from config
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout) # Use dropout from config

    def forward(self, x, freqs_cis):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)

        q_rope, k_rope = apply_rotary_emb(q, k, freqs_cis=freqs_cis)

        head_size = q_rope.shape[-1]
        # Use tensor for scale
        scale = torch.tensor(head_size ** -0.5, device=q_rope.device, dtype=q_rope.dtype)
        wei = (q_rope @ k_rope.transpose(-2, -1)) * scale

        mask = self.tril[:T, :T] == 0
        wei = wei.masked_fill(mask.to(wei.device), float('-inf')) # Ensure mask on correct device

        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)

        v = self.value(x)
        out = wei @ v
        return out

class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel with RoPE """
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd) # head_size * num_heads = n_embd
        self.dropout = nn.Dropout(dropout) # Use dropout from config

    def forward(self, x, freqs_cis):
        head_outputs = [h(x, freqs_cis) for h in self.heads]
        out = torch.cat(head_outputs, dim=-1)
        out = self.dropout(self.proj(out))
        return out

class SwiGLU(nn.Module):
    """ SwiGLU Feed-Forward Network """
    def __init__(self, n_embd, hidden_dim=None, dropout=0.1): # Allow dropout override
        super().__init__()
        if hidden_dim is None:
            hidden_dim = int(4 * n_embd * 2 / 3)
        self.w1 = nn.Linear(n_embd, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, n_embd, bias=False)
        self.w3 = nn.Linear(n_embd, hidden_dim, bias=False)
        self.dropout = nn.Dropout(dropout) # Use passed dropout

    def forward(self, x):
        gate = self.w3(x)
        value = self.w1(x)
        swish_gate = F.silu(gate)
        out = swish_gate * value
        out = self.dropout(self.w2(out))
        return out

class Block(nn.Module):
    """ Transformer block using RoPE attention and SwiGLU FFN """
    def __init__(self, n_embd, n_head):
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        # Pass dropout rate from config to SwiGLU
        self.ffwd = SwiGLU(n_embd, dropout=dropout)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x, freqs_cis):
        x = x + self.sa(self.ln1(x), freqs_cis)
        x = x + self.ffwd(self.ln2(x))
        return x

class GPTLanguageModel(nn.Module):
    def __init__(self, vocab_size, n_embd, n_head, n_layer, block_size, dropout):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.blocks = nn.ModuleList([Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)

        self.block_size = block_size # Store block_size

        # Precompute RoPE frequencies using parameters
        head_dim = n_embd // n_head
        # Pass device explicitly during precomputation if model is moved later
        freqs_cis_buffer = precompute_freqs_cis(head_dim, block_size * 2, device='cpu') # Compute on CPU first
        self.register_buffer("freqs_cis", freqs_cis_buffer, persistent=False) # Register but don't save in state_dict

        # No weight initialization needed here for inference

    def forward(self, idx, targets=None):
        B, T = idx.shape
        # Ensure T doesn't exceed block size for freqs_cis slicing
        # In generate, idx_cond handles this, but good check here too
        T_used = min(T, self.block_size)

        tok_emb = self.token_embedding_table(idx[:, -T_used:]) # Use only last block_size tokens if T > block_size

        # Retrieve precomputed RoPE frequencies for the actual sequence length T_used
        # Move required part of freqs_cis to the same device as embeddings
        freqs_cis_for_block = self.freqs_cis[:T_used].to(tok_emb.device)

        x = tok_emb
        for block in self.blocks:
            x = block(x, freqs_cis_for_block)

        x = self.ln_f(x)

        if targets is not None:
            # This path isn't typically used during inference
            logits = self.lm_head(x)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        else:
            # Inference: compute logits only for the last token
            logits = self.lm_head(x[:, [-1], :]) # (B, 1, vocab_size)
            loss = None

        return logits, loss

    @torch.no_grad() # Ensure no gradients are computed during generation
    def generate(self, idx, max_new_tokens):
        self.eval() # Ensure model is in eval mode
        for _ in range(max_new_tokens):
            # Crop context if it exceeds block size *before* forward pass
            idx_cond = idx if idx.size(1) <= self.block_size else idx[:, -self.block_size:]
            # Forward pass for inference (gets logits for the last token)
            logits, _ = self(idx_cond) # Call forward with targets=None
            logits = logits[:, -1, :] # Shape (B, vocab_size)

            # Softmax and sampling
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            # Append
            idx = torch.cat((idx, idx_next), dim=1)
        self.train() # Optional: set back to train mode if needed elsewhere
        return idx
# --- End Model Definitions ---


# --- Instantiate Model ---
print("Instantiating model...")
try:
    model = GPTLanguageModel(
        vocab_size=vocab_size,
        n_embd=n_embd,
        n_head=n_head,
        n_layer=n_layer,
        block_size=block_size,
        dropout=dropout
    )
    print("Model instantiated.")
except Exception as e:
    print(f"Error instantiating model: {e}")
    print("Ensure the class definitions above match the configuration.")
    exit()

# --- Load Weights ---
print("Loading model weights...")
try:
    state_dict = torch.load(weights_path, map_location=torch.device('cpu')) # Load to CPU first
    # Adapt state_dict if necessary (e.g., if module names changed)
    # Example: remove unexpected keys like 'freqs_cis' if they were accidentally saved
    state_dict.pop("freqs_cis", None)

    load_result = model.load_state_dict(state_dict, strict=True) # Use strict=True initially
    print(f"Weight loading result: {load_result}")
    print("Model weights loaded successfully.")
except Exception as e:
    print(f"Error loading weights: {e}")
    print("Ensure the model architecture definition matches the saved weights.")
    # If using strict=False, check missing/unexpected keys printed by load_state_dict
    exit()

# --- Setup for Inference ---
model.eval() # Set to evaluation mode (disable dropout etc.)
model.to(device) # Move model to target device
print(f"Model moved to {device} and set to evaluation mode.")

# --- Load Vocabulary and Define Encode/Decode ---
print("Loading vocabulary...")
try:
    with open(vocab_path, 'r', encoding='utf-8') as f:
        stoi = json.load(f)
    itos = {i: ch for ch, i in stoi.items()}
    print(f"Vocabulary loaded ({len(stoi)} chars).")

    # Define encoding/decoding functions based on loaded vocab
    encode = lambda s: [stoi.get(c, stoi.get('\n')) for c in s] # Use \n as fallback? Or a dedicated UNK?
    decode = lambda l: ''.join([itos.get(i, '?') for i in l])   # Use ? for unknown indices

except Exception as e:
    print(f"Error loading vocabulary: {e}")
    exit()


# --- Inference ---
print("\n--- Starting Inference ---")
prompt = " = = Artificial Intelligence in Medicine = = \n\n Artificial intelligence ( AI ) has" # Example prompt

max_tokens_to_generate = 500

print(f"Prompt:\n{prompt}")
print(f"\nGenerating ({max_tokens_to_generate} tokens)...")

# Encode the prompt
encoded_prompt = torch.tensor(encode(prompt), dtype=torch.long, device=device).unsqueeze(0)

# Generate text
# Ensure generate method uses torch.no_grad() internally
generated_ids = model.generate(encoded_prompt, max_new_tokens=max_tokens_to_generate)
generated_text = decode(generated_ids[0].tolist()) # Decode the generated indices

print("\n--- Generated Text ---")
print(generated_text)
print("\n----------------------")

This model was trained in about 1 hour and 30 minutes, so it has basic word connection (In English) but thats about it.

Downloads last month
17
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support