import math import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Tuple, List, Union from transformers import PreTrainedModel, AutoModelForCausalLM from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutputWithPast from .configuration_sapnous import SapnousT1Config from .attention_sapnous import SapnousAttention, SapnousBlock, SapnousVisionEmbeddings, precompute_freqs_cis class SapnousT1PreTrainedModel(PreTrainedModel): """Base class for all Sapnous-T1 models.""" config_class = SapnousT1Config base_model_prefix = "sapnous" def __init__(self, config: SapnousT1Config): super().__init__(config) self.config = config def _init_weights(self, module): """Initialize weights using the model's initialization configuration.""" std = self.config.initializer_range if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) elif isinstance(module, SapnousAttention): module.q_proj.weight.data.normal_(mean=0.0, std=std) module.k_proj.weight.data.normal_(mean=0.0, std=std) module.v_proj.weight.data.normal_(mean=0.0, std=std) module.o_proj.weight.data.normal_(mean=0.0, std=std) class SapnousT1Model(SapnousT1PreTrainedModel): """Base Transformer Model with advanced attention mechanisms and optional vision support.""" def __init__(self, config: SapnousT1Config): super().__init__(config) self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([SapnousBlock(config) for _ in range(config.num_hidden_layers)]) self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps) # Vision support self.vision_embed = SapnousVisionEmbeddings(config) if getattr(config, 'vision_config', None) else None # Initialize weights and apply final processing self.post_init() # Compute and cache RoPE frequencies self.freqs_cis = precompute_freqs_cis( self.config.hidden_size // self.config.num_attention_heads, self.config.max_position_embeddings, self.config.rope_theta, ) def get_input_embeddings(self) -> nn.Module: return self.embeddings def set_input_embeddings(self, value: nn.Module): self.embeddings = value def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds") # Process text input if input_ids is not None: inputs_embeds = self.embeddings(input_ids) batch_size, seq_length = input_ids.shape[:2] else: batch_size, seq_length = inputs_embeds.shape[:2] # Process vision input if available if pixel_values is not None and self.vision_embed is not None: vision_embeds = self.vision_embed(pixel_values) inputs_embeds = torch.cat([vision_embeds, inputs_embeds], dim=1) seq_length = inputs_embeds.shape[1] if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange(seq_length, dtype=torch.long, device=device) position_ids = position_ids.unsqueeze(0) # Prepare attention mask if attention_mask is not None: attention_mask = attention_mask.view(batch_size, -1) attention_mask = attention_mask[:, None, None, :] attention_mask = attention_mask.to(dtype=inputs_embeds.dtype) attention_mask = (1.0 - attention_mask) * torch.finfo(inputs_embeds.dtype).min freqs_cis = self.freqs_cis.to(inputs_embeds.device) hidden_states = inputs_embeds all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = () if use_cache else None for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) past_key_value = past_key_values[idx] if past_key_values is not None else None layer_outputs = decoder_layer( hidden_states, freqs_cis=freqs_cis, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) if output_attentions: all_self_attns += (layer_outputs[1],) hidden_states = self.norm(hidden_states) if output_hidden_states: all_hidden_states += (hidden_states,) if not return_dict: return tuple(v for v in [ hidden_states, next_decoder_cache, all_hidden_states, all_self_attns, ] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_decoder_cache, hidden_states=all_hidden_states, attentions=all_self_attns, ) class SapnousT1ForCausalLM(SapnousT1PreTrainedModel): """Sapnous-T1 Model for Causal Language Modeling with vision support.""" _keys_to_ignore_on_load_missing = [r"lm_head.weight"] def __init__(self, config: SapnousT1Config): super().__init__(config) self.model = SapnousT1Model(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self) -> nn.Module: return self.model.embeddings def set_input_embeddings(self, value: nn.Module): self.model.embeddings = value def get_output_embeddings(self) -> nn.Module: return self.lm_head def set_output_embeddings(self, new_embeddings: nn.Module): self.lm_head = new_embeddings def prepare_inputs_for_generation( self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.Tensor]]] = None, attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> dict: if past_key_values: input_ids = input_ids[:, -1:] position_ids = kwargs.get("position_ids", None) if position_ids is None: position_ids = (attention_mask.long().cumsum(-1) - 1) if attention_mask is not None else None if past_key_values: position_ids = position_ids[:, -1].unsqueeze(-1) return { "input_ids": input_ids, "attention_mask": attention_mask, "position_ids": position_ids, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "pixel_values": kwargs.get("pixel_values", None), } def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r"""Labels for computing the masked language modeling loss.""" return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, pixel_values=pixel_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = outputs[0] logits = self.lm_head(hidden_states) loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss_fct = nn.CrossEntropyLoss() loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) if not return_dict: output = (logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def tie_weights(self): """Tie the weights between the input embeddings and the output embeddings.""" self.lm_head.weight = self.model.embeddings.weight # Register the model AutoModelForCausalLM.register(SapnousT1Config, SapnousT1ForCausalLM)