|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from typing import Optional, Tuple
|
|
|
|
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
|
|
"""Precompute the frequency tensor for complex rotation."""
|
|
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
|
t = torch.arange(end, device=freqs.device)
|
|
freqs = torch.outer(t, freqs)
|
|
return torch.polar(torch.ones_like(freqs), freqs)
|
|
|
|
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
|
"""Apply rotary position embeddings to the input tensor."""
|
|
x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
|
freqs_cis = freqs_cis.view(1, *freqs_cis.shape)
|
|
x_rotated = x_complex * freqs_cis
|
|
return torch.view_as_real(x_rotated).flatten(-2)
|
|
|
|
class SapnousAttention(nn.Module):
|
|
"""Multi-head attention with rotary position embeddings and sliding window attention."""
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.config = config
|
|
self.hidden_size = config.hidden_size
|
|
self.num_attention_heads = config.num_attention_heads
|
|
self.head_dim = self.hidden_size // self.num_attention_heads
|
|
self.num_key_value_heads = config.num_key_value_heads
|
|
self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
|
|
self.max_position_embeddings = config.max_position_embeddings
|
|
self.rope_theta = config.rope_theta
|
|
self.sliding_window = config.sliding_window if config.use_sliding_window else None
|
|
|
|
if (self.head_dim * self.num_attention_heads) != self.hidden_size:
|
|
raise ValueError(
|
|
f"hidden_size must be divisible by num_attention_heads (got {self.hidden_size} and {self.num_attention_heads})"
|
|
)
|
|
|
|
self.q_proj = nn.Linear(self.hidden_size, self.num_attention_heads * self.head_dim, bias=False)
|
|
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
|
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
|
self.o_proj = nn.Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias=False)
|
|
|
|
self.attention_dropout = nn.Dropout(config.attention_dropout)
|
|
|
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int) -> torch.Tensor:
|
|
return tensor.view(bsz, seq_len, self.num_attention_heads, self.head_dim).transpose(1, 2)
|
|
|
|
def _kv_shape(self, tensor: torch.Tensor, seq_len: int, bsz: int) -> torch.Tensor:
|
|
return tensor.view(bsz, seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
freqs_cis: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
output_attentions: bool = False,
|
|
use_cache: bool = False,
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
bsz, q_len, _ = hidden_states.size()
|
|
|
|
query_states = self.q_proj(hidden_states)
|
|
key_states = self.k_proj(hidden_states)
|
|
value_states = self.v_proj(hidden_states)
|
|
|
|
query_states = self._shape(query_states, q_len, bsz)
|
|
key_states = self._kv_shape(key_states, q_len, bsz)
|
|
value_states = self._kv_shape(value_states, q_len, bsz)
|
|
|
|
kv_seq_len = key_states.shape[-2]
|
|
if past_key_value is not None:
|
|
kv_seq_len += past_key_value[0].shape[-2]
|
|
|
|
|
|
if position_ids is None:
|
|
position_ids = torch.arange(kv_seq_len, device=hidden_states.device)
|
|
cos, sin = freqs_cis[position_ids]
|
|
query_states, key_states = apply_rotary_emb(query_states, cos), apply_rotary_emb(key_states, sin)
|
|
|
|
if past_key_value is not None:
|
|
|
|
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
|
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
|
|
|
past_key_value = (key_states, value_states) if use_cache else None
|
|
|
|
|
|
key_states = torch.repeat_interleave(key_states, self.num_key_value_groups, dim=1)
|
|
value_states = torch.repeat_interleave(value_states, self.num_key_value_groups, dim=1)
|
|
|
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
|
|
|
if attention_mask is not None:
|
|
attn_weights = attn_weights + attention_mask
|
|
|
|
|
|
if self.sliding_window is not None and kv_seq_len > self.sliding_window:
|
|
|
|
window_mask = torch.ones_like(attn_weights, dtype=torch.bool)
|
|
for i in range(q_len):
|
|
window_start = max(0, i - self.sliding_window // 2)
|
|
window_end = min(kv_seq_len, i + self.sliding_window // 2)
|
|
window_mask[:, :, i, window_start:window_end] = False
|
|
attn_weights = attn_weights.masked_fill(window_mask, float('-inf'))
|
|
|
|
|
|
if self.config.scoring_func == "softmax":
|
|
causal_mask = torch.triu(torch.ones((q_len, kv_seq_len), dtype=torch.bool), diagonal=1)
|
|
causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)
|
|
attn_weights = attn_weights.masked_fill(causal_mask.to(attn_weights.device), float('-inf'))
|
|
attn_weights = F.softmax(attn_weights, dim=-1)
|
|
else:
|
|
|
|
attn_weights = F.relu(attn_weights)
|
|
attn_weights = attn_weights / (attn_weights.sum(dim=-1, keepdim=True) + 1e-6)
|
|
|
|
attn_weights = self.attention_dropout(attn_weights)
|
|
attn_output = torch.matmul(attn_weights, value_states)
|
|
|
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
|
|
|
attn_output = self.o_proj(attn_output)
|
|
|
|
if not output_attentions:
|
|
attn_weights = None
|
|
|
|
return attn_output, attn_weights, past_key_value
|
|
|
|
class SapnousBlock(nn.Module):
|
|
"""Transformer block with attention, layer norm, and feed-forward network."""
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.hidden_size = config.hidden_size
|
|
self.self_attn = SapnousAttention(config)
|
|
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
|
self.mlp = nn.Sequential(
|
|
nn.Linear(config.hidden_size, config.intermediate_size, bias=False),
|
|
nn.SiLU(),
|
|
nn.Linear(config.intermediate_size, config.hidden_size, bias=False),
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
freqs_cis: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
output_attentions: bool = False,
|
|
use_cache: bool = False,
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
|
|
residual = hidden_states
|
|
hidden_states = self.input_layernorm(hidden_states)
|
|
|
|
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
|
hidden_states=hidden_states,
|
|
freqs_cis=freqs_cis,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_value=past_key_value,
|
|
output_attentions=output_attentions,
|
|
use_cache=use_cache,
|
|
)
|
|
hidden_states = residual + hidden_states
|
|
|
|
|
|
residual = hidden_states
|
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
hidden_states = self.mlp(hidden_states)
|
|
hidden_states = residual + hidden_states
|
|
|
|
outputs = (hidden_states,)
|
|
|
|
if output_attentions:
|
|
outputs += (self_attn_weights,)
|
|
|
|
if use_cache:
|
|
outputs += (present_key_value,)
|
|
|
|
return outputs
|
|
|
|
class SapnousVisionEmbeddings(nn.Module):
|
|
"""Vision embeddings for multimodal support."""
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.config = config
|
|
self.hidden_size = config.hidden_size
|
|
|
|
|
|
self.patch_embed = nn.Conv2d(3, self.hidden_size, kernel_size=16, stride=16)
|
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.hidden_size))
|
|
self.pos_embed = nn.Parameter(torch.zeros(1, (224 // 16) ** 2 + 1, self.hidden_size))
|
|
|
|
|
|
self.norm = nn.LayerNorm(self.hidden_size, eps=config.rms_norm_eps)
|
|
self.dropout = nn.Dropout(config.attention_dropout)
|
|
|
|
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
|
B = pixel_values.shape[0]
|
|
|
|
|
|
x = self.patch_embed(pixel_values)
|
|
x = x.flatten(2).transpose(1, 2)
|
|
|
|
|
|
cls_tokens = self.cls_token.expand(B, -1, -1)
|
|
x = torch.cat((cls_tokens, x), dim=1)
|
|
x = x + self.pos_embed
|
|
|
|
|
|
x = self.norm(x)
|
|
x = self.dropout(x)
|
|
|
|
return x |