Sapnous-VR-6B / attention_sapnous.py
Atah Alam
Updated py files
5838aa1
raw
history blame contribute delete
10.6 kB
# coding=utf-8
# Copyright 2025-present, the HuggingFace Inc. Team and AIRAS Inc. Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
"""Precompute the frequency tensor for complex rotation."""
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs)
return torch.polar(torch.ones_like(freqs), freqs)
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
"""Apply rotary position embeddings to the input tensor."""
x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
freqs_cis = freqs_cis.view(1, *freqs_cis.shape)
x_rotated = x_complex * freqs_cis
return torch.view_as_real(x_rotated).flatten(-2)
class SapnousAttention(nn.Module):
"""Multi-head attention with rotary position embeddings and sliding window attention."""
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_attention_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.sliding_window = config.sliding_window if config.use_sliding_window else None
if (self.head_dim * self.num_attention_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_attention_heads (got {self.hidden_size} and {self.num_attention_heads})"
)
self.q_proj = nn.Linear(self.hidden_size, self.num_attention_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias=False)
self.attention_dropout = nn.Dropout(config.attention_dropout)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int) -> torch.Tensor:
return tensor.view(bsz, seq_len, self.num_attention_heads, self.head_dim).transpose(1, 2)
def _kv_shape(self, tensor: torch.Tensor, seq_len: int, bsz: int) -> torch.Tensor:
return tensor.view(bsz, seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
def forward(
self,
hidden_states: torch.Tensor,
freqs_cis: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = self._shape(query_states, q_len, bsz)
key_states = self._kv_shape(key_states, q_len, bsz)
value_states = self._kv_shape(value_states, q_len, bsz)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
# Apply rotary position embeddings
if position_ids is None:
position_ids = torch.arange(kv_seq_len, device=hidden_states.device)
cos, sin = freqs_cis[position_ids]
query_states, key_states = apply_rotary_emb(query_states, cos), apply_rotary_emb(key_states, sin)
if past_key_value is not None:
# Reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# Repeat k/v heads if n_kv_heads < n_heads
key_states = torch.repeat_interleave(key_states, self.num_key_value_groups, dim=1)
value_states = torch.repeat_interleave(value_states, self.num_key_value_groups, dim=1)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
# Sliding window attention if configured
if self.sliding_window is not None and kv_seq_len > self.sliding_window:
# Create sliding window mask
window_mask = torch.ones_like(attn_weights, dtype=torch.bool)
for i in range(q_len):
window_start = max(0, i - self.sliding_window // 2)
window_end = min(kv_seq_len, i + self.sliding_window // 2)
window_mask[:, :, i, window_start:window_end] = False
attn_weights = attn_weights.masked_fill(window_mask, float('-inf'))
# Causal mask for autoregressive generation
if self.config.scoring_func == "softmax":
causal_mask = torch.triu(torch.ones((q_len, kv_seq_len), dtype=torch.bool), diagonal=1)
causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)
attn_weights = attn_weights.masked_fill(causal_mask.to(attn_weights.device), float('-inf'))
attn_weights = F.softmax(attn_weights, dim=-1)
else:
# Alternative scoring functions (e.g., RoPE-only, cosine similarity)
attn_weights = F.relu(attn_weights)
attn_weights = attn_weights / (attn_weights.sum(dim=-1, keepdim=True) + 1e-6)
attn_weights = self.attention_dropout(attn_weights)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class SapnousBlock(nn.Module):
"""Transformer block with attention, layer norm, and feed-forward network."""
def __init__(self, config):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = SapnousAttention(config)
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
self.mlp = nn.Sequential(
nn.Linear(config.hidden_size, config.intermediate_size, bias=False),
nn.SiLU(),
nn.Linear(config.intermediate_size, config.hidden_size, bias=False),
)
def forward(
self,
hidden_states: torch.Tensor,
freqs_cis: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
freqs_cis=freqs_cis,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
class SapnousVisionEmbeddings(nn.Module):
"""Vision embeddings for multimodal support."""
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
# Vision embedding layers
self.patch_embed = nn.Conv2d(3, self.hidden_size, kernel_size=16, stride=16)
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.hidden_size))
self.pos_embed = nn.Parameter(torch.zeros(1, (224 // 16) ** 2 + 1, self.hidden_size))
# Layer normalization and dropout
self.norm = nn.LayerNorm(self.hidden_size, eps=config.rms_norm_eps)
self.dropout = nn.Dropout(config.attention_dropout)
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
B = pixel_values.shape[0]
# Create patch embeddings
x = self.patch_embed(pixel_values)
x = x.flatten(2).transpose(1, 2) # B, N, C
# Add cls token and position embeddings
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed
# Apply normalization and dropout
x = self.norm(x)
x = self.dropout(x)
return x