Atah Alam commited on
Commit
5838aa1
·
1 Parent(s): 2e2a204

Updated py files

Browse files
__init__.py CHANGED
@@ -1,7 +1,21 @@
1
- from .configuration_sapnous import SapnousConfig
2
- from .modeling_sapnous import SapnousModel
 
 
 
 
 
 
 
 
 
 
 
 
3
  from typing import TYPE_CHECKING
4
  from transformers.utils import _LazyModule
 
 
5
 
6
  _import_structure = {
7
  "configuration_sapnous": ["SAPNOUS_PRETRAINED_CONFIG_ARCHIVE_MAP", "SapnousT1Config"],
@@ -13,4 +27,13 @@ if TYPE_CHECKING:
13
  from .modeling_sapnous import SapnousT1Model, SapnousT1ForCausalLM
14
  else:
15
  import sys
16
- sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025-present, the HuggingFace Inc. Team and AIRAS Inc. Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
  from typing import TYPE_CHECKING
16
  from transformers.utils import _LazyModule
17
+ from transformers.models.auto import CONFIG_MAPPING, MODEL_MAPPING, MODEL_FOR_CAUSAL_LM_MAPPING
18
+ from transformers.models.auto import AutoConfig, AutoModel, AutoModelForCausalLM
19
 
20
  _import_structure = {
21
  "configuration_sapnous": ["SAPNOUS_PRETRAINED_CONFIG_ARCHIVE_MAP", "SapnousT1Config"],
 
27
  from .modeling_sapnous import SapnousT1Model, SapnousT1ForCausalLM
28
  else:
29
  import sys
30
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
31
+
32
+ # Register model in auto classes
33
+ CONFIG_MAPPING["sapnous_t1"] = SapnousT1Config
34
+ MODEL_MAPPING["sapnous_t1"] = SapnousT1Model
35
+ MODEL_FOR_CAUSAL_LM_MAPPING["sapnous_t1"] = SapnousT1ForCausalLM
36
+
37
+ AutoConfig.register("sapnous_t1", SapnousT1Config)
38
+ AutoModel.register(SapnousT1Config, SapnousT1Model)
39
+ AutoModelForCausalLM.register(SapnousT1Config, SapnousT1ForCausalLM)
attention_sapnous.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025-present, the HuggingFace Inc. Team and AIRAS Inc. Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import math
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ from typing import Optional, Tuple
20
+
21
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
22
+ """Precompute the frequency tensor for complex rotation."""
23
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
24
+ t = torch.arange(end, device=freqs.device)
25
+ freqs = torch.outer(t, freqs)
26
+ return torch.polar(torch.ones_like(freqs), freqs)
27
+
28
+ def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
29
+ """Apply rotary position embeddings to the input tensor."""
30
+ x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
31
+ freqs_cis = freqs_cis.view(1, *freqs_cis.shape)
32
+ x_rotated = x_complex * freqs_cis
33
+ return torch.view_as_real(x_rotated).flatten(-2)
34
+
35
+ class SapnousAttention(nn.Module):
36
+ """Multi-head attention with rotary position embeddings and sliding window attention."""
37
+ def __init__(self, config):
38
+ super().__init__()
39
+ self.config = config
40
+ self.hidden_size = config.hidden_size
41
+ self.num_attention_heads = config.num_attention_heads
42
+ self.head_dim = self.hidden_size // self.num_attention_heads
43
+ self.num_key_value_heads = config.num_key_value_heads
44
+ self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
45
+ self.max_position_embeddings = config.max_position_embeddings
46
+ self.rope_theta = config.rope_theta
47
+ self.sliding_window = config.sliding_window if config.use_sliding_window else None
48
+
49
+ if (self.head_dim * self.num_attention_heads) != self.hidden_size:
50
+ raise ValueError(
51
+ f"hidden_size must be divisible by num_attention_heads (got {self.hidden_size} and {self.num_attention_heads})"
52
+ )
53
+
54
+ self.q_proj = nn.Linear(self.hidden_size, self.num_attention_heads * self.head_dim, bias=False)
55
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
56
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
57
+ self.o_proj = nn.Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias=False)
58
+
59
+ self.attention_dropout = nn.Dropout(config.attention_dropout)
60
+
61
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int) -> torch.Tensor:
62
+ return tensor.view(bsz, seq_len, self.num_attention_heads, self.head_dim).transpose(1, 2)
63
+
64
+ def _kv_shape(self, tensor: torch.Tensor, seq_len: int, bsz: int) -> torch.Tensor:
65
+ return tensor.view(bsz, seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
66
+
67
+ def forward(
68
+ self,
69
+ hidden_states: torch.Tensor,
70
+ freqs_cis: torch.Tensor,
71
+ attention_mask: Optional[torch.Tensor] = None,
72
+ position_ids: Optional[torch.LongTensor] = None,
73
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
74
+ output_attentions: bool = False,
75
+ use_cache: bool = False,
76
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
77
+ bsz, q_len, _ = hidden_states.size()
78
+
79
+ query_states = self.q_proj(hidden_states)
80
+ key_states = self.k_proj(hidden_states)
81
+ value_states = self.v_proj(hidden_states)
82
+
83
+ query_states = self._shape(query_states, q_len, bsz)
84
+ key_states = self._kv_shape(key_states, q_len, bsz)
85
+ value_states = self._kv_shape(value_states, q_len, bsz)
86
+
87
+ kv_seq_len = key_states.shape[-2]
88
+ if past_key_value is not None:
89
+ kv_seq_len += past_key_value[0].shape[-2]
90
+
91
+ # Apply rotary position embeddings
92
+ if position_ids is None:
93
+ position_ids = torch.arange(kv_seq_len, device=hidden_states.device)
94
+ cos, sin = freqs_cis[position_ids]
95
+ query_states, key_states = apply_rotary_emb(query_states, cos), apply_rotary_emb(key_states, sin)
96
+
97
+ if past_key_value is not None:
98
+ # Reuse k, v, self_attention
99
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
100
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
101
+
102
+ past_key_value = (key_states, value_states) if use_cache else None
103
+
104
+ # Repeat k/v heads if n_kv_heads < n_heads
105
+ key_states = torch.repeat_interleave(key_states, self.num_key_value_groups, dim=1)
106
+ value_states = torch.repeat_interleave(value_states, self.num_key_value_groups, dim=1)
107
+
108
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
109
+
110
+ if attention_mask is not None:
111
+ attn_weights = attn_weights + attention_mask
112
+
113
+ # Sliding window attention if configured
114
+ if self.sliding_window is not None and kv_seq_len > self.sliding_window:
115
+ # Create sliding window mask
116
+ window_mask = torch.ones_like(attn_weights, dtype=torch.bool)
117
+ for i in range(q_len):
118
+ window_start = max(0, i - self.sliding_window // 2)
119
+ window_end = min(kv_seq_len, i + self.sliding_window // 2)
120
+ window_mask[:, :, i, window_start:window_end] = False
121
+ attn_weights = attn_weights.masked_fill(window_mask, float('-inf'))
122
+
123
+ # Causal mask for autoregressive generation
124
+ if self.config.scoring_func == "softmax":
125
+ causal_mask = torch.triu(torch.ones((q_len, kv_seq_len), dtype=torch.bool), diagonal=1)
126
+ causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)
127
+ attn_weights = attn_weights.masked_fill(causal_mask.to(attn_weights.device), float('-inf'))
128
+ attn_weights = F.softmax(attn_weights, dim=-1)
129
+ else:
130
+ # Alternative scoring functions (e.g., RoPE-only, cosine similarity)
131
+ attn_weights = F.relu(attn_weights)
132
+ attn_weights = attn_weights / (attn_weights.sum(dim=-1, keepdim=True) + 1e-6)
133
+
134
+ attn_weights = self.attention_dropout(attn_weights)
135
+ attn_output = torch.matmul(attn_weights, value_states)
136
+
137
+ attn_output = attn_output.transpose(1, 2).contiguous()
138
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
139
+
140
+ attn_output = self.o_proj(attn_output)
141
+
142
+ if not output_attentions:
143
+ attn_weights = None
144
+
145
+ return attn_output, attn_weights, past_key_value
146
+
147
+ class SapnousBlock(nn.Module):
148
+ """Transformer block with attention, layer norm, and feed-forward network."""
149
+ def __init__(self, config):
150
+ super().__init__()
151
+ self.hidden_size = config.hidden_size
152
+ self.self_attn = SapnousAttention(config)
153
+ self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
154
+ self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
155
+
156
+ self.mlp = nn.Sequential(
157
+ nn.Linear(config.hidden_size, config.intermediate_size, bias=False),
158
+ nn.SiLU(),
159
+ nn.Linear(config.intermediate_size, config.hidden_size, bias=False),
160
+ )
161
+
162
+ def forward(
163
+ self,
164
+ hidden_states: torch.Tensor,
165
+ freqs_cis: torch.Tensor,
166
+ attention_mask: Optional[torch.Tensor] = None,
167
+ position_ids: Optional[torch.LongTensor] = None,
168
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
169
+ output_attentions: bool = False,
170
+ use_cache: bool = False,
171
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
172
+ # Self Attention
173
+ residual = hidden_states
174
+ hidden_states = self.input_layernorm(hidden_states)
175
+
176
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
177
+ hidden_states=hidden_states,
178
+ freqs_cis=freqs_cis,
179
+ attention_mask=attention_mask,
180
+ position_ids=position_ids,
181
+ past_key_value=past_key_value,
182
+ output_attentions=output_attentions,
183
+ use_cache=use_cache,
184
+ )
185
+ hidden_states = residual + hidden_states
186
+
187
+ # Fully Connected
188
+ residual = hidden_states
189
+ hidden_states = self.post_attention_layernorm(hidden_states)
190
+ hidden_states = self.mlp(hidden_states)
191
+ hidden_states = residual + hidden_states
192
+
193
+ outputs = (hidden_states,)
194
+
195
+ if output_attentions:
196
+ outputs += (self_attn_weights,)
197
+
198
+ if use_cache:
199
+ outputs += (present_key_value,)
200
+
201
+ return outputs
202
+
203
+ class SapnousVisionEmbeddings(nn.Module):
204
+ """Vision embeddings for multimodal support."""
205
+ def __init__(self, config):
206
+ super().__init__()
207
+ self.config = config
208
+ self.hidden_size = config.hidden_size
209
+
210
+ # Vision embedding layers
211
+ self.patch_embed = nn.Conv2d(3, self.hidden_size, kernel_size=16, stride=16)
212
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, self.hidden_size))
213
+ self.pos_embed = nn.Parameter(torch.zeros(1, (224 // 16) ** 2 + 1, self.hidden_size))
214
+
215
+ # Layer normalization and dropout
216
+ self.norm = nn.LayerNorm(self.hidden_size, eps=config.rms_norm_eps)
217
+ self.dropout = nn.Dropout(config.attention_dropout)
218
+
219
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
220
+ B = pixel_values.shape[0]
221
+
222
+ # Create patch embeddings
223
+ x = self.patch_embed(pixel_values)
224
+ x = x.flatten(2).transpose(1, 2) # B, N, C
225
+
226
+ # Add cls token and position embeddings
227
+ cls_tokens = self.cls_token.expand(B, -1, -1)
228
+ x = torch.cat((cls_tokens, x), dim=1)
229
+ x = x + self.pos_embed
230
+
231
+ # Apply normalization and dropout
232
+ x = self.norm(x)
233
+ x = self.dropout(x)
234
+
235
+ return x
configuration_sapnous.py CHANGED
@@ -1,11 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from transformers.configuration_utils import PretrainedConfig
2
  from transformers.utils import logging
3
- from transformers import AutoConfig # ✅ Correct Import
4
 
5
  logger = logging.get_logger(__name__)
6
 
7
  SAPNOUS_PRETRAINED_CONFIG_ARCHIVE_MAP = {
8
- "Sapnous-AI/Sapnous-6B": "https://huggingface.co/Sapnous-AI/Sapnous-6B/resolve/main/config.json",
9
  }
10
 
11
  class SapnousT1Config(PretrainedConfig):
 
1
+ # coding=utf-8
2
+ # Copyright 2025-present, the HuggingFace Inc. Team and AIRAS Inc. Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
  from transformers.configuration_utils import PretrainedConfig
16
  from transformers.utils import logging
17
+ from transformers import AutoConfig
18
 
19
  logger = logging.get_logger(__name__)
20
 
21
  SAPNOUS_PRETRAINED_CONFIG_ARCHIVE_MAP = {
22
+ "Sapnous-AI/Sapnous-VR-6B": "https://huggingface.co/Sapnous-AI/Sapnous-VR-6B/resolve/main/config.json",
23
  }
24
 
25
  class SapnousT1Config(PretrainedConfig):
convert_to_gguf.py CHANGED
@@ -1,7 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import torch
 
 
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from ctransformers import AutoModelForCausalLM as GGUFModel
 
5
 
6
  def convert_to_gguf(model_path, output_path):
7
  # Load the model and tokenizer with vision-language support
@@ -16,29 +33,86 @@ def convert_to_gguf(model_path, output_path):
16
  trust_remote_code=True
17
  )
18
 
19
- # Save in GGUF format
 
 
 
 
 
20
  model.save_pretrained(output_path, safe_serialization=True)
21
  tokenizer.save_pretrained(output_path)
22
 
23
- # Convert to GGUF using ctransformers with SapnousT1 architecture settings
24
  gguf_model = GGUFModel.from_pretrained(
25
  output_path,
26
- model_type='llama', # Base architecture type
27
  gpu_layers=0, # CPU only for conversion
28
  config={
29
- 'context_length': 32768, # Match model's sliding window size
30
- 'attention_type': 'multiquery', # For efficient attention
31
- 'num_attention_heads': 40, # Match model's head count
32
- 'num_key_value_heads': 8, # Match model's KV head count
33
- 'hidden_size': 5120, # Match model's hidden size
34
- 'intermediate_size': 20480, # Match model's intermediate size
35
- 'max_position_embeddings': 128000 # Match model's max positions
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  }
37
  )
38
 
39
  print(f"Model converted and saved to {output_path}")
40
  return gguf_model
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  if __name__ == '__main__':
43
  model_path = os.path.dirname(os.path.abspath(__file__))
44
  output_path = os.path.join(model_path, 'gguf_model')
 
1
+ # coding=utf-8
2
+ # Copyright 2025-present, the HuggingFace Inc. Team and AIRAS Inc. Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
  import os
16
  import torch
17
+ import json
18
+ from pathlib import Path
19
  from transformers import AutoModelForCausalLM, AutoTokenizer
20
  from ctransformers import AutoModelForCausalLM as GGUFModel
21
+ from models.sapnous import SapnousT1Config
22
 
23
  def convert_to_gguf(model_path, output_path):
24
  # Load the model and tokenizer with vision-language support
 
33
  trust_remote_code=True
34
  )
35
 
36
+ # Get model configuration
37
+ config = model.config
38
+ if not isinstance(config, SapnousT1Config):
39
+ raise ValueError("Model must be a SapnousT1 model")
40
+
41
+ # Save in intermediate format
42
  model.save_pretrained(output_path, safe_serialization=True)
43
  tokenizer.save_pretrained(output_path)
44
 
45
+ # Convert to GGUF using custom SapnousT1 architecture settings
46
  gguf_model = GGUFModel.from_pretrained(
47
  output_path,
48
+ model_type='sapnous_t1', # Custom architecture type
49
  gpu_layers=0, # CPU only for conversion
50
  config={
51
+ 'context_length': config.sliding_window,
52
+ 'attention_type': 'multihead', # Custom attention implementation
53
+ 'num_attention_heads': config.num_attention_heads,
54
+ 'num_key_value_heads': config.num_key_value_heads,
55
+ 'hidden_size': config.hidden_size,
56
+ 'intermediate_size': config.intermediate_size,
57
+ 'max_position_embeddings': config.max_position_embeddings,
58
+ 'vocab_size': config.vocab_size,
59
+ 'num_hidden_layers': config.num_hidden_layers,
60
+ 'rms_norm_eps': config.rms_norm_eps,
61
+ 'rope_theta': config.rope_theta,
62
+ # Vision model parameters
63
+ 'vision_config': {
64
+ 'hidden_size': config.vision_hidden_size,
65
+ 'num_hidden_layers': config.vision_layers,
66
+ 'num_attention_heads': config.vision_heads,
67
+ 'intermediate_size': config.vision_intermediate_size,
68
+ 'patch_size': config.patch_size,
69
+ 'image_size': config.image_size
70
+ }
71
  }
72
  )
73
 
74
  print(f"Model converted and saved to {output_path}")
75
  return gguf_model
76
 
77
+ def convert_to_hf(gguf_path, output_path):
78
+ """Convert GGUF model back to Hugging Face format"""
79
+ # Load GGUF model configuration
80
+ config_path = Path(gguf_path) / "config.json"
81
+ with open(config_path, 'r') as f:
82
+ gguf_config = json.load(f)
83
+
84
+ # Create SapnousT1 configuration
85
+ config = SapnousT1Config(
86
+ vocab_size=gguf_config['vocab_size'],
87
+ hidden_size=gguf_config['hidden_size'],
88
+ num_hidden_layers=gguf_config['num_hidden_layers'],
89
+ num_attention_heads=gguf_config['num_attention_heads'],
90
+ num_key_value_heads=gguf_config['num_key_value_heads'],
91
+ intermediate_size=gguf_config['intermediate_size'],
92
+ max_position_embeddings=gguf_config['max_position_embeddings'],
93
+ rms_norm_eps=gguf_config['rms_norm_eps'],
94
+ rope_theta=gguf_config['rope_theta'],
95
+ # Vision configuration
96
+ vision_hidden_size=gguf_config['vision_config']['hidden_size'],
97
+ vision_layers=gguf_config['vision_config']['num_hidden_layers'],
98
+ vision_heads=gguf_config['vision_config']['num_attention_heads'],
99
+ vision_intermediate_size=gguf_config['vision_config']['intermediate_size'],
100
+ patch_size=gguf_config['vision_config']['patch_size'],
101
+ image_size=gguf_config['vision_config']['image_size']
102
+ )
103
+
104
+ # Load GGUF model
105
+ gguf_model = GGUFModel.from_pretrained(gguf_path)
106
+
107
+ # Convert weights to HF format
108
+ model = AutoModelForCausalLM.from_config(config)
109
+ model.load_state_dict(gguf_model.state_dict())
110
+
111
+ # Save converted model
112
+ model.save_pretrained(output_path)
113
+ print(f"Model converted back to Hugging Face format at {output_path}")
114
+ return model
115
+
116
  if __name__ == '__main__':
117
  model_path = os.path.dirname(os.path.abspath(__file__))
118
  output_path = os.path.join(model_path, 'gguf_model')
model.py CHANGED
@@ -1,3 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from transformers import PreTrainedModel, AutoConfig
2
  import torch
3
  import torch.nn as nn
 
1
+ # coding=utf-8
2
+ # Copyright 2025-present, the HuggingFace Inc. Team and AIRAS Inc. Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
  from transformers import PreTrainedModel, AutoConfig
16
  import torch
17
  import torch.nn as nn
modeling_sapnous.py CHANGED
@@ -1,53 +1,271 @@
 
1
  import torch
2
  import torch.nn as nn
 
 
3
  from transformers import PreTrainedModel, AutoModelForCausalLM
4
- from configuration_sapnous import SapnousT1Config # Ensure this file is correct
 
 
5
 
6
  class SapnousT1PreTrainedModel(PreTrainedModel):
7
  """Base class for all Sapnous-T1 models."""
8
  config_class = SapnousT1Config
 
9
 
10
  def __init__(self, config: SapnousT1Config):
11
  super().__init__(config)
12
  self.config = config
13
 
14
  def _init_weights(self, module):
15
- """Initialize weights if required."""
 
16
  if isinstance(module, nn.Linear):
17
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
18
  if module.bias is not None:
19
  module.bias.data.zero_()
 
 
 
 
 
 
 
 
 
 
20
 
21
  class SapnousT1Model(SapnousT1PreTrainedModel):
22
- """Base Transformer Model"""
23
  def __init__(self, config: SapnousT1Config):
24
  super().__init__(config)
 
25
  self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
26
- self.encoder = nn.TransformerEncoder(
27
- nn.TransformerEncoderLayer(
28
- d_model=config.hidden_size,
29
- nhead=config.num_attention_heads
30
- ),
31
- num_layers=config.num_hidden_layers
 
 
 
 
 
 
 
 
32
  )
33
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
34
 
35
- def forward(self, input_ids):
36
- x = self.embeddings(input_ids)
37
- x = self.encoder(x)
38
- return self.lm_head(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  class SapnousT1ForCausalLM(SapnousT1PreTrainedModel):
41
- """Sapnous-T1 Model for Causal LM (Text Generation)"""
 
 
42
  def __init__(self, config: SapnousT1Config):
43
  super().__init__(config)
44
  self.model = SapnousT1Model(config)
45
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
46
 
47
- def forward(self, input_ids):
48
- hidden_states = self.model(input_ids)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  logits = self.lm_head(hidden_states)
50
- return logits
51
 
52
- # Register the model properly
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  AutoModelForCausalLM.register(SapnousT1Config, SapnousT1ForCausalLM)
 
1
+ import math
2
  import torch
3
  import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from typing import Optional, Tuple, List, Union
6
  from transformers import PreTrainedModel, AutoModelForCausalLM
7
+ from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutputWithPast
8
+ from .configuration_sapnous import SapnousT1Config
9
+ from .attention_sapnous import SapnousAttention, SapnousBlock, SapnousVisionEmbeddings, precompute_freqs_cis
10
 
11
  class SapnousT1PreTrainedModel(PreTrainedModel):
12
  """Base class for all Sapnous-T1 models."""
13
  config_class = SapnousT1Config
14
+ base_model_prefix = "sapnous"
15
 
16
  def __init__(self, config: SapnousT1Config):
17
  super().__init__(config)
18
  self.config = config
19
 
20
  def _init_weights(self, module):
21
+ """Initialize weights using the model's initialization configuration."""
22
+ std = self.config.initializer_range
23
  if isinstance(module, nn.Linear):
24
+ module.weight.data.normal_(mean=0.0, std=std)
25
  if module.bias is not None:
26
  module.bias.data.zero_()
27
+ elif isinstance(module, nn.Embedding):
28
+ module.weight.data.normal_(mean=0.0, std=std)
29
+ elif isinstance(module, nn.LayerNorm):
30
+ module.bias.data.zero_()
31
+ module.weight.data.fill_(1.0)
32
+ elif isinstance(module, SapnousAttention):
33
+ module.q_proj.weight.data.normal_(mean=0.0, std=std)
34
+ module.k_proj.weight.data.normal_(mean=0.0, std=std)
35
+ module.v_proj.weight.data.normal_(mean=0.0, std=std)
36
+ module.o_proj.weight.data.normal_(mean=0.0, std=std)
37
 
38
  class SapnousT1Model(SapnousT1PreTrainedModel):
39
+ """Base Transformer Model with advanced attention mechanisms and optional vision support."""
40
  def __init__(self, config: SapnousT1Config):
41
  super().__init__(config)
42
+
43
  self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
44
+ self.layers = nn.ModuleList([SapnousBlock(config) for _ in range(config.num_hidden_layers)])
45
+ self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
46
+
47
+ # Vision support
48
+ self.vision_embed = SapnousVisionEmbeddings(config) if getattr(config, 'vision_config', None) else None
49
+
50
+ # Initialize weights and apply final processing
51
+ self.post_init()
52
+
53
+ # Compute and cache RoPE frequencies
54
+ self.freqs_cis = precompute_freqs_cis(
55
+ self.config.hidden_size // self.config.num_attention_heads,
56
+ self.config.max_position_embeddings,
57
+ self.config.rope_theta,
58
  )
 
59
 
60
+ def get_input_embeddings(self) -> nn.Module:
61
+ return self.embeddings
62
+
63
+ def set_input_embeddings(self, value: nn.Module):
64
+ self.embeddings = value
65
+
66
+ def forward(
67
+ self,
68
+ input_ids: Optional[torch.LongTensor] = None,
69
+ attention_mask: Optional[torch.Tensor] = None,
70
+ position_ids: Optional[torch.LongTensor] = None,
71
+ past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
72
+ inputs_embeds: Optional[torch.FloatTensor] = None,
73
+ pixel_values: Optional[torch.FloatTensor] = None,
74
+ use_cache: Optional[bool] = None,
75
+ output_attentions: Optional[bool] = None,
76
+ output_hidden_states: Optional[bool] = None,
77
+ return_dict: Optional[bool] = None,
78
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
79
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
80
+ output_hidden_states = (
81
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
82
+ )
83
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
84
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
85
+
86
+ if input_ids is not None and inputs_embeds is not None:
87
+ raise ValueError("You cannot specify both input_ids and inputs_embeds")
88
+
89
+ # Process text input
90
+ if input_ids is not None:
91
+ inputs_embeds = self.embeddings(input_ids)
92
+ batch_size, seq_length = input_ids.shape[:2]
93
+ else:
94
+ batch_size, seq_length = inputs_embeds.shape[:2]
95
+
96
+ # Process vision input if available
97
+ if pixel_values is not None and self.vision_embed is not None:
98
+ vision_embeds = self.vision_embed(pixel_values)
99
+ inputs_embeds = torch.cat([vision_embeds, inputs_embeds], dim=1)
100
+ seq_length = inputs_embeds.shape[1]
101
+
102
+ if position_ids is None:
103
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
104
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
105
+ position_ids = position_ids.unsqueeze(0)
106
+
107
+ # Prepare attention mask
108
+ if attention_mask is not None:
109
+ attention_mask = attention_mask.view(batch_size, -1)
110
+ attention_mask = attention_mask[:, None, None, :]
111
+ attention_mask = attention_mask.to(dtype=inputs_embeds.dtype)
112
+ attention_mask = (1.0 - attention_mask) * torch.finfo(inputs_embeds.dtype).min
113
+
114
+ freqs_cis = self.freqs_cis.to(inputs_embeds.device)
115
+
116
+ hidden_states = inputs_embeds
117
+ all_hidden_states = () if output_hidden_states else None
118
+ all_self_attns = () if output_attentions else None
119
+ next_decoder_cache = () if use_cache else None
120
+
121
+ for idx, decoder_layer in enumerate(self.layers):
122
+ if output_hidden_states:
123
+ all_hidden_states += (hidden_states,)
124
+
125
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
126
+
127
+ layer_outputs = decoder_layer(
128
+ hidden_states,
129
+ freqs_cis=freqs_cis,
130
+ attention_mask=attention_mask,
131
+ position_ids=position_ids,
132
+ past_key_value=past_key_value,
133
+ output_attentions=output_attentions,
134
+ use_cache=use_cache,
135
+ )
136
+
137
+ hidden_states = layer_outputs[0]
138
+
139
+ if use_cache:
140
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
141
+
142
+ if output_attentions:
143
+ all_self_attns += (layer_outputs[1],)
144
+
145
+ hidden_states = self.norm(hidden_states)
146
+
147
+ if output_hidden_states:
148
+ all_hidden_states += (hidden_states,)
149
+
150
+ if not return_dict:
151
+ return tuple(v for v in [
152
+ hidden_states,
153
+ next_decoder_cache,
154
+ all_hidden_states,
155
+ all_self_attns,
156
+ ] if v is not None)
157
+
158
+ return BaseModelOutputWithPast(
159
+ last_hidden_state=hidden_states,
160
+ past_key_values=next_decoder_cache,
161
+ hidden_states=all_hidden_states,
162
+ attentions=all_self_attns,
163
+ )
164
 
165
  class SapnousT1ForCausalLM(SapnousT1PreTrainedModel):
166
+ """Sapnous-T1 Model for Causal Language Modeling with vision support."""
167
+ _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
168
+
169
  def __init__(self, config: SapnousT1Config):
170
  super().__init__(config)
171
  self.model = SapnousT1Model(config)
172
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
173
 
174
+ # Initialize weights and apply final processing
175
+ self.post_init()
176
+
177
+ def get_input_embeddings(self) -> nn.Module:
178
+ return self.model.embeddings
179
+
180
+ def set_input_embeddings(self, value: nn.Module):
181
+ self.model.embeddings = value
182
+
183
+ def get_output_embeddings(self) -> nn.Module:
184
+ return self.lm_head
185
+
186
+ def set_output_embeddings(self, new_embeddings: nn.Module):
187
+ self.lm_head = new_embeddings
188
+
189
+ def prepare_inputs_for_generation(
190
+ self,
191
+ input_ids: torch.LongTensor,
192
+ past_key_values: Optional[List[Tuple[torch.Tensor]]] = None,
193
+ attention_mask: Optional[torch.Tensor] = None,
194
+ **kwargs,
195
+ ) -> dict:
196
+ if past_key_values:
197
+ input_ids = input_ids[:, -1:]
198
+
199
+ position_ids = kwargs.get("position_ids", None)
200
+ if position_ids is None:
201
+ position_ids = (attention_mask.long().cumsum(-1) - 1) if attention_mask is not None else None
202
+ if past_key_values:
203
+ position_ids = position_ids[:, -1].unsqueeze(-1)
204
+
205
+ return {
206
+ "input_ids": input_ids,
207
+ "attention_mask": attention_mask,
208
+ "position_ids": position_ids,
209
+ "past_key_values": past_key_values,
210
+ "use_cache": kwargs.get("use_cache"),
211
+ "pixel_values": kwargs.get("pixel_values", None),
212
+ }
213
+
214
+ def forward(
215
+ self,
216
+ input_ids: Optional[torch.LongTensor] = None,
217
+ attention_mask: Optional[torch.Tensor] = None,
218
+ position_ids: Optional[torch.LongTensor] = None,
219
+ past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
220
+ inputs_embeds: Optional[torch.FloatTensor] = None,
221
+ pixel_values: Optional[torch.FloatTensor] = None,
222
+ labels: Optional[torch.LongTensor] = None,
223
+ use_cache: Optional[bool] = None,
224
+ output_attentions: Optional[bool] = None,
225
+ output_hidden_states: Optional[bool] = None,
226
+ return_dict: Optional[bool] = None,
227
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
228
+ r"""Labels for computing the masked language modeling loss."""
229
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
230
+
231
+ outputs = self.model(
232
+ input_ids=input_ids,
233
+ attention_mask=attention_mask,
234
+ position_ids=position_ids,
235
+ past_key_values=past_key_values,
236
+ inputs_embeds=inputs_embeds,
237
+ pixel_values=pixel_values,
238
+ use_cache=use_cache,
239
+ output_attentions=output_attentions,
240
+ output_hidden_states=output_hidden_states,
241
+ return_dict=return_dict,
242
+ )
243
+
244
+ hidden_states = outputs[0]
245
  logits = self.lm_head(hidden_states)
 
246
 
247
+ loss = None
248
+ if labels is not None:
249
+ shift_logits = logits[..., :-1, :].contiguous()
250
+ shift_labels = labels[..., 1:].contiguous()
251
+ loss_fct = nn.CrossEntropyLoss()
252
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
253
+
254
+ if not return_dict:
255
+ output = (logits,) + outputs[1:]
256
+ return ((loss,) + output) if loss is not None else output
257
+
258
+ return CausalLMOutputWithPast(
259
+ loss=loss,
260
+ logits=logits,
261
+ past_key_values=outputs.past_key_values,
262
+ hidden_states=outputs.hidden_states,
263
+ attentions=outputs.attentions,
264
+ )
265
+
266
+ def tie_weights(self):
267
+ """Tie the weights between the input embeddings and the output embeddings."""
268
+ self.lm_head.weight = self.model.embeddings.weight
269
+
270
+ # Register the model
271
  AutoModelForCausalLM.register(SapnousT1Config, SapnousT1ForCausalLM)
models/sapnous/__init__.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025-present, the HuggingFace Inc. Team and AIRAS Inc. Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from typing import TYPE_CHECKING
16
+ from transformers.utils import _LazyModule
17
+ from transformers.models.auto import CONFIG_MAPPING, MODEL_MAPPING, MODEL_FOR_CAUSAL_LM_MAPPING
18
+ from transformers.models.auto import AutoConfig, AutoModel, AutoModelForCausalLM
19
+
20
+ _import_structure = {
21
+ "configuration_sapnous": ["SAPNOUS_PRETRAINED_CONFIG_ARCHIVE_MAP", "SapnousT1Config"],
22
+ "modeling_sapnous": ["SapnousT1Model", "SapnousT1ForCausalLM"],
23
+ "tokenization_sapnous": ["SapnousT1Tokenizer"],
24
+ }
25
+
26
+ if TYPE_CHECKING:
27
+ from .configuration_sapnous import SAPNOUS_PRETRAINED_CONFIG_ARCHIVE_MAP, SapnousT1Config
28
+ from .modeling_sapnous import SapnousT1Model, SapnousT1ForCausalLM
29
+ from .tokenization_sapnous import SapnousT1Tokenizer
30
+ else:
31
+ import sys
32
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
33
+
34
+ # Import configuration and models
35
+ from .configuration_sapnous import SapnousT1Config
36
+ from .modeling_sapnous import SapnousT1Model, SapnousT1ForCausalLM
37
+
38
+ # Register model in auto classes
39
+ CONFIG_MAPPING["sapnous_t1"] = SapnousT1Config
40
+ MODEL_MAPPING["sapnous_t1"] = SapnousT1Model
41
+ MODEL_FOR_CAUSAL_LM_MAPPING["sapnous_t1"] = SapnousT1ForCausalLM
models/sapnous/configuration_sapnous.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025-present, the HuggingFace Inc. Team and AIRAS Inc. Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from transformers.configuration_utils import PretrainedConfig
16
+ from transformers.utils import logging
17
+ from transformers import AutoConfig
18
+
19
+ logger = logging.get_logger(__name__)
20
+
21
+ SAPNOUS_PRETRAINED_CONFIG_ARCHIVE_MAP = {
22
+ "Sapnous-AI/Sapnous-6B": "https://huggingface.co/Sapnous-AI/Sapnous-6B/resolve/main/config.json",
23
+ }
24
+
25
+ class SapnousT1Config(PretrainedConfig):
26
+ """Configuration class for Sapnous-T1 model with vision-language capabilities.
27
+
28
+ This configuration class handles both text and vision modalities, supporting multimodal
29
+ tasks like image understanding, video processing, and vision-language reasoning.
30
+ """
31
+
32
+ model_type = "sapnous_t1"
33
+
34
+ def __init__(
35
+ self,
36
+ # Text model parameters
37
+ vocab_size=151936,
38
+ hidden_size=5120,
39
+ intermediate_size=20480,
40
+ num_hidden_layers=36,
41
+ num_attention_heads=40,
42
+ num_key_value_heads=8,
43
+ hidden_act="silu",
44
+ max_position_embeddings=128000,
45
+ initializer_range=0.02,
46
+ rms_norm_eps=1e-6,
47
+ use_cache=True,
48
+ pad_token_id=None,
49
+ bos_token_id=151643,
50
+ eos_token_id=151645,
51
+ tie_word_embeddings=True,
52
+
53
+ # Vision model parameters
54
+ vision_start_token_id=151652,
55
+ vision_end_token_id=151653,
56
+ vision_token_id=151654,
57
+ image_token_id=151655,
58
+ video_token_id=151656,
59
+ vision_config=None,
60
+ patch_size=14,
61
+ image_size=224,
62
+ num_channels=3,
63
+ vision_layers=24,
64
+ vision_heads=16,
65
+ vision_hidden_size=1024,
66
+ vision_intermediate_size=4096,
67
+ vision_act="gelu",
68
+ vision_layer_norm_eps=1e-5,
69
+ vision_dropout=0.0,
70
+ vision_attention_dropout=0.0,
71
+ vision_embedding_dropout=0.0,
72
+
73
+ # Cross-attention parameters
74
+ num_cross_attention_layers=12,
75
+ cross_attention_heads=16,
76
+ cross_attention_dropout=0.0,
77
+ use_cross_attention=True,
78
+
79
+ # Positional encoding and attention parameters
80
+ rope_theta=1000000.0,
81
+ sliding_window=32768,
82
+ use_sliding_window=False,
83
+ max_window_layers=70,
84
+ attention_dropout=0.0,
85
+ rope_scaling=None,
86
+ scoring_func="softmax",
87
+
88
+ # Training parameters
89
+ aux_loss_alpha=0.001,
90
+ seq_aux=True,
91
+ **kwargs
92
+ ):
93
+ super().__init__(
94
+ pad_token_id=pad_token_id,
95
+ bos_token_id=bos_token_id,
96
+ eos_token_id=eos_token_id,
97
+ tie_word_embeddings=tie_word_embeddings,
98
+ **kwargs,
99
+ )
100
+
101
+ self.vocab_size = vocab_size
102
+ self.max_position_embeddings = max_position_embeddings
103
+ self.hidden_size = hidden_size
104
+ self.intermediate_size = intermediate_size
105
+ self.num_hidden_layers = num_hidden_layers
106
+ self.num_attention_heads = num_attention_heads
107
+ self.num_key_value_heads = num_key_value_heads
108
+ self.hidden_act = hidden_act
109
+ self.initializer_range = initializer_range
110
+ self.rms_norm_eps = rms_norm_eps
111
+ self.use_cache = use_cache
112
+ self.vision_start_token_id = vision_start_token_id
113
+ self.vision_end_token_id = vision_end_token_id
114
+ self.vision_token_id = vision_token_id
115
+ self.image_token_id = image_token_id
116
+ self.video_token_id = video_token_id
117
+ self.vision_config = vision_config
118
+ self.rope_theta = rope_theta
119
+ self.sliding_window = sliding_window
120
+ self.use_sliding_window = use_sliding_window
121
+ self.max_window_layers = max_window_layers
122
+ self.attention_dropout = attention_dropout
123
+ self.rope_scaling = rope_scaling
124
+ self.scoring_func = scoring_func
125
+ self.aux_loss_alpha = aux_loss_alpha
126
+ self.seq_aux = seq_aux
127
+
128
+ model_type = "sapnous_t1"
129
+ keys_to_ignore_at_inference = ["past_key_values"]
130
+
131
+ AutoConfig.register("sapnous_t1", SapnousT1Config)
models/sapnous/modeling_sapnous.py ADDED
@@ -0,0 +1,535 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025-present, the HuggingFace Inc. Team and AIRAS Inc. Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from typing import List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ from torch.nn import CrossEntropyLoss
21
+
22
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
23
+ from transformers.modeling_utils import PreTrainedModel
24
+ from transformers.utils import logging
25
+ from transformers import AutoModelForCausalLM, AutoModel
26
+
27
+ from .configuration_sapnous import SapnousT1Config
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+ class SapnousT1Attention(nn.Module):
32
+ def __init__(self, config: SapnousT1Config):
33
+ super().__init__()
34
+ self.config = config
35
+ self.hidden_size = config.hidden_size
36
+ self.num_heads = config.num_attention_heads
37
+ self.head_dim = self.hidden_size // self.num_heads
38
+ self.num_key_value_heads = config.num_key_value_heads
39
+ self.max_position_embeddings = config.max_position_embeddings
40
+ self.rope_theta = config.rope_theta
41
+
42
+ if (self.head_dim * self.num_heads) != self.hidden_size:
43
+ raise ValueError(
44
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`: {self.num_heads})."
45
+ )
46
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
47
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
48
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
49
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
50
+
51
+ self.dropout = nn.Dropout(config.attention_dropout)
52
+
53
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
54
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
55
+
56
+ def forward(
57
+ self,
58
+ hidden_states: torch.Tensor,
59
+ attention_mask: Optional[torch.Tensor] = None,
60
+ position_ids: Optional[torch.LongTensor] = None,
61
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
62
+ output_attentions: bool = False,
63
+ use_cache: bool = False,
64
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
65
+ bsz, q_len, _ = hidden_states.size()
66
+
67
+ query_states = self.q_proj(hidden_states)
68
+ key_states = self.k_proj(hidden_states)
69
+ value_states = self.v_proj(hidden_states)
70
+
71
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
72
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
73
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
74
+
75
+ kv_seq_len = key_states.shape[-2]
76
+ if past_key_value is not None:
77
+ kv_seq_len += past_key_value[0].shape[-2]
78
+
79
+ if past_key_value is not None:
80
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
81
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
82
+
83
+ past_key_value = (key_states, value_states) if use_cache else None
84
+
85
+ # repeat k/v heads if n_kv_heads < n_heads
86
+ key_states = self._repeat_kv(key_states)
87
+ value_states = self._repeat_kv(value_states)
88
+
89
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
90
+
91
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
92
+ raise ValueError(
93
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
94
+ f" {attn_weights.size()}"
95
+ )
96
+
97
+ if attention_mask is not None:
98
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
99
+ raise ValueError(
100
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
101
+ )
102
+ attn_weights = attn_weights + attention_mask
103
+
104
+ # upcast attention to fp32
105
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
106
+ attn_weights = self.dropout(attn_weights)
107
+
108
+ attn_output = torch.matmul(attn_weights, value_states)
109
+
110
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
111
+ raise ValueError(
112
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
113
+ f" {attn_output.size()}"
114
+ )
115
+
116
+ attn_output = attn_output.transpose(1, 2).contiguous()
117
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
118
+
119
+ attn_output = self.o_proj(attn_output)
120
+
121
+ if not output_attentions:
122
+ attn_weights = None
123
+
124
+ return attn_output, attn_weights, past_key_value
125
+
126
+ def _repeat_kv(self, hidden_states: torch.Tensor) -> torch.Tensor:
127
+ if self.num_key_value_heads != self.num_heads:
128
+ hidden_states = hidden_states.repeat_interleave(self.num_heads // self.num_key_value_heads, dim=1)
129
+ return hidden_states
130
+
131
+ class SapnousT1MLP(nn.Module):
132
+ def __init__(self, config: SapnousT1Config):
133
+ super().__init__()
134
+ self.config = config
135
+ self.hidden_size = config.hidden_size
136
+ self.intermediate_size = config.intermediate_size
137
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
138
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
139
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
140
+ self.act_fn = nn.SiLU()
141
+
142
+ def forward(self, x):
143
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
144
+
145
+ class SapnousT1DecoderLayer(nn.Module):
146
+ def __init__(self, config: SapnousT1Config):
147
+ super().__init__()
148
+ self.hidden_size = config.hidden_size
149
+ self.self_attn = SapnousT1Attention(config=config)
150
+ self.mlp = SapnousT1MLP(config)
151
+ self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
152
+ self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
153
+
154
+ def forward(
155
+ self,
156
+ hidden_states: torch.Tensor,
157
+ attention_mask: Optional[torch.Tensor] = None,
158
+ position_ids: Optional[torch.LongTensor] = None,
159
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
160
+ output_attentions: Optional[bool] = False,
161
+ use_cache: Optional[bool] = False,
162
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
163
+ residual = hidden_states
164
+
165
+ hidden_states = self.input_layernorm(hidden_states)
166
+
167
+ # Self Attention
168
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
169
+ hidden_states=hidden_states,
170
+ attention_mask=attention_mask,
171
+ position_ids=position_ids,
172
+ past_key_value=past_key_value,
173
+ output_attentions=output_attentions,
174
+ use_cache=use_cache,
175
+ )
176
+ hidden_states = residual + hidden_states
177
+
178
+ # Fully Connected
179
+ residual = hidden_states
180
+ hidden_states = self.post_attention_layernorm(hidden_states)
181
+ hidden_states = self.mlp(hidden_states)
182
+ hidden_states = residual + hidden_states
183
+
184
+ outputs = (hidden_states,)
185
+
186
+ if output_attentions:
187
+ outputs += (self_attn_weights,)
188
+
189
+ if use_cache:
190
+ outputs += (present_key_value,)
191
+
192
+ return outputs
193
+
194
+ class SapnousT1PreTrainedModel(PreTrainedModel):
195
+ config_class = SapnousT1Config
196
+ base_model_prefix = "model"
197
+ supports_gradient_checkpointing = True
198
+ _no_split_modules = ["SapnousT1DecoderLayer"]
199
+
200
+ def _init_weights(self, module):
201
+ std = self.config.initializer_range
202
+ if isinstance(module, nn.Linear):
203
+ module.weight.data.normal_(mean=0.0, std=std)
204
+ if module.bias is not None:
205
+ module.bias.data.zero_()
206
+ elif isinstance(module, nn.Embedding):
207
+ module.weight.data.normal_(mean=0.0, std=std)
208
+
209
+ def _set_gradient_checkpointing(self, module, value=False):
210
+ if isinstance(module, SapnousT1Model):
211
+ module.gradient_checkpointing = value
212
+
213
+ class SapnousT1Model(SapnousT1PreTrainedModel):
214
+ def __init__(self, config: SapnousT1Config):
215
+ super().__init__(config)
216
+ self.padding_idx = config.pad_token_id
217
+ self.vocab_size = config.vocab_size
218
+
219
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
220
+
221
+ self.layers = nn.ModuleList([SapnousT1DecoderLayer(config) for _ in range(config.num_hidden_layers)])
222
+ self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
223
+
224
+ self.gradient_checkpointing = False
225
+ # Initialize weights and apply final processing
226
+ self.post_init()
227
+
228
+ def get_input_embeddings(self):
229
+ return self.embed_tokens
230
+
231
+ def set_input_embeddings(self, value):
232
+ self.embed_tokens = value
233
+
234
+ def forward(
235
+ self,
236
+ input_ids: torch.LongTensor = None,
237
+ attention_mask: Optional[torch.Tensor] = None,
238
+ position_ids: Optional[torch.LongTensor] = None,
239
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
240
+ inputs_embeds: Optional[torch.FloatTensor] = None,
241
+ use_cache: Optional[bool] = None,
242
+ output_attentions: Optional[bool] = None,
243
+ output_hidden_states: Optional[bool] = None,
244
+ return_dict: Optional[bool] = None,
245
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
246
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
247
+ output_hidden_states = (
248
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
249
+ )
250
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
251
+
252
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
253
+
254
+ # retrieve input_ids and inputs_embeds
255
+ if input_ids is not None and inputs_embeds is not None:
256
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
257
+ elif input_ids is not None:
258
+ batch_size, seq_length = input_ids.shape
259
+ elif inputs_embeds is not None:
260
+ batch_size, seq_length, _ = inputs_embeds.shape
261
+ else:
262
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
263
+
264
+ seq_length_with_past = seq_length
265
+ past_key_values_length = 0
266
+
267
+ if past_key_values is not None:
268
+ past_key_values_length = past_key_values[0][0].shape[2]
269
+ seq_length_with_past = seq_length_with_past + past_key_values_length
270
+
271
+ if position_ids is None:
272
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
273
+ position_ids = torch.arange(
274
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
275
+ )
276
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
277
+ else:
278
+ position_ids = position_ids.view(-1, seq_length).long()
279
+
280
+ if inputs_embeds is None:
281
+ inputs_embeds = self.embed_tokens(input_ids)
282
+
283
+ if attention_mask is not None:
284
+ if batch_size <= 0:
285
+ raise ValueError("batch_size has to be defined and > 0")
286
+ attention_mask = self._prepare_decoder_attention_mask(
287
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
288
+ )
289
+
290
+ hidden_states = inputs_embeds
291
+
292
+ if self.gradient_checkpointing and self.training:
293
+ if use_cache:
294
+ logger.warning_once(
295
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
296
+ )
297
+ use_cache = False
298
+
299
+ # decoder layers
300
+ all_hidden_states = () if output_hidden_states else None
301
+ all_self_attns = () if output_attentions else None
302
+ next_decoder_cache = () if use_cache else None
303
+
304
+ for idx, decoder_layer in enumerate(self.layers):
305
+ if output_hidden_states:
306
+ all_hidden_states += (hidden_states,)
307
+
308
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
309
+
310
+ if self.gradient_checkpointing and self.training:
311
+
312
+ def create_custom_forward(module):
313
+ def custom_forward(*inputs):
314
+ # None for past_key_value
315
+ return module(*inputs, past_key_value, output_attentions)
316
+
317
+ return custom_forward
318
+
319
+ layer_outputs = torch.utils.checkpoint.checkpoint(
320
+ create_custom_forward(decoder_layer),
321
+ hidden_states,
322
+ attention_mask,
323
+ position_ids,
324
+ )
325
+ else:
326
+ layer_outputs = decoder_layer(
327
+ hidden_states,
328
+ attention_mask=attention_mask,
329
+ position_ids=position_ids,
330
+ past_key_value=past_key_value,
331
+ output_attentions=output_attentions,
332
+ use_cache=use_cache,
333
+ )
334
+
335
+ hidden_states = layer_outputs[0]
336
+
337
+ if use_cache:
338
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
339
+
340
+ if output_attentions:
341
+ all_self_attns += (layer_outputs[1],)
342
+
343
+ hidden_states = self.norm(hidden_states)
344
+
345
+ # add hidden states from the last decoder layer
346
+ if output_hidden_states:
347
+ all_hidden_states += (hidden_states,)
348
+
349
+ next_cache = next_decoder_cache if use_cache else None
350
+ if not return_dict:
351
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
352
+ return BaseModelOutputWithPast(
353
+ last_hidden_state=hidden_states,
354
+ past_key_values=next_cache,
355
+ hidden_states=all_hidden_states,
356
+ attentions=all_self_attns,
357
+ )
358
+
359
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
360
+ # create causal mask
361
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
362
+ combined_attention_mask = None
363
+ if input_shape[-1] > 1:
364
+ combined_attention_mask = _make_causal_mask(
365
+ input_shape,
366
+ inputs_embeds.dtype,
367
+ device=inputs_embeds.device,
368
+ past_key_values_length=past_key_values_length,
369
+ )
370
+
371
+ if attention_mask is not None:
372
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
373
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
374
+ combined_attention_mask = (
375
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
376
+ )
377
+
378
+ return combined_attention_mask
379
+
380
+ class SapnousT1ForCausalLM(SapnousT1PreTrainedModel):
381
+ _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
382
+
383
+ def __init__(self, config):
384
+ super().__init__(config)
385
+ self.model = SapnousT1Model(config)
386
+ self.vocab_size = config.vocab_size
387
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
388
+
389
+ # Initialize weights and apply final processing
390
+ self.post_init()
391
+
392
+ def get_input_embeddings(self):
393
+ return self.model.embed_tokens
394
+
395
+ def set_input_embeddings(self, value):
396
+ self.model.embed_tokens = value
397
+
398
+ def get_output_embeddings(self):
399
+ return self.lm_head
400
+
401
+ def set_output_embeddings(self, new_embeddings):
402
+ self.lm_head = new_embeddings
403
+
404
+ def prepare_inputs_for_generation(
405
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
406
+ ):
407
+ if past_key_values:
408
+ input_ids = input_ids[:, -1:]
409
+
410
+ position_ids = kwargs.get("position_ids", None)
411
+ if attention_mask is not None and position_ids is None:
412
+ # create position_ids on the fly for batch generation
413
+ position_ids = attention_mask.long().cumsum(-1) - 1
414
+ position_ids.masked_fill_(attention_mask == 0, 1)
415
+ if past_key_values:
416
+ position_ids = position_ids[:, -1].unsqueeze(-1)
417
+
418
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
419
+ if inputs_embeds is not None and past_key_values is None:
420
+ model_inputs = {"inputs_embeds": inputs_embeds}
421
+ else:
422
+ model_inputs = {"input_ids": input_ids}
423
+
424
+ model_inputs.update(
425
+ {
426
+ "position_ids": position_ids,
427
+ "past_key_values": past_key_values,
428
+ "use_cache": kwargs.get("use_cache"),
429
+ "attention_mask": attention_mask,
430
+ }
431
+ )
432
+ return model_inputs
433
+
434
+ def forward(
435
+ self,
436
+ input_ids: torch.LongTensor = None,
437
+ attention_mask: Optional[torch.Tensor] = None,
438
+ position_ids: Optional[torch.LongTensor] = None,
439
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
440
+ inputs_embeds: Optional[torch.FloatTensor] = None,
441
+ labels: Optional[torch.LongTensor] = None,
442
+ use_cache: Optional[bool] = None,
443
+ output_attentions: Optional[bool] = None,
444
+ output_hidden_states: Optional[bool] = None,
445
+ return_dict: Optional[bool] = None,
446
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
447
+ r"""
448
+ Args:
449
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`)
450
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*)
451
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*)
452
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*)
453
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*)
454
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*)
455
+ use_cache (`bool`, *optional*)
456
+ output_attentions (`bool`, *optional*)
457
+ output_hidden_states (`bool`, *optional*)
458
+ return_dict (`bool`, *optional*)
459
+
460
+ Returns:
461
+
462
+ Example:
463
+
464
+ ```python
465
+ >>> from transformers import AutoTokenizer, SapnousT1ForCausalLM
466
+
467
+ >>> model = SapnousT1ForCausalLM.from_pretrained("Sapnous-AI/Sapnous-VR-6B")
468
+ >>> tokenizer = AutoTokenizer.from_pretrained("Sapnous-AI/Sapnous-VR-6B")
469
+
470
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
471
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
472
+
473
+ >>> # Generate
474
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
475
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
476
+ "Hey, are you conscious? Can you talk to me? Yes, I am an AI language model capable of engaging in conversation."
477
+ ```"""
478
+
479
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
480
+ output_hidden_states = (
481
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
482
+ )
483
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
484
+
485
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
486
+ outputs = self.model(
487
+ input_ids=input_ids,
488
+ attention_mask=attention_mask,
489
+ position_ids=position_ids,
490
+ past_key_values=past_key_values,
491
+ inputs_embeds=inputs_embeds,
492
+ use_cache=use_cache,
493
+ output_attentions=output_attentions,
494
+ output_hidden_states=output_hidden_states,
495
+ return_dict=return_dict,
496
+ )
497
+
498
+ hidden_states = outputs[0]
499
+ logits = self.lm_head(hidden_states)
500
+
501
+ loss = None
502
+ if labels is not None:
503
+ # Shift so that tokens < n predict n
504
+ shift_logits = logits[..., :-1, :].contiguous()
505
+ shift_labels = labels[..., 1:].contiguous()
506
+ # Flatten the tokens
507
+ loss_fct = CrossEntropyLoss()
508
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
509
+ shift_labels = shift_labels.view(-1)
510
+ # Enable model parallelism
511
+ shift_labels = shift_labels.to(shift_logits.device)
512
+ loss = loss_fct(shift_logits, shift_labels)
513
+
514
+ if not return_dict:
515
+ output = (logits,) + outputs[1:]
516
+ return (loss,) + output if loss is not None else output
517
+
518
+ return CausalLMOutputWithPast(
519
+ loss=loss,
520
+ logits=logits,
521
+ past_key_values=outputs.past_key_values,
522
+ hidden_states=outputs.hidden_states,
523
+ attentions=outputs.attentions,
524
+ )
525
+
526
+ def _reorder_cache(self, past_key_values, beam_idx):
527
+ reordered_past = ()
528
+ for layer_past in past_key_values:
529
+ reordered_past += (
530
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
531
+ )
532
+ return reordered_past
533
+
534
+ AutoModel.register(SapnousT1Config, SapnousT1Model)
535
+ AutoModelForCausalLM.register(SapnousT1Config, SapnousT1ForCausalLM)
models/sapnous/test_tokenization_sapnous.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025-present, the HuggingFace Inc. Team and AIRAS Inc. Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import unittest
16
+ from transformers import AutoTokenizer
17
+ from .configuration_sapnous import SapnousT1Config
18
+ from .tokenization_sapnous import SapnousT1Tokenizer
19
+
20
+ class TestSapnousTokenizer(unittest.TestCase):
21
+ @classmethod
22
+ def setUpClass(cls):
23
+ cls.tokenizer = SapnousT1Tokenizer(
24
+ vocab_file="vocab.json",
25
+ merges_file="merges.txt"
26
+ )
27
+
28
+ def test_tokenizer_from_pretrained(self):
29
+ tokenizer = AutoTokenizer.from_pretrained(
30
+ "Sapnous-AI/Sapnous-VR-6B",
31
+ trust_remote_code=True
32
+ )
33
+ self.assertIsInstance(tokenizer, SapnousT1Tokenizer)
34
+
35
+ def test_save_load_pretrained(self):
36
+ vocab = self.tokenizer.get_vocab()
37
+ self.assertIsInstance(vocab, dict)
38
+ self.assertGreater(len(vocab), 0)
39
+
40
+ def test_tokenization(self):
41
+ text = "Hello, world!"
42
+ tokens = self.tokenizer.tokenize(text)
43
+ self.assertIsInstance(tokens, list)
44
+ self.assertGreater(len(tokens), 0)
45
+
46
+ def test_special_tokens(self):
47
+ self.assertIsNotNone(self.tokenizer.unk_token)
48
+ self.assertIsNotNone(self.tokenizer.bos_token)
49
+ self.assertIsNotNone(self.tokenizer.eos_token)
50
+
51
+ if __name__ == '__main__':
52
+ unittest.main()
models/sapnous/tokenization_sapnous.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025-present, the HuggingFace Inc. Team and AIRAS Inc. Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from typing import List, Optional, Tuple
16
+
17
+ from transformers.tokenization_utils import PreTrainedTokenizer
18
+ from transformers.utils import logging
19
+ from transformers import AutoTokenizer
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+ SAPNOUS_PRETRAINED_VOCAB_FILES_MAP = {
24
+ "vocab_file": {
25
+ "Sapnous-AI/Sapnous-VR-6B": "https://huggingface.co/Sapnous-AI/Sapnous-VR-6B/resolve/main/vocab.json",
26
+ },
27
+ "merges_file": {
28
+ "Sapnous-AI/Sapnous-VR-6B": "https://huggingface.co/Sapnous-AI/Sapnous-VR-6B/resolve/main/merges.txt",
29
+ },
30
+ }
31
+
32
+ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
33
+ "Sapnous-AI/Sapnous-VR-6B": 128000,
34
+ }
35
+
36
+ class SapnousT1Tokenizer(PreTrainedTokenizer):
37
+ vocab_files_names = {"vocab_file": "vocab.json", "merges_file": "merges.txt"}
38
+ pretrained_vocab_files_map = SAPNOUS_PRETRAINED_VOCAB_FILES_MAP
39
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
40
+ model_input_names = ["input_ids", "attention_mask"]
41
+
42
+ def __init__(
43
+ self,
44
+ vocab_file,
45
+ merges_file,
46
+ errors="replace",
47
+ unk_token="<|endoftext|>",
48
+ bos_token="<|endoftext|>",
49
+ eos_token="<|endoftext|>",
50
+ pad_token=None,
51
+ add_prefix_space=False,
52
+ **kwargs
53
+ ):
54
+ super().__init__(
55
+ errors=errors,
56
+ unk_token=unk_token,
57
+ bos_token=bos_token,
58
+ eos_token=eos_token,
59
+ pad_token=pad_token,
60
+ add_prefix_space=add_prefix_space,
61
+ **kwargs,
62
+ )
63
+
64
+ self.vocab_file = vocab_file
65
+ self.merges_file = merges_file
66
+ self.add_prefix_space = add_prefix_space
67
+
68
+ @property
69
+ def vocab_size(self) -> int:
70
+ return len(self.encoder)
71
+
72
+ def get_vocab(self) -> Dict[str, int]:
73
+ return dict(self.encoder, **self.added_tokens_encoder)
74
+
75
+ def _tokenize(self, text: str) -> List[str]:
76
+ """ Tokenize a string. """
77
+ raise NotImplementedError("Implement in subclass")
78
+
79
+ def _convert_token_to_id(self, token: str) -> int:
80
+ """ Converts a token to an id using the vocab. """
81
+ raise NotImplementedError("Implement in subclass")
82
+
83
+ def _convert_id_to_token(self, index: int) -> str:
84
+ """ Converts an index (integer) to a token. """
85
+ raise NotImplementedError("Implement in subclass")
86
+
87
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str, str]:
88
+ """ Save the vocabulary and special tokens file to a directory. """
89
+ raise NotImplementedError("Implement in subclass")
90
+
91
+ AutoTokenizer.register(SapnousT1Config, SapnousT1Tokenizer)
test_modeling_sapnous.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025-present, the HuggingFace Inc. Team and AIRAS Inc. Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import unittest
16
+ import torch
17
+ from transformers import AutoModelForCausalLM, AutoTokenizer
18
+ from .modeling_sapnous import SapnousT1ForCausalLM
19
+ from .configuration_sapnous import SapnousT1Config
20
+
21
+ class TestSapnousModel(unittest.TestCase):
22
+ @classmethod
23
+ def setUpClass(cls):
24
+ cls.config = SapnousT1Config(
25
+ vocab_size=32000,
26
+ hidden_size=768,
27
+ num_hidden_layers=12,
28
+ num_attention_heads=12,
29
+ intermediate_size=3072
30
+ )
31
+ cls.model = SapnousT1ForCausalLM(cls.config)
32
+
33
+ def test_model_forward(self):
34
+ input_ids = torch.randint(0, self.config.vocab_size, (1, 10))
35
+ outputs = self.model(input_ids)
36
+
37
+ self.assertIsNotNone(outputs)
38
+ self.assertTrue(hasattr(outputs, 'logits'))
39
+ self.assertEqual(outputs.logits.shape, (1, 10, self.config.vocab_size))
40
+
41
+ def test_weight_tying(self):
42
+ self.model.tie_weights()
43
+ self.assertTrue(torch.equal(self.model.lm_head.weight, self.model.model.embeddings.weight))
44
+
45
+ def test_auto_model_registration(self):
46
+ model = AutoModelForCausalLM.from_config(self.config)
47
+ self.assertIsInstance(model, SapnousT1ForCausalLM)
48
+
49
+ def test_vision_embeddings(self):
50
+ # Test vision input processing
51
+ batch_size = 1
52
+ pixel_values = torch.randn(batch_size, 3, 224, 224)
53
+ input_ids = torch.randint(0, self.config.vocab_size, (batch_size, 10))
54
+
55
+ outputs = self.model(input_ids=input_ids, pixel_values=pixel_values)
56
+ self.assertIsNotNone(outputs)
57
+ self.assertTrue(hasattr(outputs, 'logits'))
58
+
59
+ # Vision input should increase sequence length
60
+ expected_seq_length = 10 + (224 // 16) ** 2 + 1 # text_len + num_patches + cls_token
61
+ self.assertEqual(outputs.logits.shape, (batch_size, expected_seq_length, self.config.vocab_size))
62
+
63
+ def test_attention_mask(self):
64
+ # Test attention mask handling
65
+ batch_size = 2
66
+ seq_length = 15
67
+ input_ids = torch.randint(0, self.config.vocab_size, (batch_size, seq_length))
68
+ attention_mask = torch.ones(batch_size, seq_length)
69
+ attention_mask[:, -5:] = 0 # Mask out last 5 tokens
70
+
71
+ outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
72
+ self.assertIsNotNone(outputs)
73
+ self.assertEqual(outputs.logits.shape, (batch_size, seq_length, self.config.vocab_size))
74
+
75
+ def test_generation_with_vision(self):
76
+ # Test text generation with vision input
77
+ pixel_values = torch.randn(1, 3, 224, 224)
78
+ input_ids = torch.randint(0, self.config.vocab_size, (1, 5))
79
+
80
+ outputs = self.model.generate(
81
+ input_ids=input_ids,
82
+ pixel_values=pixel_values,
83
+ max_length=20,
84
+ num_beams=1
85
+ )
86
+
87
+ self.assertIsInstance(outputs, torch.Tensor)
88
+ self.assertEqual(outputs.dim(), 2)
89
+ self.assertTrue(outputs.size(1) <= 20)
90
+
91
+ if __name__ == '__main__':
92
+ unittest.main()
test_tokenization_sapnous.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025-present, the HuggingFace Inc. Team and AIRAS Inc. Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import unittest
16
+ import torch
17
+ from pathlib import Path
18
+ from transformers import AutoTokenizer
19
+ from .tokenization_sapnous import SapnousTokenizer
20
+
21
+ class TestSapnousTokenizer(unittest.TestCase):
22
+ @classmethod
23
+ def setUpClass(cls):
24
+ # Create temporary vocab and merges files for testing
25
+ cls.temp_dir = Path('test_tokenizer_files')
26
+ cls.temp_dir.mkdir(exist_ok=True)
27
+
28
+ # Create a simple test vocabulary
29
+ cls.vocab_file = cls.temp_dir / 'vocab.json'
30
+ cls.vocab = {
31
+ '<|endoftext|>': 0,
32
+ '<|startoftext|>': 1,
33
+ '<|pad|>': 2,
34
+ '<|vision_start|>': 3,
35
+ '<|vision_end|>': 4,
36
+ '<|image|>': 5,
37
+ '<|video|>': 6,
38
+ 'hello': 7,
39
+ 'world': 8,
40
+ 'test': 9,
41
+ }
42
+ with cls.vocab_file.open('w', encoding='utf-8') as f:
43
+ import json
44
+ json.dump(cls.vocab, f)
45
+
46
+ # Create test merges file
47
+ cls.merges_file = cls.temp_dir / 'merges.txt'
48
+ merges_content = "#version: 0.2\nh e\ne l\nl l\no w\nw o\no r\nr l\nl d"
49
+ cls.merges_file.write_text(merges_content)
50
+
51
+ # Initialize tokenizer
52
+ cls.tokenizer = SapnousTokenizer(
53
+ str(cls.vocab_file),
54
+ str(cls.merges_file),
55
+ )
56
+
57
+ @classmethod
58
+ def tearDownClass(cls):
59
+ # Clean up temporary files
60
+ import shutil
61
+ shutil.rmtree(cls.temp_dir)
62
+
63
+ def test_tokenizer_initialization(self):
64
+ self.assertEqual(self.tokenizer.vocab_size, len(self.vocab))
65
+ self.assertEqual(self.tokenizer.get_vocab(), self.vocab)
66
+
67
+ # Test special tokens
68
+ self.assertEqual(self.tokenizer.unk_token, '<|endoftext|>')
69
+ self.assertEqual(self.tokenizer.bos_token, '<|startoftext|>')
70
+ self.assertEqual(self.tokenizer.eos_token, '<|endoftext|>')
71
+ self.assertEqual(self.tokenizer.pad_token, '<|pad|>')
72
+
73
+ def test_tokenization(self):
74
+ text = "hello world test"
75
+ tokens = self.tokenizer.tokenize(text)
76
+ self.assertIsInstance(tokens, list)
77
+ self.assertTrue(all(isinstance(token, str) for token in tokens))
78
+
79
+ # Test encoding
80
+ input_ids = self.tokenizer.encode(text, add_special_tokens=False)
81
+ self.assertIsInstance(input_ids, list)
82
+ self.assertEqual(len(input_ids), 3) # 'hello', 'world', 'test'
83
+
84
+ # Test decoding
85
+ decoded_text = self.tokenizer.decode(input_ids)
86
+ self.assertEqual(decoded_text.strip(), text)
87
+
88
+ def test_special_tokens_handling(self):
89
+ text = "hello world"
90
+ # Test with special tokens
91
+ tokens_with_special = self.tokenizer.encode(text, add_special_tokens=True)
92
+ self.assertTrue(tokens_with_special[0] == self.tokenizer.bos_token_id)
93
+ self.assertTrue(tokens_with_special[-1] == self.tokenizer.eos_token_id)
94
+
95
+ # Test without special tokens
96
+ tokens_without_special = self.tokenizer.encode(text, add_special_tokens=False)
97
+ self.assertNotEqual(tokens_without_special[0], self.tokenizer.bos_token_id)
98
+ self.assertNotEqual(tokens_without_special[-1], self.tokenizer.eos_token_id)
99
+
100
+ def test_vision_tokens(self):
101
+ # Test vision-specific token methods
102
+ text = "This is an image description"
103
+ vision_text = self.tokenizer.prepare_for_vision(text)
104
+ self.assertTrue(vision_text.startswith('<|vision_start|>'))
105
+ self.assertTrue(vision_text.endswith('<|vision_end|>'))
106
+
107
+ image_text = self.tokenizer.prepare_for_image(text)
108
+ self.assertTrue(image_text.startswith('<|image|>'))
109
+
110
+ video_text = self.tokenizer.prepare_for_video(text)
111
+ self.assertTrue(video_text.startswith('<|video|>'))
112
+
113
+ def test_batch_encoding(self):
114
+ texts = ["hello world", "test hello"]
115
+ batch_encoding = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
116
+
117
+ self.assertIsInstance(batch_encoding["input_ids"], torch.Tensor)
118
+ self.assertIsInstance(batch_encoding["attention_mask"], torch.Tensor)
119
+ self.assertEqual(batch_encoding["input_ids"].shape[0], len(texts))
120
+ self.assertEqual(batch_encoding["attention_mask"].shape[0], len(texts))
121
+
122
+ def test_save_and_load(self):
123
+ # Test saving vocabulary
124
+ save_dir = Path('test_save_tokenizer')
125
+ save_dir.mkdir(exist_ok=True)
126
+
127
+ try:
128
+ vocab_files = self.tokenizer.save_vocabulary(str(save_dir))
129
+ self.assertTrue(all(Path(f).exists() for f in vocab_files))
130
+
131
+ # Test loading saved vocabulary
132
+ loaded_tokenizer = SapnousTokenizer(*vocab_files)
133
+ self.assertEqual(loaded_tokenizer.get_vocab(), self.tokenizer.get_vocab())
134
+
135
+ # Test encoding/decoding with loaded tokenizer
136
+ text = "hello world test"
137
+ original_encoding = self.tokenizer.encode(text)
138
+ loaded_encoding = loaded_tokenizer.encode(text)
139
+ self.assertEqual(original_encoding, loaded_encoding)
140
+ finally:
141
+ # Clean up
142
+ import shutil
143
+ shutil.rmtree(save_dir)
144
+
145
+ def test_auto_tokenizer_registration(self):
146
+ # Test if the tokenizer can be loaded using AutoTokenizer
147
+ config = {
148
+ "model_type": "sapnous",
149
+ "vocab_file": str(self.vocab_file),
150
+ "merges_file": str(self.merges_file)
151
+ }
152
+
153
+ tokenizer = AutoTokenizer.from_pretrained(str(self.temp_dir), **config)
154
+ self.assertIsInstance(tokenizer, SapnousTokenizer)
155
+
156
+ if __name__ == '__main__':
157
+ unittest.main()
tokenization_sapnous.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025-present, the HuggingFace Inc. Team and AIRAS Inc. Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from typing import List, Optional, Tuple, Union
16
+ from transformers.tokenization_utils import PreTrainedTokenizer
17
+ from transformers import AutoTokenizer
18
+ import json
19
+ import regex as re
20
+ from pathlib import Path
21
+ from typing import Dict, List, Optional, Union
22
+
23
+ BYTES_TO_UNICODE_REGEX = re.compile(r"'([^']+)':\s*([0-9]+)")
24
+
25
+ def bytes_to_unicode():
26
+ bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
27
+ cs = bs[:]
28
+ n = 0
29
+ for b in range(2**8):
30
+ if b not in bs:
31
+ bs.append(b)
32
+ cs.append(2**8 + n)
33
+ n += 1
34
+ cs = [chr(n) for n in cs]
35
+ return dict(zip(bs, cs))
36
+
37
+ def get_pairs(word):
38
+ pairs = set()
39
+ prev_char = word[0]
40
+ for char in word[1:]:
41
+ pairs.add((prev_char, char))
42
+ prev_char = char
43
+ return pairs
44
+
45
+ class SapnousTokenizer(PreTrainedTokenizer):
46
+ model_input_names = ["input_ids", "attention_mask"]
47
+
48
+ def __init__(
49
+ self,
50
+ vocab_file: str,
51
+ merges_file: Optional[str] = None,
52
+ unk_token: str = "<|endoftext|>",
53
+ bos_token: str = "<|startoftext|>",
54
+ eos_token: str = "<|endoftext|>",
55
+ pad_token: str = "<|pad|>",
56
+ vision_start_token: str = "<|vision_start|>",
57
+ vision_end_token: str = "<|vision_end|>",
58
+ image_token: str = "<|image|>",
59
+ video_token: str = "<|video|>",
60
+ add_prefix_space: bool = False,
61
+ **kwargs
62
+ ):
63
+ super().__init__(
64
+ unk_token=unk_token,
65
+ bos_token=bos_token,
66
+ eos_token=eos_token,
67
+ pad_token=pad_token,
68
+ **kwargs,
69
+ )
70
+
71
+ self.vocab_file = vocab_file
72
+ self.merges_file = merges_file
73
+ self.add_prefix_space = add_prefix_space
74
+
75
+ self.special_tokens = {
76
+ "unk_token": unk_token,
77
+ "bos_token": bos_token,
78
+ "eos_token": eos_token,
79
+ "pad_token": pad_token,
80
+ "vision_start_token": vision_start_token,
81
+ "vision_end_token": vision_end_token,
82
+ "image_token": image_token,
83
+ "video_token": video_token,
84
+ }
85
+
86
+ with Path(vocab_file).open(encoding="utf-8") as f:
87
+ self.encoder = json.load(f)
88
+ self.decoder = {v: k for k, v in self.encoder.items()}
89
+
90
+ if merges_file:
91
+ with Path(merges_file).open(encoding="utf-8") as f:
92
+ bpe_merges = f.read().strip().split('\n')[1:]
93
+ bpe_merges = [tuple(merge.split()) for merge in bpe_merges]
94
+ self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
95
+ else:
96
+ self.bpe_ranks = {}
97
+
98
+ self.byte_encoder = bytes_to_unicode()
99
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
100
+ self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\w+| ?\d+| ?[^\s\w\d]+|\s+(?!\S)|\s+""")
101
+
102
+ def bpe(self, token: str) -> str:
103
+ if token in self.special_tokens.values():
104
+ return token
105
+
106
+ word = tuple(token)
107
+ pairs = get_pairs(word)
108
+
109
+ if not pairs:
110
+ return token
111
+
112
+ while True:
113
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
114
+ if bigram not in self.bpe_ranks:
115
+ break
116
+ first, second = bigram
117
+ new_word = []
118
+ i = 0
119
+ while i < len(word):
120
+ try:
121
+ j = word.index(first, i)
122
+ new_word.extend(word[i:j])
123
+ if word[j + 1] == second:
124
+ new_word.append(first + second)
125
+ i = j + 2
126
+ else:
127
+ new_word.append(word[j])
128
+ i = j + 1
129
+ except ValueError:
130
+ new_word.extend(word[i:])
131
+ break
132
+ word = tuple(new_word)
133
+ if len(word) == 1:
134
+ break
135
+ pairs = get_pairs(word)
136
+ return ' '.join(word)
137
+
138
+ def _tokenize(self, text: str) -> List[str]:
139
+ if self.add_prefix_space:
140
+ text = ' ' + text
141
+
142
+ bpe_tokens = []
143
+ for token in re.findall(self.pat, text):
144
+ token = ''.join(self.byte_encoder[ord(b)] for b in token)
145
+ bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' '))
146
+ return bpe_tokens
147
+
148
+ def _convert_token_to_id(self, token: str) -> int:
149
+ return self.encoder.get(token, self.encoder.get(self.unk_token))
150
+
151
+ def _convert_id_to_token(self, index: int) -> str:
152
+ return self.decoder.get(index, self.unk_token)
153
+
154
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
155
+ text = ''.join(tokens)
156
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors='replace')
157
+ return text
158
+
159
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str, str]:
160
+ if not filename_prefix:
161
+ filename_prefix = ""
162
+
163
+ vocab_file = Path(save_directory) / f"{filename_prefix}vocab.json"
164
+ merge_file = Path(save_directory) / f"{filename_prefix}merges.txt"
165
+
166
+ with vocab_file.open('w', encoding='utf-8') as f:
167
+ json.dump(self.encoder, f, ensure_ascii=False)
168
+
169
+ if self.merges_file:
170
+ with merge_file.open('w', encoding='utf-8') as f:
171
+ for merge in self.bpe_ranks:
172
+ f.write(f"{merge[0]} {merge[1]}\n")
173
+ return str(vocab_file), str(merge_file)
174
+
175
+ return str(vocab_file)
176
+
177
+ def prepare_for_vision(self, text: str) -> str:
178
+ """Prepare text for vision tasks by adding special tokens."""
179
+ return f"{self.vision_start_token}{text}{self.vision_end_token}"
180
+
181
+ def prepare_for_image(self, text: str) -> str:
182
+ """Prepare text for image tasks."""
183
+ return f"{self.image_token}{text}"
184
+
185
+ def prepare_for_video(self, text: str) -> str:
186
+ """Prepare text for video tasks."""
187
+ return f"{self.video_token}{text}"
188
+
189
+ @property
190
+ def vocab_size(self) -> int:
191
+ return len(self.encoder)
192
+
193
+ def get_vocab(self) -> Dict[str, int]:
194
+ return self.encoder.copy()
195
+
196
+ # Register the tokenizer
197
+ AutoTokenizer.register(SapnousTokenizer, "sapnous")