|
import math
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from typing import Optional, Tuple, List, Union
|
|
from transformers import PreTrainedModel, AutoModelForCausalLM
|
|
from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutputWithPast
|
|
from .configuration_sapnous import SapnousT1Config
|
|
from .attention_sapnous import SapnousAttention, SapnousBlock, SapnousVisionEmbeddings, precompute_freqs_cis
|
|
|
|
class SapnousT1PreTrainedModel(PreTrainedModel):
|
|
"""Base class for all Sapnous-T1 models."""
|
|
config_class = SapnousT1Config
|
|
base_model_prefix = "sapnous"
|
|
|
|
def __init__(self, config: SapnousT1Config):
|
|
super().__init__(config)
|
|
self.config = config
|
|
|
|
def _init_weights(self, module):
|
|
"""Initialize weights using the model's initialization configuration."""
|
|
std = self.config.initializer_range
|
|
if isinstance(module, nn.Linear):
|
|
module.weight.data.normal_(mean=0.0, std=std)
|
|
if module.bias is not None:
|
|
module.bias.data.zero_()
|
|
elif isinstance(module, nn.Embedding):
|
|
module.weight.data.normal_(mean=0.0, std=std)
|
|
elif isinstance(module, nn.LayerNorm):
|
|
module.bias.data.zero_()
|
|
module.weight.data.fill_(1.0)
|
|
elif isinstance(module, SapnousAttention):
|
|
module.q_proj.weight.data.normal_(mean=0.0, std=std)
|
|
module.k_proj.weight.data.normal_(mean=0.0, std=std)
|
|
module.v_proj.weight.data.normal_(mean=0.0, std=std)
|
|
module.o_proj.weight.data.normal_(mean=0.0, std=std)
|
|
|
|
class SapnousT1Model(SapnousT1PreTrainedModel):
|
|
"""Base Transformer Model with advanced attention mechanisms and optional vision support."""
|
|
def __init__(self, config: SapnousT1Config):
|
|
super().__init__(config)
|
|
|
|
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
|
|
self.layers = nn.ModuleList([SapnousBlock(config) for _ in range(config.num_hidden_layers)])
|
|
self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
|
|
|
self.vision_embed = SapnousVisionEmbeddings(config) if getattr(config, 'vision_config', None) else None
|
|
|
|
|
|
self.post_init()
|
|
|
|
|
|
self.freqs_cis = precompute_freqs_cis(
|
|
self.config.hidden_size // self.config.num_attention_heads,
|
|
self.config.max_position_embeddings,
|
|
self.config.rope_theta,
|
|
)
|
|
|
|
def get_input_embeddings(self) -> nn.Module:
|
|
return self.embeddings
|
|
|
|
def set_input_embeddings(self, value: nn.Module):
|
|
self.embeddings = value
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
pixel_values: Optional[torch.FloatTensor] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
if input_ids is not None and inputs_embeds is not None:
|
|
raise ValueError("You cannot specify both input_ids and inputs_embeds")
|
|
|
|
|
|
if input_ids is not None:
|
|
inputs_embeds = self.embeddings(input_ids)
|
|
batch_size, seq_length = input_ids.shape[:2]
|
|
else:
|
|
batch_size, seq_length = inputs_embeds.shape[:2]
|
|
|
|
|
|
if pixel_values is not None and self.vision_embed is not None:
|
|
vision_embeds = self.vision_embed(pixel_values)
|
|
inputs_embeds = torch.cat([vision_embeds, inputs_embeds], dim=1)
|
|
seq_length = inputs_embeds.shape[1]
|
|
|
|
if position_ids is None:
|
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
|
position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
|
|
position_ids = position_ids.unsqueeze(0)
|
|
|
|
|
|
if attention_mask is not None:
|
|
attention_mask = attention_mask.view(batch_size, -1)
|
|
attention_mask = attention_mask[:, None, None, :]
|
|
attention_mask = attention_mask.to(dtype=inputs_embeds.dtype)
|
|
attention_mask = (1.0 - attention_mask) * torch.finfo(inputs_embeds.dtype).min
|
|
|
|
freqs_cis = self.freqs_cis.to(inputs_embeds.device)
|
|
|
|
hidden_states = inputs_embeds
|
|
all_hidden_states = () if output_hidden_states else None
|
|
all_self_attns = () if output_attentions else None
|
|
next_decoder_cache = () if use_cache else None
|
|
|
|
for idx, decoder_layer in enumerate(self.layers):
|
|
if output_hidden_states:
|
|
all_hidden_states += (hidden_states,)
|
|
|
|
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
|
|
|
layer_outputs = decoder_layer(
|
|
hidden_states,
|
|
freqs_cis=freqs_cis,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_value=past_key_value,
|
|
output_attentions=output_attentions,
|
|
use_cache=use_cache,
|
|
)
|
|
|
|
hidden_states = layer_outputs[0]
|
|
|
|
if use_cache:
|
|
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
|
|
|
if output_attentions:
|
|
all_self_attns += (layer_outputs[1],)
|
|
|
|
hidden_states = self.norm(hidden_states)
|
|
|
|
if output_hidden_states:
|
|
all_hidden_states += (hidden_states,)
|
|
|
|
if not return_dict:
|
|
return tuple(v for v in [
|
|
hidden_states,
|
|
next_decoder_cache,
|
|
all_hidden_states,
|
|
all_self_attns,
|
|
] if v is not None)
|
|
|
|
return BaseModelOutputWithPast(
|
|
last_hidden_state=hidden_states,
|
|
past_key_values=next_decoder_cache,
|
|
hidden_states=all_hidden_states,
|
|
attentions=all_self_attns,
|
|
)
|
|
|
|
class SapnousT1ForCausalLM(SapnousT1PreTrainedModel):
|
|
"""Sapnous-T1 Model for Causal Language Modeling with vision support."""
|
|
_keys_to_ignore_on_load_missing = [r"lm_head.weight"]
|
|
|
|
def __init__(self, config: SapnousT1Config):
|
|
super().__init__(config)
|
|
self.model = SapnousT1Model(config)
|
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
|
|
|
|
self.post_init()
|
|
|
|
def get_input_embeddings(self) -> nn.Module:
|
|
return self.model.embeddings
|
|
|
|
def set_input_embeddings(self, value: nn.Module):
|
|
self.model.embeddings = value
|
|
|
|
def get_output_embeddings(self) -> nn.Module:
|
|
return self.lm_head
|
|
|
|
def set_output_embeddings(self, new_embeddings: nn.Module):
|
|
self.lm_head = new_embeddings
|
|
|
|
def prepare_inputs_for_generation(
|
|
self,
|
|
input_ids: torch.LongTensor,
|
|
past_key_values: Optional[List[Tuple[torch.Tensor]]] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
**kwargs,
|
|
) -> dict:
|
|
if past_key_values:
|
|
input_ids = input_ids[:, -1:]
|
|
|
|
position_ids = kwargs.get("position_ids", None)
|
|
if position_ids is None:
|
|
position_ids = (attention_mask.long().cumsum(-1) - 1) if attention_mask is not None else None
|
|
if past_key_values:
|
|
position_ids = position_ids[:, -1].unsqueeze(-1)
|
|
|
|
return {
|
|
"input_ids": input_ids,
|
|
"attention_mask": attention_mask,
|
|
"position_ids": position_ids,
|
|
"past_key_values": past_key_values,
|
|
"use_cache": kwargs.get("use_cache"),
|
|
"pixel_values": kwargs.get("pixel_values", None),
|
|
}
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
pixel_values: Optional[torch.FloatTensor] = None,
|
|
labels: Optional[torch.LongTensor] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
r"""Labels for computing the masked language modeling loss."""
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
outputs = self.model(
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_values=past_key_values,
|
|
inputs_embeds=inputs_embeds,
|
|
pixel_values=pixel_values,
|
|
use_cache=use_cache,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
|
|
hidden_states = outputs[0]
|
|
logits = self.lm_head(hidden_states)
|
|
|
|
loss = None
|
|
if labels is not None:
|
|
shift_logits = logits[..., :-1, :].contiguous()
|
|
shift_labels = labels[..., 1:].contiguous()
|
|
loss_fct = nn.CrossEntropyLoss()
|
|
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
|
|
|
if not return_dict:
|
|
output = (logits,) + outputs[1:]
|
|
return ((loss,) + output) if loss is not None else output
|
|
|
|
return CausalLMOutputWithPast(
|
|
loss=loss,
|
|
logits=logits,
|
|
past_key_values=outputs.past_key_values,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
)
|
|
|
|
def tie_weights(self):
|
|
"""Tie the weights between the input embeddings and the output embeddings."""
|
|
self.lm_head.weight = self.model.embeddings.weight
|
|
|
|
|
|
AutoModelForCausalLM.register(SapnousT1Config, SapnousT1ForCausalLM)
|
|
|