# Copyright (c) 2024 The HuggingFace Inc. team. # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 # # This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20. # # Original file was released under Apache-2.0, with the full license text # available at https://github.com/huggingface/transformers/blob/main/LICENSE. # # This modified file is released under the same license. import torch from torch import nn from transformers.activations import ACT2FN from modeling.siglip.configuration_siglip import SiglipVisionConfig as _SiglipVisionConfig from modeling.siglip.modeling_siglip import SiglipAttention, SiglipPreTrainedModel from flash_attn import flash_attn_varlen_func class SiglipVisionConfig(_SiglipVisionConfig): r""" This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. Args: hidden_size (`int`, *optional*, defaults to 768): Dimensionality of the encoder layers and the pooler layer. intermediate_size (`int`, *optional*, defaults to 3072): Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. num_hidden_layers (`int`, *optional*, defaults to 12): Number of hidden layers in the Transformer encoder. num_attention_heads (`int`, *optional*, defaults to 12): Number of attention heads for each attention layer in the Transformer encoder. num_channels (`int`, *optional*, defaults to 3): Number of channels in the input images. image_size (`int`, *optional*, defaults to 224): The size (resolution) of each image. patch_size (`int`, *optional*, defaults to 16): The size (resolution) of each patch. hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. layer_norm_eps (`float`, *optional*, defaults to 1e-06): The epsilon used by the layer normalization layers. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. Example: ```python >>> from transformers import SiglipVisionConfig, SiglipVisionModel >>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224 style configuration >>> configuration = SiglipVisionConfig() >>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration >>> model = SiglipVisionModel(configuration) >>> # Accessing the model configuration >>> configuration = model.config ```""" model_type = "siglip_vision_model" def __init__( self, hidden_size=768, intermediate_size=3072, num_hidden_layers=12, num_attention_heads=12, num_channels=3, image_size=224, patch_size=16, hidden_act="gelu_pytorch_tanh", layer_norm_eps=1e-6, attention_dropout=0.0, rope=True, **kwargs, ): super().__init__( hidden_size=hidden_size, intermediate_size=intermediate_size, num_hidden_layers=num_hidden_layers, num_attention_heads=num_attention_heads, num_channels=num_channels, image_size=image_size, patch_size=patch_size, hidden_act=hidden_act, layer_norm_eps=layer_norm_eps, attention_dropout=attention_dropout, **kwargs) self.rope = rope class RotaryEmbedding2D(torch.nn.Module): def __init__(self, dim, max_h, max_w, base=10000): super().__init__() freq = torch.arange(0, dim, 2, dtype=torch.int64).float() / dim inv_freq = 1.0 / (base ** freq) grid_h = torch.arange(0, max_h) grid_h = grid_h.to(inv_freq.dtype) grid_h = grid_h[:, None].repeat(1, max_w) grid_w = torch.arange(0, max_w) grid_w = grid_w.to(inv_freq.dtype) grid_w = grid_w[None, :].repeat(max_h, 1) cos_h, sin_h = self._forward_one_side(grid_h, inv_freq) cos_w, sin_w = self._forward_one_side(grid_w, inv_freq) self.register_buffer("cos_h", cos_h) self.register_buffer("sin_h", sin_h) self.register_buffer("cos_w", cos_w) self.register_buffer("sin_w", sin_w) def _forward_one_side(self, grid, inv_freq): freqs = grid[..., None] * inv_freq[None, None, :] emb = torch.cat((freqs, freqs), dim=-1).flatten(0, 1) return emb.cos(), emb.sin() def rotate_half(x): x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin): # unsqueeze due to the head dimension cos = cos.unsqueeze(1) sin = sin.unsqueeze(1) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed class SiglipVisionEmbeddings(nn.Module): def __init__(self, config: SiglipVisionConfig): super().__init__() self.config = config self.embed_dim = config.hidden_size self.image_size = config.image_size self.patch_size = config.patch_size self.patch_embedding = nn.Conv2d( in_channels=config.num_channels, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, padding="valid", ) self.num_patches_per_side = self.image_size // self.patch_size self.num_patches = self.num_patches_per_side**2 self.num_positions = self.num_patches if not config.rope: self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) def convert_conv2d_to_linear(self, config, meta=False): if meta: linear_patch_embedding = nn.Linear( config.num_channels * self.patch_size ** 2, self.embed_dim, bias=True, device='meta' ) else: linear_patch_embedding = nn.Linear( config.num_channels * self.patch_size ** 2, self.embed_dim, bias=True ) W = self.patch_embedding.weight.permute(0, 2, 3, 1).reshape( self.embed_dim, config.num_channels * self.patch_size ** 2 ) linear_patch_embedding.weight.data = W linear_patch_embedding.bias.data = self.patch_embedding.bias.data del self.patch_embedding self.patch_embedding = linear_patch_embedding def forward( self, packed_pixel_values: torch.FloatTensor, packed_flattened_position_ids: torch.LongTensor ) -> torch.Tensor: patch_embeds = self.patch_embedding(packed_pixel_values) if not self.config.rope: embeddings = patch_embeds + self.position_embedding(packed_flattened_position_ids) else: embeddings = patch_embeds return embeddings class SiglipFlashAttention2(SiglipAttention): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.IntTensor, max_seqlen: int, cos_h: torch.Tensor = None, sin_h: torch.Tensor = None, cos_w: torch.Tensor = None, sin_w: torch.Tensor = None, **kwargs, ) -> torch.Tensor: total_q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view(total_q_len, self.num_heads, self.head_dim) key_states = key_states.view(total_q_len, self.num_heads, self.head_dim) value_states = value_states.view(total_q_len, self.num_heads, self.head_dim) if self.config.rope: qh, qw = query_states[:, :, :self.head_dim // 2], query_states[:, :, self.head_dim // 2:] kh, kw = key_states[:, :, :self.head_dim // 2], key_states[:, :, self.head_dim // 2:] qh, kh = apply_rotary_pos_emb(qh, kh, cos_h, sin_h) qw, kw = apply_rotary_pos_emb(qw, kw, cos_w, sin_w) query_states = torch.cat([qh, qw], dim=-1) key_states = torch.cat([kh, kw], dim=-1) attn_output = flash_attn_varlen_func( query_states.to(torch.bfloat16), key_states.to(torch.bfloat16), value_states.to(torch.bfloat16), cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen, causal=False, ) attn_output = self.out_proj(attn_output.reshape(total_q_len, -1)) return attn_output class SiglipMLP(nn.Module): def __init__(self, config): super().__init__() self.config = config self.activation_fn = ACT2FN[config.hidden_act] self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.fc1(hidden_states) hidden_states = self.activation_fn(hidden_states) hidden_states = self.fc2(hidden_states) return hidden_states class SiglipEncoderLayer(nn.Module): def __init__(self, config: SiglipVisionConfig): super().__init__() self.embed_dim = config.hidden_size self.self_attn = SiglipFlashAttention2(config) self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = SiglipMLP(config) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.IntTensor, max_seqlen: int, cos_h: torch.Tensor = None, sin_h: torch.Tensor = None, cos_w: torch.Tensor = None, sin_w: torch.Tensor = None ) -> torch.Tensor: residual = hidden_states hidden_states = self.layer_norm1(hidden_states) hidden_states = self.self_attn( hidden_states=hidden_states, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, cos_h=cos_h, sin_h=sin_h, cos_w=cos_w, sin_w=sin_w ) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.layer_norm2(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states class SiglipEncoder(nn.Module): def __init__(self, config: SiglipVisionConfig): super().__init__() self.config = config self.layers = nn.ModuleList( [SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)] ) def forward( self, inputs_embeds: torch.Tensor, cu_seqlens: torch.IntTensor, max_seqlen: int, cos_h: torch.Tensor = None, sin_h: torch.Tensor = None, cos_w: torch.Tensor = None, sin_w: torch.Tensor = None, ) -> torch.Tensor: hidden_states = inputs_embeds for encoder_layer in self.layers: hidden_states = encoder_layer(hidden_states, cu_seqlens, max_seqlen, cos_h=cos_h, sin_h=sin_h, cos_w=cos_w, sin_w=sin_w) return hidden_states class SiglipVisionTransformer(nn.Module): def __init__(self, config: SiglipVisionConfig): super().__init__() self.config = config embed_dim = config.hidden_size self.embeddings = SiglipVisionEmbeddings(config) if config.rope: max_size = config.image_size // config.patch_size dim_head = config.hidden_size // config.num_attention_heads self.rope = RotaryEmbedding2D(dim_head // 2, max_size, max_size) self.encoder = SiglipEncoder(config) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) def forward( self, packed_pixel_values: torch.Tensor, packed_flattened_position_ids: torch.LongTensor, cu_seqlens: torch.IntTensor, max_seqlen: int, ) -> torch.Tensor: hidden_states = self.embeddings( packed_pixel_values=packed_pixel_values, packed_flattened_position_ids=packed_flattened_position_ids ) extra_inputs = {} if self.config.rope: extra_inputs.update( cos_h = self.rope.cos_h[packed_flattened_position_ids], sin_h = self.rope.sin_h[packed_flattened_position_ids], cos_w = self.rope.cos_w[packed_flattened_position_ids], sin_w = self.rope.sin_w[packed_flattened_position_ids] ) last_hidden_state = self.encoder( inputs_embeds=hidden_states, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, **extra_inputs ) last_hidden_state = self.post_layernorm(last_hidden_state) return last_hidden_state class SiglipVisionModel(SiglipPreTrainedModel): config_class = SiglipVisionConfig main_input_name = "packed_pixel_values" def __init__(self, config: SiglipVisionConfig): super().__init__(config) self.vision_model = SiglipVisionTransformer(config) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self) -> nn.Module: return self.vision_model.embeddings.patch_embedding def forward( self, packed_pixel_values: torch.Tensor, packed_flattened_position_ids: torch.LongTensor, cu_seqlens: torch.IntTensor, max_seqlen: int, ) -> torch.Tensor: return self.vision_model( packed_pixel_values=packed_pixel_values, packed_flattened_position_ids=packed_flattened_position_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, )