Sapnous-VR-6B / modeling_sapnous.py
Atah Alam
inintial commit
2e2a204
raw
history blame
2.04 kB
import torch
import torch.nn as nn
from transformers import PreTrainedModel, AutoModelForCausalLM
from configuration_sapnous import SapnousT1Config # Ensure this file is correct
class SapnousT1PreTrainedModel(PreTrainedModel):
"""Base class for all Sapnous-T1 models."""
config_class = SapnousT1Config
def __init__(self, config: SapnousT1Config):
super().__init__(config)
self.config = config
def _init_weights(self, module):
"""Initialize weights if required."""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
class SapnousT1Model(SapnousT1PreTrainedModel):
"""Base Transformer Model"""
def __init__(self, config: SapnousT1Config):
super().__init__(config)
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
self.encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=config.hidden_size,
nhead=config.num_attention_heads
),
num_layers=config.num_hidden_layers
)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
def forward(self, input_ids):
x = self.embeddings(input_ids)
x = self.encoder(x)
return self.lm_head(x)
class SapnousT1ForCausalLM(SapnousT1PreTrainedModel):
"""Sapnous-T1 Model for Causal LM (Text Generation)"""
def __init__(self, config: SapnousT1Config):
super().__init__(config)
self.model = SapnousT1Model(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def forward(self, input_ids):
hidden_states = self.model(input_ids)
logits = self.lm_head(hidden_states)
return logits
# ✅ Register the model properly
AutoModelForCausalLM.register(SapnousT1Config, SapnousT1ForCausalLM)