Sapnous-VR-6B / modeling_sapnous.py
Atah Alam
Updated py files
5838aa1
raw
history blame contribute delete
11.3 kB
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, List, Union
from transformers import PreTrainedModel, AutoModelForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutputWithPast
from .configuration_sapnous import SapnousT1Config
from .attention_sapnous import SapnousAttention, SapnousBlock, SapnousVisionEmbeddings, precompute_freqs_cis
class SapnousT1PreTrainedModel(PreTrainedModel):
"""Base class for all Sapnous-T1 models."""
config_class = SapnousT1Config
base_model_prefix = "sapnous"
def __init__(self, config: SapnousT1Config):
super().__init__(config)
self.config = config
def _init_weights(self, module):
"""Initialize weights using the model's initialization configuration."""
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, SapnousAttention):
module.q_proj.weight.data.normal_(mean=0.0, std=std)
module.k_proj.weight.data.normal_(mean=0.0, std=std)
module.v_proj.weight.data.normal_(mean=0.0, std=std)
module.o_proj.weight.data.normal_(mean=0.0, std=std)
class SapnousT1Model(SapnousT1PreTrainedModel):
"""Base Transformer Model with advanced attention mechanisms and optional vision support."""
def __init__(self, config: SapnousT1Config):
super().__init__(config)
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList([SapnousBlock(config) for _ in range(config.num_hidden_layers)])
self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
# Vision support
self.vision_embed = SapnousVisionEmbeddings(config) if getattr(config, 'vision_config', None) else None
# Initialize weights and apply final processing
self.post_init()
# Compute and cache RoPE frequencies
self.freqs_cis = precompute_freqs_cis(
self.config.hidden_size // self.config.num_attention_heads,
self.config.max_position_embeddings,
self.config.rope_theta,
)
def get_input_embeddings(self) -> nn.Module:
return self.embeddings
def set_input_embeddings(self, value: nn.Module):
self.embeddings = value
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds")
# Process text input
if input_ids is not None:
inputs_embeds = self.embeddings(input_ids)
batch_size, seq_length = input_ids.shape[:2]
else:
batch_size, seq_length = inputs_embeds.shape[:2]
# Process vision input if available
if pixel_values is not None and self.vision_embed is not None:
vision_embeds = self.vision_embed(pixel_values)
inputs_embeds = torch.cat([vision_embeds, inputs_embeds], dim=1)
seq_length = inputs_embeds.shape[1]
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0)
# Prepare attention mask
if attention_mask is not None:
attention_mask = attention_mask.view(batch_size, -1)
attention_mask = attention_mask[:, None, None, :]
attention_mask = attention_mask.to(dtype=inputs_embeds.dtype)
attention_mask = (1.0 - attention_mask) * torch.finfo(inputs_embeds.dtype).min
freqs_cis = self.freqs_cis.to(inputs_embeds.device)
hidden_states = inputs_embeds
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
layer_outputs = decoder_layer(
hidden_states,
freqs_cis=freqs_cis,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
if output_hidden_states:
all_hidden_states += (hidden_states,)
if not return_dict:
return tuple(v for v in [
hidden_states,
next_decoder_cache,
all_hidden_states,
all_self_attns,
] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class SapnousT1ForCausalLM(SapnousT1PreTrainedModel):
"""Sapnous-T1 Model for Causal Language Modeling with vision support."""
_keys_to_ignore_on_load_missing = [r"lm_head.weight"]
def __init__(self, config: SapnousT1Config):
super().__init__(config)
self.model = SapnousT1Model(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> nn.Module:
return self.model.embeddings
def set_input_embeddings(self, value: nn.Module):
self.model.embeddings = value
def get_output_embeddings(self) -> nn.Module:
return self.lm_head
def set_output_embeddings(self, new_embeddings: nn.Module):
self.lm_head = new_embeddings
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor,
past_key_values: Optional[List[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> dict:
if past_key_values:
input_ids = input_ids[:, -1:]
position_ids = kwargs.get("position_ids", None)
if position_ids is None:
position_ids = (attention_mask.long().cumsum(-1) - 1) if attention_mask is not None else None
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"pixel_values": kwargs.get("pixel_values", None),
}
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""Labels for computing the masked language modeling loss."""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
pixel_values=pixel_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def tie_weights(self):
"""Tie the weights between the input embeddings and the output embeddings."""
self.lm_head.weight = self.model.embeddings.weight
# Register the model
AutoModelForCausalLM.register(SapnousT1Config, SapnousT1ForCausalLM)