File size: 11,268 Bytes
5838aa1 2e2a204 5838aa1 2e2a204 5838aa1 2e2a204 5838aa1 2e2a204 5838aa1 2e2a204 5838aa1 2e2a204 5838aa1 2e2a204 5838aa1 2e2a204 5838aa1 2e2a204 5838aa1 2e2a204 5838aa1 2e2a204 5838aa1 2e2a204 5838aa1 2e2a204 5838aa1 2e2a204 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 |
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, List, Union
from transformers import PreTrainedModel, AutoModelForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutputWithPast
from .configuration_sapnous import SapnousT1Config
from .attention_sapnous import SapnousAttention, SapnousBlock, SapnousVisionEmbeddings, precompute_freqs_cis
class SapnousT1PreTrainedModel(PreTrainedModel):
"""Base class for all Sapnous-T1 models."""
config_class = SapnousT1Config
base_model_prefix = "sapnous"
def __init__(self, config: SapnousT1Config):
super().__init__(config)
self.config = config
def _init_weights(self, module):
"""Initialize weights using the model's initialization configuration."""
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, SapnousAttention):
module.q_proj.weight.data.normal_(mean=0.0, std=std)
module.k_proj.weight.data.normal_(mean=0.0, std=std)
module.v_proj.weight.data.normal_(mean=0.0, std=std)
module.o_proj.weight.data.normal_(mean=0.0, std=std)
class SapnousT1Model(SapnousT1PreTrainedModel):
"""Base Transformer Model with advanced attention mechanisms and optional vision support."""
def __init__(self, config: SapnousT1Config):
super().__init__(config)
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList([SapnousBlock(config) for _ in range(config.num_hidden_layers)])
self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
# Vision support
self.vision_embed = SapnousVisionEmbeddings(config) if getattr(config, 'vision_config', None) else None
# Initialize weights and apply final processing
self.post_init()
# Compute and cache RoPE frequencies
self.freqs_cis = precompute_freqs_cis(
self.config.hidden_size // self.config.num_attention_heads,
self.config.max_position_embeddings,
self.config.rope_theta,
)
def get_input_embeddings(self) -> nn.Module:
return self.embeddings
def set_input_embeddings(self, value: nn.Module):
self.embeddings = value
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds")
# Process text input
if input_ids is not None:
inputs_embeds = self.embeddings(input_ids)
batch_size, seq_length = input_ids.shape[:2]
else:
batch_size, seq_length = inputs_embeds.shape[:2]
# Process vision input if available
if pixel_values is not None and self.vision_embed is not None:
vision_embeds = self.vision_embed(pixel_values)
inputs_embeds = torch.cat([vision_embeds, inputs_embeds], dim=1)
seq_length = inputs_embeds.shape[1]
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0)
# Prepare attention mask
if attention_mask is not None:
attention_mask = attention_mask.view(batch_size, -1)
attention_mask = attention_mask[:, None, None, :]
attention_mask = attention_mask.to(dtype=inputs_embeds.dtype)
attention_mask = (1.0 - attention_mask) * torch.finfo(inputs_embeds.dtype).min
freqs_cis = self.freqs_cis.to(inputs_embeds.device)
hidden_states = inputs_embeds
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
layer_outputs = decoder_layer(
hidden_states,
freqs_cis=freqs_cis,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
if output_hidden_states:
all_hidden_states += (hidden_states,)
if not return_dict:
return tuple(v for v in [
hidden_states,
next_decoder_cache,
all_hidden_states,
all_self_attns,
] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class SapnousT1ForCausalLM(SapnousT1PreTrainedModel):
"""Sapnous-T1 Model for Causal Language Modeling with vision support."""
_keys_to_ignore_on_load_missing = [r"lm_head.weight"]
def __init__(self, config: SapnousT1Config):
super().__init__(config)
self.model = SapnousT1Model(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> nn.Module:
return self.model.embeddings
def set_input_embeddings(self, value: nn.Module):
self.model.embeddings = value
def get_output_embeddings(self) -> nn.Module:
return self.lm_head
def set_output_embeddings(self, new_embeddings: nn.Module):
self.lm_head = new_embeddings
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor,
past_key_values: Optional[List[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> dict:
if past_key_values:
input_ids = input_ids[:, -1:]
position_ids = kwargs.get("position_ids", None)
if position_ids is None:
position_ids = (attention_mask.long().cumsum(-1) - 1) if attention_mask is not None else None
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"pixel_values": kwargs.get("pixel_values", None),
}
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""Labels for computing the masked language modeling loss."""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
pixel_values=pixel_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def tie_weights(self):
"""Tie the weights between the input embeddings and the output embeddings."""
self.lm_head.weight = self.model.embeddings.weight
# Register the model
AutoModelForCausalLM.register(SapnousT1Config, SapnousT1ForCausalLM)
|