|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from transformers import PreTrainedModel, AutoConfig
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
class SapnousT1ForCausalLM(PreTrainedModel):
|
|
config_class = AutoConfig
|
|
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.hidden_size = config.hidden_size
|
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
|
self.layers = nn.ModuleList([
|
|
nn.Linear(config.hidden_size, config.hidden_size) for _ in range(config.num_hidden_layers)
|
|
])
|
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
|
|
def forward(self, input_ids):
|
|
hidden_states = self.embed_tokens(input_ids)
|
|
for layer in self.layers:
|
|
hidden_states = layer(hidden_states)
|
|
logits = self.lm_head(hidden_states)
|
|
return logits
|
|
|
|
|
|
from transformers import AutoModelForCausalLM
|
|
AutoModelForCausalLM.register(SapnousT1ForCausalLM, "sapnous_t1")
|
|
|