NGen4: A Causal Language Model with Rotary Positional Embeddings
NGen4 is a decoder-only causal language model architecture designed with a focus on modern techniques, particularly Rotary Positional Embeddings (RoPE) for effective handling of long contexts and flexible attention mechanisms. This repository contains the PyTorch implementation using the Hugging Face Transformers library.
This model card describes the NGen4 architecture. Specific instances (e.g., a 50M parameter model pre-trained on Wikitext-2) will have their details in respective sections or separate model cards if uploaded individually.
Model Approach & Architecture
NGen4 is built upon the standard Transformer decoder architecture, incorporating several key design choices for performance, flexibility, and effective long-context modeling.
1. Core Architecture: Decoder-Only Causal LM
- Type: Autoregressive, decoder-only Transformer.
- Objective: Trained for next-token prediction, making it suitable for text generation tasks.
- Causal Masking: A causal attention mask is applied in the self-attention layers to ensure that predictions for a token at position
i
can only depend on known outputs at positions less thani
.
2. Positional Embeddings: Rotary Positional Embeddings (RoPE)
NGen4 replaces traditional absolute or learned positional embeddings with Rotary Positional Embeddings (RoPE).
- Mechanism: RoPE encodes absolute positional information by rotating pairs of features in the query and key projections based on their position. This is applied after the Q/K projections but before the attention dot product.
- Advantages:
- Long Context Scalability: RoPE has shown excellent performance in generalizing to sequence lengths longer than those seen during training.
- Relative Position Encoding: While encoding absolute positions, it implicitly captures relative positional information in the self-attention mechanism through the rotational property.
- No Trainable Parameters: RoPE itself does not add trainable parameters for positional encoding.
- Implementation Details:
rope_theta
: The base period for the rotary encodings (default:10000.0
).rope_pct
: The percentage of head dimensions to which RoPE is applied (default:1.0
, meaning all dimensions). The actualrope_dim
is calculated ashead_dim * rope_pct
.- The implementation includes caching for
sin
andcos
values to improve efficiency. - The configuration includes
rope_scaling
(default:None
) as a placeholder for future integration of RoPE scaling strategies (e.g., NTK-aware scaling, YaRN) to further enhance long-context capabilities.
3. Attention Mechanism: Multi-Head Self-Attention (MHSA)
- Standard MHSA: The model uses multi-head self-attention as the core mechanism for information aggregation.
- Configurable Implementations: The
NGen4Config
allows specifying different attention implementations via theattn_implementation
parameter:"eager"
: The standard PyTorch implementation. Provides a clear reference but can be less memory/compute efficient."sdpa"
(Scaled Dot Product Attention): Leverages PyTorch 2.0's built-in optimizedF.scaled_dot_product_attention
. Generally faster and more memory-efficient than eager."flash_attention_2"
: If theflash-attn
library is installed and PyTorch >= 2.0, this option can be used for significant speedups and memory savings, especially for longer sequences and on compatible hardware.
- Projections: Separate linear projections are used to create Query (Q), Key (K), and Value (V) tensors from the input hidden states (
c_attn
). An output projection (c_proj
) is applied after attention. - Dropout: Dropout is applied to attention weights (
attn_pdrop
) and residual connections (resid_pdrop
).
4. Transformer Blocks (NGen4Block
)
Each NGen4 block follows a standard Pre-LayerNormalization (Pre-LN) structure:
- Layer Normalization (
ln_1
): Applied to the input hidden states. - Multi-Head Self-Attention (
attn
): As described above. - Residual Connection: The output of the attention module is added to the input of
ln_1
. - Layer Normalization (
ln_2
): Applied to the output of the first residual connection. - Feed-Forward Network (MLP) (
mlp
):- Two linear layers with an activation function in between.
- The intermediate size (
n_inner
) defaults to4 * n_embd
. - The activation function is configurable via
activation_function
(e.g.,"gelu_new"
).
- Residual Connection: The output of the MLP is added to the input of
ln_2
.
5. Embeddings and Output Layer
- Token Embeddings (
wte
): A standard learnable embedding layer maps input token IDs to dense vectors (n_embd
dimensions). - LM Head (
lm_head
): A linear layer maps the final hidden states from the transformer blocks to logits over the vocabulary (vocab_size
). - Weight Tying: The weights of the token embedding layer (
wte.weight
) and the LM head (lm_head.weight
) are typically tied. This is declared in the model via_tied_weights_keys = ["lm_head.weight"]
and in the configuration viatie_word_embeddings=True
. This practice reduces parameters and can improve performance.
6. Key Configuration Parameters
The architecture is defined by parameters in NGen4Config
, including:
vocab_size
: Size of the vocabulary.n_positions
: Maximum sequence length the model can process (context window).n_embd
: Dimensionality of the token embeddings and hidden states.n_layer
: Number of NGen4 transformer blocks.n_head
: Number of attention heads.n_inner
: Dimensionality of the intermediate layer in the MLP.activation_function
: Activation function for the MLP.- Dropout rates:
resid_pdrop
,embd_pdrop
,attn_pdrop
. layer_norm_epsilon
: Epsilon for LayerNorm stability.- RoPE parameters:
use_rope
,rope_theta
,rope_scaling
,rope_pct
. attn_implementation
: Choice of attention backend.tie_word_embeddings
: Whether to tie input and output embeddings.
7. Gradient Checkpointing
The model supports gradient checkpointing to reduce memory usage during training by trading VRAM for a small amount of recomputation. This is enabled through the Hugging Face Trainer
or by directly setting the gradient_checkpointing
attribute on the NGen4Model
instance.
Training Data
(This section should be updated based on the specific model instance)
- Initial Pre-training (Example for a 50M parameter model):
- Dataset: Wikitext-2 (
wikitext
,wikitext-2-raw-v1
configuration). - Preprocessing: Text was tokenized using a GPT-2 tokenizer and then grouped into blocks of sequence length (e.g., 512 tokens).
- Dataset: Wikitext-2 (
- Planned Further Pre-training:
- The model architecture is designed for larger datasets. Future work includes further pre-training on more extensive and diverse corpora like subsets of RedPajama (e.g., the arXiv subset) to enhance general language understanding and generation capabilities.
Training Procedure
(This section should be updated based on the specific model instance)
- Framework: Trained using PyTorch and the Hugging Face
Trainer
API. - Example Configuration (for Wikitext-2, ~50M model):
n_embd
: 512n_layer
: 8n_head
: 8n_positions
: 512 (alsoblock_size
for training)attn_implementation
: "eager" (can be set to "flash_attention_2" or "sdpa" if supported)
- Optimizer: AdamW.
- Learning Rate: E.g., 5e-5 with a linear scheduler and warmup.
- Batch Size: Effective batch size achieved through
per_device_train_batch_size
andgradient_accumulation_steps
. - Epochs: E.g., 3 epochs for Wikitext-2.
- Mixed Precision (FP16): Enabled to reduce memory and potentially speed up training.
- Gradient Checkpointing: Utilized during training.
Evaluation Results
(This section should be updated based on the specific model instance. Provide perplexity scores, loss, etc., on relevant evaluation datasets.)
- Example (Wikitext-2, ~50M model):
- Validation Perplexity:
[TODO: Add perplexity, e.g., 1.7781]
- Validation Loss:
[TODO: Add eval loss, e.g., 0.5755]
- Validation Perplexity:
How to Use
Prerequisites
Ensure you have PyTorch, Transformers, and the ngen4.py
model definition file.
pip install torch transformers
If using Flash Attention 2, install it separately:
pip install flash-attn --no-build-isolation
Loading the Model and Tokenizer
from transformers import AutoTokenizer
from ngen4 import NGen4ForCausalLM, NGen4Config # Ensure ngen4.py is in your Python path
model_path = "[path_to_your_saved_ngen4_model_directory]" # e.g., "./ngen4_wikitext2_50M_tied_v2"
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token # Common practice for GPT-like models
# Load model
# The NGen4Config will be loaded automatically from the model_path if config.json exists.
# If you need to override config parameters, you can load NGen4Config first, modify it,
# and then pass it to from_pretrained.
model = NGen4ForCausalLM.from_pretrained(model_path)
# Set device (e.g., CUDA if available)
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval() # Set to evaluation mode for inference
Text Generation Example
prompt = "Once upon a time in a land far away"
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
attention_mask = torch.ones_like(input_ids) # Explicitly create attention mask
with torch.no_grad():
output_sequences = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_length=150,
temperature=0.7,
top_k=50,
top_p=0.95,
do_sample=True,
num_return_sequences=1,
pad_token_id=tokenizer.eos_token_id
)
generated_text = tokenizer.decode(output_sequences[0], skip_special_tokens=True)
print(generated_text)
(For more detailed sampling options, see the sample_ngen4.py
script if provided in the model repository.)
Intended Use & Limitations
- Intended Use: This model is intended for text generation, capable of tasks like completing prompts, creative writing, and potentially (with further fine-tuning) summarization, question answering, etc. The base pre-trained model is primarily for next-token prediction based on its training data.
- Limitations:
- The model's knowledge is limited to its training data.
- It may generate factually incorrect, biased, or nonsensical text.
- Performance on specific downstream tasks will heavily depend on the quality and nature of further fine-tuning.
- The current ~50M parameter model trained on Wikitext-2 is a demonstration model and will have limited capabilities compared to larger models trained on more diverse datasets.
Future Work
- Further Pre-training: Scale up pre-training using larger and more diverse datasets such as subsets of RedPajama (e.g., arXiv, Books, CommonCrawl).
- Instruction Fine-tuning: Fine-tune the pre-trained NGen4 model on instruction-following datasets (e.g., Alpaca, Dolly, OpenOrca) to improve its ability to follow commands and engage in conversational interactions.
- RoPE Scaling: Implement and evaluate RoPE scaling techniques (NTK-aware, YaRN) to further enhance long-context performance.
- Evaluation: Conduct comprehensive evaluations on a wider range of downstream NLP benchmarks.
*Model architecture and training scripts developed by TNSA AI.
- Downloads last month
- 2