|
import torch
|
|
import torch.nn as nn
|
|
from transformers import PreTrainedModel, AutoModelForCausalLM
|
|
from configuration_sapnous import SapnousT1Config
|
|
|
|
class SapnousT1PreTrainedModel(PreTrainedModel):
|
|
"""Base class for all Sapnous-T1 models."""
|
|
config_class = SapnousT1Config
|
|
|
|
def __init__(self, config: SapnousT1Config):
|
|
super().__init__(config)
|
|
self.config = config
|
|
|
|
def _init_weights(self, module):
|
|
"""Initialize weights if required."""
|
|
if isinstance(module, nn.Linear):
|
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
|
if module.bias is not None:
|
|
module.bias.data.zero_()
|
|
|
|
class SapnousT1Model(SapnousT1PreTrainedModel):
|
|
"""Base Transformer Model"""
|
|
def __init__(self, config: SapnousT1Config):
|
|
super().__init__(config)
|
|
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
|
|
self.encoder = nn.TransformerEncoder(
|
|
nn.TransformerEncoderLayer(
|
|
d_model=config.hidden_size,
|
|
nhead=config.num_attention_heads
|
|
),
|
|
num_layers=config.num_hidden_layers
|
|
)
|
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
|
|
|
|
def forward(self, input_ids):
|
|
x = self.embeddings(input_ids)
|
|
x = self.encoder(x)
|
|
return self.lm_head(x)
|
|
|
|
class SapnousT1ForCausalLM(SapnousT1PreTrainedModel):
|
|
"""Sapnous-T1 Model for Causal LM (Text Generation)"""
|
|
def __init__(self, config: SapnousT1Config):
|
|
super().__init__(config)
|
|
self.model = SapnousT1Model(config)
|
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
|
|
def forward(self, input_ids):
|
|
hidden_states = self.model(input_ids)
|
|
logits = self.lm_head(hidden_states)
|
|
return logits
|
|
|
|
|
|
AutoModelForCausalLM.register(SapnousT1Config, SapnousT1ForCausalLM)
|
|
|