# coding=utf-8 # Copyright 2025-present, the HuggingFace Inc. Team and AIRAS Inc. Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Tuple def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: """Precompute the frequency tensor for complex rotation.""" freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) t = torch.arange(end, device=freqs.device) freqs = torch.outer(t, freqs) return torch.polar(torch.ones_like(freqs), freqs) def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: """Apply rotary position embeddings to the input tensor.""" x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) freqs_cis = freqs_cis.view(1, *freqs_cis.shape) x_rotated = x_complex * freqs_cis return torch.view_as_real(x_rotated).flatten(-2) class SapnousAttention(nn.Module): """Multi-head attention with rotary position embeddings and sliding window attention.""" def __init__(self, config): super().__init__() self.config = config self.hidden_size = config.hidden_size self.num_attention_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_attention_heads self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.sliding_window = config.sliding_window if config.use_sliding_window else None if (self.head_dim * self.num_attention_heads) != self.hidden_size: raise ValueError( f"hidden_size must be divisible by num_attention_heads (got {self.hidden_size} and {self.num_attention_heads})" ) self.q_proj = nn.Linear(self.hidden_size, self.num_attention_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias=False) self.attention_dropout = nn.Dropout(config.attention_dropout) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int) -> torch.Tensor: return tensor.view(bsz, seq_len, self.num_attention_heads, self.head_dim).transpose(1, 2) def _kv_shape(self, tensor: torch.Tensor, seq_len: int, bsz: int) -> torch.Tensor: return tensor.view(bsz, seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) def forward( self, hidden_states: torch.Tensor, freqs_cis: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = self._shape(query_states, q_len, bsz) key_states = self._kv_shape(key_states, q_len, bsz) value_states = self._kv_shape(value_states, q_len, bsz) kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] # Apply rotary position embeddings if position_ids is None: position_ids = torch.arange(kv_seq_len, device=hidden_states.device) cos, sin = freqs_cis[position_ids] query_states, key_states = apply_rotary_emb(query_states, cos), apply_rotary_emb(key_states, sin) if past_key_value is not None: # Reuse k, v, self_attention key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) past_key_value = (key_states, value_states) if use_cache else None # Repeat k/v heads if n_kv_heads < n_heads key_states = torch.repeat_interleave(key_states, self.num_key_value_groups, dim=1) value_states = torch.repeat_interleave(value_states, self.num_key_value_groups, dim=1) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attention_mask is not None: attn_weights = attn_weights + attention_mask # Sliding window attention if configured if self.sliding_window is not None and kv_seq_len > self.sliding_window: # Create sliding window mask window_mask = torch.ones_like(attn_weights, dtype=torch.bool) for i in range(q_len): window_start = max(0, i - self.sliding_window // 2) window_end = min(kv_seq_len, i + self.sliding_window // 2) window_mask[:, :, i, window_start:window_end] = False attn_weights = attn_weights.masked_fill(window_mask, float('-inf')) # Causal mask for autoregressive generation if self.config.scoring_func == "softmax": causal_mask = torch.triu(torch.ones((q_len, kv_seq_len), dtype=torch.bool), diagonal=1) causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) attn_weights = attn_weights.masked_fill(causal_mask.to(attn_weights.device), float('-inf')) attn_weights = F.softmax(attn_weights, dim=-1) else: # Alternative scoring functions (e.g., RoPE-only, cosine similarity) attn_weights = F.relu(attn_weights) attn_weights = attn_weights / (attn_weights.sum(dim=-1, keepdim=True) + 1e-6) attn_weights = self.attention_dropout(attn_weights) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value class SapnousBlock(nn.Module): """Transformer block with attention, layer norm, and feed-forward network.""" def __init__(self, config): super().__init__() self.hidden_size = config.hidden_size self.self_attn = SapnousAttention(config) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps) self.mlp = nn.Sequential( nn.Linear(config.hidden_size, config.intermediate_size, bias=False), nn.SiLU(), nn.Linear(config.intermediate_size, config.hidden_size, bias=False), ) def forward( self, hidden_states: torch.Tensor, freqs_cis: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: # Self Attention residual = hidden_states hidden_states = self.input_layernorm(hidden_states) hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, freqs_cis=freqs_cis, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, ) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) if use_cache: outputs += (present_key_value,) return outputs class SapnousVisionEmbeddings(nn.Module): """Vision embeddings for multimodal support.""" def __init__(self, config): super().__init__() self.config = config self.hidden_size = config.hidden_size # Vision embedding layers self.patch_embed = nn.Conv2d(3, self.hidden_size, kernel_size=16, stride=16) self.cls_token = nn.Parameter(torch.zeros(1, 1, self.hidden_size)) self.pos_embed = nn.Parameter(torch.zeros(1, (224 // 16) ** 2 + 1, self.hidden_size)) # Layer normalization and dropout self.norm = nn.LayerNorm(self.hidden_size, eps=config.rms_norm_eps) self.dropout = nn.Dropout(config.attention_dropout) def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: B = pixel_values.shape[0] # Create patch embeddings x = self.patch_embed(pixel_values) x = x.flatten(2).transpose(1, 2) # B, N, C # Add cls token and position embeddings cls_tokens = self.cls_token.expand(B, -1, -1) x = torch.cat((cls_tokens, x), dim=1) x = x + self.pos_embed # Apply normalization and dropout x = self.norm(x) x = self.dropout(x) return x