nvedant07 commited on
Commit
f9972a2
·
verified ·
1 Parent(s): 58634eb

Upload 3 files

Browse files

Added inference

Files changed (3) hide show
  1. __init__.py +0 -0
  2. model.py +916 -0
  3. splitter.py +45 -0
__init__.py ADDED
File without changes
model.py ADDED
@@ -0,0 +1,916 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ from collections.abc import Sequence
3
+ from importlib.metadata import PackageNotFoundError, version
4
+ from typing import Callable
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from einops import rearrange
9
+ from flash_attn.flash_attn_interface import flash_attn_varlen_func
10
+ from transformers import PreTrainedModel
11
+ from transformers.cache_utils import Cache, DynamicCache
12
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
13
+ from transformers.models.llama.modeling_llama import (
14
+ LlamaDecoderLayer,
15
+ LlamaRotaryEmbedding,
16
+ )
17
+ from transformers.utils import ModelOutput
18
+
19
+ from .config import (
20
+ CrossAttentionConfig,
21
+ DecoderHATModelConfig,
22
+ EncoderHATModelConfig,
23
+ HATArchitectureConfig,
24
+ TransformerHATModelConfig,
25
+ )
26
+ from .splitter import HATSplitter
27
+
28
+ try:
29
+ transformers_version = version("transformers")
30
+ if transformers_version != "4.46.3":
31
+ print(f"Warning: Expecected transformers version 4.46.3, but found {transformers_version}. Outputs might be different.")
32
+ except PackageNotFoundError:
33
+ print("transformers is not installed")
34
+
35
+
36
+ def sample_argmax(logits: torch.Tensor) -> torch.Tensor:
37
+ return torch.argmax(logits, dim=-1)[:, -1]
38
+
39
+
40
+ LLAMA_TEMPLATE = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>
41
+ You are a helpful assistant. You give engaging, well-structured answers to user inquiries.<|eot_id|><|start_header_id|>user<|end_header_id|>
42
+ {input}<|eot_id|><|start_header_id|>assistant<|end_header_id|>"""
43
+
44
+
45
+ class HATCache(Cache):
46
+ encoder_cache: DynamicCache
47
+ backbone_cache: DynamicCache
48
+ decoder_cache: DynamicCache
49
+
50
+ def __init__(self, *args, **kwargs):
51
+ super().__init__(*args, **kwargs)
52
+ self.encoder_cache = DynamicCache()
53
+ self.backbone_cache = DynamicCache()
54
+ self.decoder_cache = DynamicCache()
55
+
56
+ def get_backbone_cache(self) -> DynamicCache:
57
+ return self.backbone_cache
58
+
59
+ def get_decoder_cache(self) -> DynamicCache:
60
+ return self.decoder_cache
61
+
62
+ def get_encoder_cache(self) -> DynamicCache:
63
+ return self.encoder_cache
64
+
65
+
66
+ def rotate_half(x):
67
+ """Rotates half the hidden dims of the input."""
68
+ x1 = x[..., : x.shape[-1] // 2]
69
+ x2 = x[..., x.shape[-1] // 2 :]
70
+ return torch.cat((-x2, x1), dim=-1)
71
+
72
+
73
+ def apply_rotary_pos_emb(q, k, q_cos=None, q_sin=None, k_cos=None, k_sin=None, unsqueeze_dim=1):
74
+ """Applies Rotary Position Embedding to the query and key tensors.
75
+ and allows for different sequence lengths.
76
+ Args:
77
+ q (`torch.Tensor`): The query tensor.
78
+ k (`torch.Tensor`): The key tensor.
79
+ q_cos (`torch.Tensor`): The cosine part of the rotary embedding.
80
+ q_sin (`torch.Tensor`): The sine part of the rotary embedding.
81
+ k_cos (`torch.Tensor`): The cosine part of the rotary embedding.
82
+ k_sin (`torch.Tensor`): The sine part of the rotary embedding.
83
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
84
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze
85
+ cos[position_ids] and sin[position_ids] so that they can be properly
86
+ broadcasted to the dimensions of q and k. For example, note
87
+ that cos[position_ids] and sin[position_ids] have the shape
88
+ [batch_size, seq_len, head_dim]. Then, if q and
89
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting
90
+ unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids]
91
+ broadcastable to the shapes of q and k. Similarly, if q and k have
92
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
93
+ Returns:
94
+ `tuple(torch.Tensor)` comprising of the query and key
95
+ tensors rotated using the Rotary Position Embedding.
96
+ """
97
+
98
+ q_cos = q_cos.unsqueeze(unsqueeze_dim)
99
+ q_sin = q_sin.unsqueeze(unsqueeze_dim)
100
+ k_cos = k_cos.unsqueeze(unsqueeze_dim)
101
+ k_sin = k_sin.unsqueeze(unsqueeze_dim)
102
+ q_embed = (q * q_cos) + (rotate_half(q) * q_sin)
103
+ k_embed = (k * k_cos) + (rotate_half(k) * k_sin)
104
+
105
+ return q_embed, k_embed
106
+
107
+
108
+ class HATBackbone(nn.Module):
109
+ def __init__(self, config: TransformerHATModelConfig, *args, **kwargs):
110
+ super().__init__(*args, **kwargs)
111
+
112
+ self.layers = nn.ModuleList([LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
113
+ self.rotary_emb = LlamaRotaryEmbedding(config=config)
114
+
115
+ def forward(
116
+ self,
117
+ hidden_states: torch.Tensor,
118
+ position_ids: torch.Tensor | None = None,
119
+ past_key_values: DynamicCache | None = None,
120
+ use_cache: bool | None = False,
121
+ ) -> BaseModelOutputWithPast:
122
+ if use_cache and past_key_values is None:
123
+ past_key_values = DynamicCache()
124
+
125
+ if position_ids is None:
126
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
127
+ position_ids = torch.arange(
128
+ past_seen_tokens,
129
+ past_seen_tokens + hidden_states.shape[1],
130
+ device=hidden_states.device,
131
+ ).unsqueeze(0)
132
+
133
+ # create position embeddings to be shared across the decoder layers
134
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
135
+
136
+ for backbone_layer in self.layers:
137
+ layer_outputs = backbone_layer(
138
+ hidden_states,
139
+ position_ids=position_ids,
140
+ past_key_value=past_key_values,
141
+ use_cache=use_cache,
142
+ position_embeddings=position_embeddings,
143
+ )
144
+ hidden_states = layer_outputs[0]
145
+
146
+ return CausalLMOutputWithPast(
147
+ hidden_states=hidden_states,
148
+ past_key_values=past_key_values if use_cache else None,
149
+ )
150
+
151
+
152
+ class HATDecoderConnector(nn.Module):
153
+ def __init__(self, backbone_hiden_dim: int, *args, **kwargs):
154
+ super().__init__(*args, **kwargs)
155
+ self.first_word_embedding = torch.nn.Parameter(
156
+ torch.empty(
157
+ 1,
158
+ 1,
159
+ backbone_hiden_dim,
160
+ device="cuda",
161
+ dtype=torch.bfloat16,
162
+ )
163
+ )
164
+
165
+ def forward(
166
+ self,
167
+ backbone_activations: torch.Tensor,
168
+ ):
169
+ activations = backbone_activations.clone()
170
+ activations[:, -1:, :] = self.first_word_embedding
171
+ activations = torch.roll(activations, shifts=1, dims=1)
172
+ return activations
173
+
174
+
175
+ class RMSNorm(nn.Module):
176
+ def __init__(self, dimensions: int, eps: float, device: torch.device, dtype: torch.dtype = torch.bfloat16, norm_in_fp32: bool = False):
177
+ super().__init__()
178
+ self.eps = eps
179
+ self.weight = torch.nn.Parameter(torch.ones(dimensions, dtype=dtype).to(device))
180
+ self.norm_in_fp32 = norm_in_fp32
181
+
182
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
183
+ original_dtype = x.dtype
184
+ if self.norm_in_fp32:
185
+ x = x.float()
186
+
187
+ out = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
188
+
189
+ if out.dtype != original_dtype:
190
+ out = out.to(original_dtype)
191
+
192
+ return out * self.weight
193
+
194
+
195
+ class HATDecoderBlock(nn.Module):
196
+ def __init__(
197
+ self,
198
+ add_cross_attention: bool,
199
+ config: DecoderHATModelConfig,
200
+ layer_idx: int,
201
+ *args,
202
+ **kwargs,
203
+ ):
204
+ super().__init__(*args, **kwargs)
205
+ self.add_cross_attention = add_cross_attention
206
+ self.config = config
207
+ self.llama_layer = LlamaDecoderLayer(config, layer_idx)
208
+ self.llama_layer.self_attn.sliding_window = config.sliding_window
209
+ if add_cross_attention:
210
+ self.cross_attention = HATCrossAttention(
211
+ hidden_size=config.cross_attention_config.hidden_size,
212
+ hidden_size_kv=config.cross_attention_config.hidden_size_kv,
213
+ hidden_size_q=config.cross_attention_config.hidden_size_q,
214
+ config=config,
215
+ cross_attention_config=config.cross_attention_config,
216
+ )
217
+
218
+ self.query_norm = RMSNorm(
219
+ config.cross_attention_config.hidden_size_q,
220
+ eps=config.rms_norm_eps,
221
+ device=torch.device("cuda"),
222
+ dtype=torch.bfloat16,
223
+ norm_in_fp32=False,
224
+ )
225
+
226
+ self.kv_norm = RMSNorm(
227
+ config.cross_attention_config.hidden_size_kv,
228
+ eps=config.rms_norm_eps,
229
+ device=torch.device("cuda"),
230
+ dtype=torch.bfloat16,
231
+ norm_in_fp32=False,
232
+ )
233
+
234
+ def apply_norm(self, activations):
235
+ return self.query_norm(activations), self.kv_norm(activations)
236
+
237
+ def forward(
238
+ self,
239
+ encoder_activations,
240
+ backbone_activations,
241
+ byte_position_ids,
242
+ word_position_ids,
243
+ cumulative_seq_lengths_per_word,
244
+ position_embeddings,
245
+ past_key_values,
246
+ use_cache,
247
+ ):
248
+ if self.add_cross_attention:
249
+ kv_activations = self.kv_norm(backbone_activations)
250
+ q_activations = self.query_norm(encoder_activations)
251
+
252
+ activations = self.cross_attention.forward(
253
+ q_activations=q_activations,
254
+ kv_activations=kv_activations,
255
+ position_ids_q=byte_position_ids,
256
+ position_ids_kv=word_position_ids,
257
+ cumulative_seq_q=cumulative_seq_lengths_per_word,
258
+ cumulative_seq_kv=torch.arange(0, kv_activations.size(1) + 1, device=encoder_activations.device, dtype=torch.int32),
259
+ causal=False,
260
+ )
261
+ encoder_activations = encoder_activations + activations
262
+
263
+ return self.llama_layer.forward(
264
+ hidden_states=encoder_activations,
265
+ position_ids=byte_position_ids,
266
+ position_embeddings=position_embeddings,
267
+ past_key_value=past_key_values,
268
+ use_cache=use_cache,
269
+ )[0]
270
+
271
+
272
+ class HATDecoder(nn.Module):
273
+ def __init__(self, config: DecoderHATModelConfig, *args, **kwargs):
274
+ super().__init__()
275
+
276
+ self.decoder_layers = nn.Sequential()
277
+ for layer_idx in range(config.num_hidden_layers):
278
+ add_cross_attention = config.cross_attn_every_layer or layer_idx == 0
279
+ self.decoder_layers.add_module(
280
+ str(layer_idx),
281
+ HATDecoderBlock(
282
+ add_cross_attention,
283
+ config,
284
+ layer_idx,
285
+ ),
286
+ )
287
+
288
+ self.rotary_emb = LlamaRotaryEmbedding(config=config)
289
+
290
+ def forward(
291
+ self,
292
+ backbone_activations: torch.Tensor,
293
+ activations: torch.Tensor,
294
+ cumulative_seq_lengths_per_word: torch.Tensor | None = None,
295
+ byte_position_ids: torch.Tensor | None = None,
296
+ word_position_ids: torch.Tensor | None = None,
297
+ past_key_values: DynamicCache | None = None,
298
+ use_cache: bool | None = False,
299
+ ) -> BaseModelOutputWithPast:
300
+ if use_cache and past_key_values is None:
301
+ past_key_values = DynamicCache()
302
+
303
+ if byte_position_ids is None:
304
+ past_seen_bytes = past_key_values.get_seq_length() if past_key_values is not None else 0
305
+ byte_position_ids = torch.arange(
306
+ past_seen_bytes,
307
+ past_seen_bytes + activations.size(1),
308
+ device=activations.device,
309
+ dtype=torch.int32,
310
+ ).unsqueeze(0)
311
+
312
+ if cumulative_seq_lengths_per_word is None:
313
+ cumulative_seq_lengths_per_word = torch.tensor([0, byte_position_ids.size(1)], dtype=byte_position_ids.dtype, device=byte_position_ids.device)
314
+
315
+ if word_position_ids is None:
316
+ raise ValueError() # TODO
317
+
318
+ position_embeddings = self.rotary_emb(activations, byte_position_ids)
319
+
320
+ for _, layer in enumerate(self.decoder_layers):
321
+ activations = layer(
322
+ encoder_activations=activations,
323
+ backbone_activations=backbone_activations,
324
+ position_embeddings=position_embeddings,
325
+ cumulative_seq_lengths_per_word=cumulative_seq_lengths_per_word,
326
+ byte_position_ids=byte_position_ids,
327
+ word_position_ids=word_position_ids,
328
+ past_key_values=past_key_values,
329
+ use_cache=use_cache,
330
+ )
331
+
332
+ return BaseModelOutputWithPast(
333
+ last_hidden_state=activations,
334
+ past_key_values=past_key_values if use_cache else None,
335
+ )
336
+
337
+
338
+ class HATCrossAttention(nn.Module):
339
+ def __init__(
340
+ self,
341
+ hidden_size: int,
342
+ hidden_size_q: int,
343
+ hidden_size_kv: int,
344
+ config: EncoderHATModelConfig | DecoderHATModelConfig,
345
+ cross_attention_config: CrossAttentionConfig,
346
+ dtype: torch.dtype = torch.bfloat16,
347
+ ):
348
+ super().__init__()
349
+ self.hidden_size = hidden_size
350
+ self.hidden_size_q = hidden_size_q
351
+ self.hidden_size_kv = hidden_size_kv
352
+ self.num_heads = cross_attention_config.num_attention_heads
353
+ self.num_key_value_heads = cross_attention_config.attention_num_kv_heads
354
+ self.num_repeat_kv = cross_attention_config.num_attention_heads // cross_attention_config.attention_num_kv_heads
355
+ self.head_dim = hidden_size // self.num_heads
356
+
357
+ self.q_proj = nn.Linear(
358
+ in_features=hidden_size_q,
359
+ out_features=hidden_size,
360
+ dtype=dtype,
361
+ bias=False,
362
+ )
363
+
364
+ self.k_proj = nn.Linear(
365
+ in_features=hidden_size_kv,
366
+ out_features=hidden_size // self.num_repeat_kv,
367
+ dtype=dtype,
368
+ bias=False,
369
+ )
370
+
371
+ self.v_proj = nn.Linear(
372
+ in_features=hidden_size_kv,
373
+ out_features=hidden_size // self.num_repeat_kv,
374
+ dtype=dtype,
375
+ bias=False,
376
+ )
377
+
378
+ self.o_proj = nn.Linear(in_features=hidden_size, out_features=hidden_size_q, dtype=dtype, bias=False)
379
+
380
+ rope_theta = config.rope_theta
381
+ rope_type = config.rope_scaling["rope_type"]
382
+
383
+ self.rotary_emb = LlamaRotaryEmbedding(dim=self.head_dim, base=rope_theta, rope_type=rope_type)
384
+
385
+ def forward(
386
+ self,
387
+ q_activations: torch.Tensor,
388
+ kv_activations: torch.Tensor,
389
+ position_ids_q: torch.Tensor,
390
+ position_ids_kv: torch.Tensor,
391
+ cumulative_seq_kv: torch.Tensor,
392
+ cumulative_seq_q: torch.Tensor,
393
+ causal: bool = True,
394
+ use_cache: bool = False,
395
+ past_key_value: DynamicCache | None = None,
396
+ ):
397
+ q_len = cumulative_seq_q[-1]
398
+
399
+ bsz, _, _ = kv_activations.size()
400
+ query_states = self.q_proj(q_activations)
401
+ key_states = self.k_proj(kv_activations)
402
+ value_states = self.v_proj(kv_activations)
403
+
404
+ # TODO get rid of the double rearrange, this is just for compatibility with scaling
405
+ query_states = rearrange(query_states, "bsz seq_len (h d) -> bsz h seq_len d", h=self.num_heads)
406
+ key_states = rearrange(
407
+ key_states,
408
+ "bsz seq_len (h d) -> bsz h seq_len d",
409
+ h=self.num_key_value_heads,
410
+ )
411
+ value_states = rearrange(
412
+ value_states,
413
+ "bsz seq_len (h d) -> bsz h seq_len d",
414
+ h=self.num_key_value_heads,
415
+ )
416
+
417
+ # WIP: Should word_positions_id respect document boundaries?
418
+ q_cos, q_sin = self.rotary_emb(query_states, position_ids_q)
419
+ k_cos, k_sin = self.rotary_emb(key_states, position_ids_kv)
420
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, q_cos=q_cos, q_sin=q_sin, k_cos=k_cos, k_sin=k_sin)
421
+
422
+ query_states = rearrange(query_states, "bsz h seq_len d -> (bsz seq_len) h d")
423
+ key_states = rearrange(key_states, "bsz h seq_len d -> (bsz seq_len) h d")
424
+ value_states = rearrange(value_states, "bsz h seq_len d -> (bsz seq_len) h d")
425
+
426
+ attn_output = flash_attn_varlen_func(
427
+ query_states,
428
+ key_states,
429
+ value_states,
430
+ cu_seqlens_q=cumulative_seq_q,
431
+ cu_seqlens_k=cumulative_seq_kv,
432
+ max_seqlen_q=self._get_max_seqlen(cumulative_seq_q),
433
+ max_seqlen_k=self._get_max_seqlen(cumulative_seq_kv),
434
+ causal=False,
435
+ )
436
+
437
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
438
+
439
+ attn_output = self.o_proj(attn_output)
440
+ return attn_output
441
+
442
+ def _get_max_seqlen(self, cumulative_word_lengths: torch.Tensor):
443
+ diffs = cumulative_word_lengths[1:] - cumulative_word_lengths[:-1]
444
+ return int(diffs.max().item())
445
+
446
+
447
+ class HATEncoderConnector(nn.Module):
448
+ def __init__(
449
+ self,
450
+ config: EncoderHATModelConfig,
451
+ backbone_hidden_size: int,
452
+ dtype: torch.dtype = torch.bfloat16,
453
+ *args,
454
+ **kwargs,
455
+ ):
456
+ super().__init__(*args, **kwargs)
457
+ self.latent_query = torch.nn.Parameter(
458
+ torch.empty(
459
+ 1,
460
+ 1,
461
+ backbone_hidden_size,
462
+ device="cuda",
463
+ dtype=dtype,
464
+ )
465
+ )
466
+
467
+ self.cross_attention_encoder_connector = HATCrossAttention(
468
+ hidden_size=config.cross_attention_config.hidden_size,
469
+ hidden_size_q=backbone_hidden_size,
470
+ hidden_size_kv=config.hidden_size,
471
+ config=config,
472
+ cross_attention_config=config.cross_attention_config,
473
+ )
474
+
475
+ def forward(
476
+ self,
477
+ hidden_states: torch.Tensor,
478
+ cumulative_seq_lengths_per_word: torch.Tensor,
479
+ word_position_ids: torch.Tensor,
480
+ byte_position_ids: torch.Tensor,
481
+ ):
482
+ q_len = cumulative_seq_lengths_per_word.shape[0] - 1
483
+ latent_query_repeated = self.latent_query.expand(-1, q_len, -1)
484
+ cumulative_seq_lengths_q = torch.arange(
485
+ start=0,
486
+ end=latent_query_repeated.shape[1] + 1,
487
+ step=1,
488
+ device=self.latent_query.device,
489
+ dtype=torch.int32,
490
+ )
491
+ word_embeddings = self.cross_attention_encoder_connector.forward(
492
+ q_activations=latent_query_repeated,
493
+ kv_activations=hidden_states,
494
+ position_ids_q=word_position_ids,
495
+ position_ids_kv=byte_position_ids,
496
+ cumulative_seq_q=cumulative_seq_lengths_q,
497
+ cumulative_seq_kv=cumulative_seq_lengths_per_word,
498
+ )
499
+ return word_embeddings
500
+
501
+
502
+ class HATEncoder(nn.Module):
503
+ def __init__(
504
+ self,
505
+ config: EncoderHATModelConfig,
506
+ dtype: torch.dtype = torch.bfloat16,
507
+ *args,
508
+ **kwargs,
509
+ ):
510
+ super().__init__(*args, **kwargs)
511
+ self.embedding_layer = nn.Embedding(config.vocab_size, config.hidden_size, dtype=dtype)
512
+ self.layers = nn.ModuleList([LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
513
+ for layer in self.layers:
514
+ layer.self_attn.sliding_window = config.sliding_window
515
+
516
+ self.rotary_emb = LlamaRotaryEmbedding(config=config)
517
+
518
+ self.word_window_size = config.cross_attention_config.word_window_size
519
+
520
+ def forward(
521
+ self,
522
+ input_ids: torch.Tensor,
523
+ cumulative_seq_lengths_per_word: torch.Tensor | None = None,
524
+ byte_position_ids: torch.Tensor | None = None,
525
+ word_position_ids: torch.Tensor | None = None, # TODO: Remove
526
+ past_key_values: DynamicCache | None = None,
527
+ use_cache: bool | None = False,
528
+ ):
529
+ input_embeds = self.embedding_layer(input_ids)
530
+
531
+ if cumulative_seq_lengths_per_word is None:
532
+ cumulative_seq_lengths_per_word = torch.tensor([0, input_embeds.shape[1]], dtype=torch.int32, device=input_ids.device)
533
+
534
+ if use_cache and past_key_values is None:
535
+ past_key_values = DynamicCache()
536
+
537
+ if byte_position_ids is None:
538
+ past_seen_bytes = past_key_values.get_seq_length() if past_key_values is not None else 0
539
+ byte_position_ids = torch.arange(
540
+ past_seen_bytes,
541
+ past_seen_bytes + input_embeds.shape[1],
542
+ device=input_embeds.device,
543
+ ).unsqueeze(0)
544
+
545
+ if word_position_ids is None:
546
+ raise ValueError() # TODO
547
+
548
+ hidden_states = input_embeds
549
+
550
+ # create position embeddings to be shared across the decoder layers
551
+ position_embeddings = self.rotary_emb(hidden_states, byte_position_ids)
552
+
553
+ for layer in self.layers:
554
+ layer_outputs = layer(
555
+ hidden_states,
556
+ position_ids=byte_position_ids,
557
+ past_key_value=past_key_values,
558
+ use_cache=use_cache,
559
+ position_embeddings=position_embeddings,
560
+ )
561
+ hidden_states = layer_outputs[0]
562
+
563
+ return CausalLMOutputWithPast(
564
+ hidden_states=hidden_states,
565
+ past_key_values=past_key_values if use_cache else None,
566
+ )
567
+
568
+
569
+ class HATForCausalLM(PreTrainedModel):
570
+ config_class = HATArchitectureConfig
571
+ _supports_flash_attn_2 = True
572
+ _supports_cache_class = True
573
+
574
+ def __init__(self, config: HATArchitectureConfig, *args, **kwargs):
575
+ super().__init__(config, *args, **kwargs)
576
+ self.config = config
577
+ self.eos_token_id = config.eos_token_id
578
+ self.encoder = HATEncoder(config.encoder_config)
579
+ self.encoder_connector = HATEncoderConnector(config.encoder_config, config.backbone_config.hidden_size)
580
+ self.backbone = HATBackbone(config.backbone_config)
581
+ self.decoder_connector = HATDecoderConnector(config.backbone_config.hidden_size)
582
+ self.decoder = HATDecoder(config.decoder_config)
583
+ self.splitter = HATSplitter(special_token_dict=config.special_token_dict, max_word_size=config.max_word_size)
584
+ self.layer_norm = RMSNorm(config.decoder_config.hidden_size, eps=config.decoder_config.rms_norm_eps, device=torch.device("cuda"), dtype=torch.bfloat16, norm_in_fp32=False)
585
+ self.lm_head = nn.Linear(
586
+ in_features=config.decoder_config.hidden_size,
587
+ out_features=config.decoder_config.vocab_size,
588
+ dtype=torch.bfloat16,
589
+ bias=False,
590
+ )
591
+
592
+ def forward(
593
+ self,
594
+ input_ids: torch.Tensor,
595
+ byte_position_ids: torch.Tensor,
596
+ cumulative_seq_lengths_per_word: torch.Tensor | None = None,
597
+ word_position_ids: torch.Tensor | None = None,
598
+ past_key_values: HATCache | None = None,
599
+ use_cache: bool = False,
600
+ ):
601
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
602
+
603
+ if past_key_values is None and use_cache:
604
+ past_key_values = HATCache()
605
+
606
+ encoder_past_key_values = past_key_values.get_encoder_cache() if past_key_values is not None else None
607
+ backbone_past_key_values = past_key_values.get_backbone_cache() if past_key_values is not None else None
608
+ decoder_past_key_values = past_key_values.get_decoder_cache() if past_key_values is not None else None
609
+
610
+ encoder_output: BaseModelOutputWithPast = self.encoder(
611
+ input_ids=input_ids,
612
+ cumulative_seq_lengths_per_word=cumulative_seq_lengths_per_word,
613
+ byte_position_ids=byte_position_ids,
614
+ word_position_ids=word_position_ids,
615
+ past_key_values=encoder_past_key_values,
616
+ use_cache=use_cache,
617
+ )
618
+ byte_level_activations = encoder_output.hidden_states
619
+
620
+ encoder_connector_output = self.encoder_connector(
621
+ byte_level_activations,
622
+ cumulative_seq_lengths_per_word,
623
+ word_position_ids,
624
+ byte_position_ids,
625
+ )
626
+ backbone_output: CausalLMOutputWithPast = self.backbone(
627
+ hidden_states=encoder_connector_output,
628
+ position_ids=word_position_ids,
629
+ past_key_values=backbone_past_key_values,
630
+ use_cache=use_cache,
631
+ )
632
+
633
+ predictive_word_embeddings = self.decoder_connector.forward(backbone_activations=backbone_output.hidden_states)
634
+
635
+ decoder_output = self.decoder.forward(
636
+ activations=byte_level_activations,
637
+ backbone_activations=predictive_word_embeddings,
638
+ cumulative_seq_lengths_per_word=cumulative_seq_lengths_per_word,
639
+ byte_position_ids=byte_position_ids,
640
+ word_position_ids=word_position_ids,
641
+ past_key_values=decoder_past_key_values,
642
+ use_cache=use_cache,
643
+ )
644
+
645
+ decoder_output = self.layer_norm(decoder_output.last_hidden_state)
646
+ logits = self.lm_head(decoder_output)
647
+
648
+ loss = None
649
+
650
+ return CausalLMOutputWithPast(
651
+ loss=loss,
652
+ logits=logits,
653
+ past_key_values=past_key_values if use_cache else None,
654
+ hidden_states=backbone_output.hidden_states,
655
+ attentions=None,
656
+ )
657
+
658
+ def _append_byte(self, words: list[list[int]], token: int) -> list[list[int]]:
659
+ extended_last_word = words.pop() + [token]
660
+ try:
661
+ text = self.splitter.decode(extended_last_word, skip_special_tokens=False)
662
+ list_of_bytes = self.splitter.encode(text)
663
+ words.extend([list(word_in_bytes) for word_in_bytes in list_of_bytes])
664
+ except UnicodeDecodeError:
665
+ # if decoding fails, the token cannot be part of a new word since it is not a valid
666
+ # utf-8 end byte and we append it to the current word
667
+ words.append(extended_last_word)
668
+ return words
669
+
670
+ def _complete_word(
671
+ self,
672
+ input_ids: torch.Tensor,
673
+ byte_position_ids: torch.Tensor,
674
+ backbone_word_prediction: torch.Tensor,
675
+ word_position_id: torch.Tensor,
676
+ encoder_cache: DynamicCache,
677
+ decoder_cache: DynamicCache,
678
+ sample_fn: Callable[[torch.Tensor], torch.Tensor] = sample_argmax,
679
+ ):
680
+ """Generate byte tokens until we hit the first byte of a new word."""
681
+ words = [input_ids.squeeze(0).tolist()]
682
+ byte_encoder_activations = []
683
+ completion_logits = []
684
+
685
+ while True:
686
+ encoder_output = self.encoder.forward(
687
+ input_ids,
688
+ byte_position_ids=None,
689
+ word_position_ids=word_position_id,
690
+ past_key_values=encoder_cache,
691
+ use_cache=True,
692
+ )
693
+ byte_encoder_activations.append(encoder_output.hidden_states)
694
+ decoder_output = self.decoder.forward(
695
+ backbone_word_prediction,
696
+ encoder_output.hidden_states,
697
+ byte_position_ids=None,
698
+ word_position_ids=word_position_id,
699
+ past_key_values=decoder_cache,
700
+ use_cache=True,
701
+ )
702
+ decoder_output = self.layer_norm(decoder_output.last_hidden_state)
703
+ logits = self.lm_head(decoder_output)
704
+ completion_logits.append(logits[0, -1:, :])
705
+ next_byte = int(sample_fn(logits).item())
706
+ words = self._append_byte(words, next_byte)
707
+ if len(words) > 1 or next_byte == self.eos_token_id:
708
+ break
709
+ input_ids = torch.tensor([[next_byte]], dtype=input_ids.dtype, device=input_ids.device)
710
+
711
+ byte_encoder_activations = torch.cat(byte_encoder_activations, dim=1)
712
+ num_kv = encoder_cache.get_seq_length()
713
+ byte_position_ids = torch.arange(num_kv + 1 - byte_encoder_activations.shape[1], num_kv + 1, device=input_ids.device, dtype=torch.long).unsqueeze(0)
714
+ completed_word_embedding = self.encoder_connector.forward(
715
+ byte_encoder_activations,
716
+ cumulative_seq_lengths_per_word=torch.tensor([0, byte_encoder_activations.size(1)], dtype=torch.int32, device=input_ids.device),
717
+ word_position_ids=word_position_id,
718
+ byte_position_ids=byte_position_ids,
719
+ )
720
+
721
+ completion = sum(words, [])[-len(completion_logits) :]
722
+ first_byte_of_next_word = words[1]
723
+ return completion, completed_word_embedding, first_byte_of_next_word, byte_position_ids[:, -1].item() + 1, completion_logits
724
+
725
+ def generate(
726
+ self,
727
+ input_ids: torch.Tensor,
728
+ max_new_tokens: int,
729
+ cumulative_seq_lengths_per_word: torch.Tensor,
730
+ byte_position_ids: torch.Tensor | None = None,
731
+ word_position_ids: torch.Tensor | None = None,
732
+ sample_fn: Callable[[torch.Tensor], torch.Tensor] = sample_argmax,
733
+ use_cache: bool = True,
734
+ stop_sequences: Sequence[str] | None = None,
735
+ ):
736
+ if use_cache:
737
+ completion_text, completion_logits = self._generate_cached(input_ids, max_new_tokens, cumulative_seq_lengths_per_word, byte_position_ids, word_position_ids, sample_fn, stop_sequences=stop_sequences)
738
+ else:
739
+ completion_text, completion_logits = self._generate_uncached(input_ids, max_new_tokens, cumulative_seq_lengths_per_word, byte_position_ids, word_position_ids, sample_fn, stop_sequences=stop_sequences)
740
+
741
+ # remove stop sequence if exists
742
+ if stop_sequences is not None:
743
+ stop_sequences = sorted(stop_sequences, key=lambda i: len(i), reverse=True)
744
+ for stop_sequence in stop_sequences:
745
+ if stop_sequence in completion_text:
746
+ completion_text_left = completion_text.split(stop_sequence)[0]
747
+ completion_text_removed = completion_text[len(completion_text_left) :]
748
+
749
+ completion_logits = completion_logits[: -len(list(bytes(completion_text_removed.encode("UTF-8"))))]
750
+ completion_text = completion_text_left
751
+ break
752
+
753
+ return ModelOutput(
754
+ completion_text=completion_text,
755
+ input_ids=input_ids,
756
+ completion_logits=completion_logits,
757
+ )
758
+
759
+ @torch.no_grad()
760
+ def _generate_cached(
761
+ self,
762
+ input_ids: torch.Tensor,
763
+ max_new_tokens: int,
764
+ cumulative_seq_lengths_per_word: torch.Tensor,
765
+ byte_position_ids: torch.Tensor | None = None,
766
+ word_position_ids: torch.Tensor | None = None,
767
+ sample_fn: Callable[[torch.Tensor], torch.Tensor] = sample_argmax,
768
+ stop_sequences: Sequence[str] | None = None,
769
+ ):
770
+ max_total_bytes = max_new_tokens + input_ids.shape[1]
771
+ if byte_position_ids is None:
772
+ byte_position_ids = torch.arange(0, cumulative_seq_lengths_per_word[-1].item(), device=input_ids.device, dtype=torch.int32).unsqueeze(0)
773
+
774
+ if word_position_ids is None:
775
+ word_position_ids = torch.arange(0, cumulative_seq_lengths_per_word.shape[0] - 1, device=input_ids.device, dtype=torch.int32).unsqueeze(0)
776
+
777
+ last_word_start, last_word_end = (
778
+ cumulative_seq_lengths_per_word[-2],
779
+ cumulative_seq_lengths_per_word[-1],
780
+ )
781
+ # Populate cache with everything except last word
782
+ initial_forward_output = self.forward(
783
+ input_ids=input_ids[:, :last_word_start],
784
+ cumulative_seq_lengths_per_word=cumulative_seq_lengths_per_word[:-1],
785
+ byte_position_ids=byte_position_ids[:, :last_word_start],
786
+ word_position_ids=word_position_ids[:, :-1],
787
+ past_key_values=None,
788
+ use_cache=True,
789
+ )
790
+
791
+ completion_bytes = []
792
+ completion_logits = []
793
+ input_ids = input_ids[:, last_word_start:last_word_end]
794
+ next_byte_id = last_word_end
795
+ byte_position_ids = byte_position_ids[:, last_word_start:last_word_end]
796
+ word_position_id = word_position_ids[:, -1].unsqueeze(-1)
797
+ backbone_last_hidden_state = initial_forward_output.hidden_states[:, -1:, :]
798
+ while next_byte_id < max_total_bytes:
799
+ completion, completed_word_embedding, first_byte_of_next_word, next_byte_id, next_completion_logits = self._complete_word(
800
+ input_ids=input_ids,
801
+ byte_position_ids=byte_position_ids,
802
+ backbone_word_prediction=backbone_last_hidden_state,
803
+ word_position_id=word_position_id,
804
+ encoder_cache=initial_forward_output.past_key_values.get_encoder_cache(),
805
+ decoder_cache=initial_forward_output.past_key_values.get_decoder_cache(),
806
+ sample_fn=sample_fn,
807
+ )
808
+ completion_logits.extend(next_completion_logits)
809
+ completion_bytes.extend(completion)
810
+
811
+ if self.eos_token_id in completion_bytes:
812
+ completion_bytes = completion_bytes[: completion_bytes.index(self.eos_token_id)]
813
+ break
814
+
815
+ if stop_sequences is not None:
816
+ try:
817
+ completion_text_tmp = self.splitter.decode(completion_bytes)
818
+ if any(stop_sequence in completion_text_tmp for stop_sequence in stop_sequences):
819
+ break
820
+ except Exception as e:
821
+ print("Cannot compare stop sequence", e)
822
+
823
+ backbone_output = self.backbone.forward(
824
+ hidden_states=completed_word_embedding,
825
+ position_ids=None,
826
+ past_key_values=initial_forward_output.past_key_values.get_backbone_cache(),
827
+ use_cache=True,
828
+ )
829
+ backbone_last_hidden_state = backbone_output.hidden_states[:, -1, :].unsqueeze(1)
830
+
831
+ input_ids = torch.tensor([first_byte_of_next_word], dtype=input_ids.dtype, device=input_ids.device)
832
+ byte_position_ids = torch.tensor([[next_byte_id]], dtype=input_ids.dtype, device=input_ids.device)
833
+ word_position_id = word_position_id + 1
834
+
835
+ completion_bytes.extend(first_byte_of_next_word)
836
+ completion_bytes = completion_bytes[:max_new_tokens]
837
+ completion_logits = torch.cat(completion_logits[:max_new_tokens], dim=0)
838
+ completion_text = self.splitter.decode(completion_bytes)
839
+
840
+ return completion_text, completion_logits
841
+
842
+ @torch.no_grad()
843
+ def _generate_uncached(
844
+ self,
845
+ input_ids: torch.Tensor,
846
+ max_new_tokens: int,
847
+ cumulative_seq_lengths_per_word: torch.Tensor,
848
+ byte_position_ids: torch.Tensor | None = None,
849
+ word_position_ids: torch.Tensor | None = None,
850
+ sample_fn=sample_argmax,
851
+ stop_sequences: Sequence[str] | None = None,
852
+ ):
853
+ if byte_position_ids is None:
854
+ byte_position_ids = torch.arange(0, cumulative_seq_lengths_per_word[-1].item(), device=input_ids.device, dtype=torch.int32).unsqueeze(0)
855
+
856
+ if word_position_ids is None:
857
+ word_position_ids = torch.arange(0, cumulative_seq_lengths_per_word.shape[0] - 1, device=input_ids.device, dtype=torch.int32).unsqueeze(0)
858
+
859
+ word_list = []
860
+ for i in range(1, cumulative_seq_lengths_per_word.shape[0]):
861
+ start_idx = cumulative_seq_lengths_per_word[i - 1]
862
+ end_idx = cumulative_seq_lengths_per_word[i]
863
+ word_list.append(input_ids[:, start_idx:end_idx].squeeze(0).tolist())
864
+
865
+ completion_bytes = []
866
+ for _ in range(max_new_tokens):
867
+ output = self.forward(
868
+ input_ids=input_ids,
869
+ cumulative_seq_lengths_per_word=cumulative_seq_lengths_per_word,
870
+ byte_position_ids=byte_position_ids,
871
+ word_position_ids=word_position_ids,
872
+ past_key_values=None,
873
+ )
874
+
875
+ next_byte = int(sample_fn(output.logits).item())
876
+ completion_bytes.append(next_byte)
877
+ if next_byte == self.eos_token_id:
878
+ break
879
+ word_list = self._append_byte(word_list, next_byte)
880
+
881
+ input_ids = torch.tensor(sum(word_list, []), dtype=torch.long, device=input_ids.device).unsqueeze(0)
882
+ cumulative_seq_lengths_per_word = torch.tensor([0] + list(itertools.accumulate(len(word) for word in word_list if len(word) > 0)), dtype=torch.int32, device=input_ids.device)
883
+ byte_position_ids = torch.arange(0, input_ids.shape[1], device=input_ids.device, dtype=torch.int32).unsqueeze(0)
884
+ word_position_ids = torch.arange(0, cumulative_seq_lengths_per_word.shape[0] - 1, device=input_ids.device, dtype=torch.int32).unsqueeze(0)
885
+
886
+ if stop_sequences is not None:
887
+ try:
888
+ completion_text_tmp = self.splitter.decode(completion_bytes)
889
+ if any(completion_text_tmp.endswith(stop_sequence) for stop_sequence in stop_sequences):
890
+ break
891
+ except Exception as e:
892
+ print("Cannot compare stop sequence", e)
893
+
894
+ completion_text = self.splitter.decode(completion_bytes)
895
+ completion_logits = output.logits[0, -len(completion_bytes) :, :]
896
+
897
+ return completion_text, completion_logits
898
+
899
+ def _prepare_input(self, input_str: str, add_llama_template: bool = True, device: torch.device | None = None) -> tuple[torch.Tensor, torch.Tensor]:
900
+ if add_llama_template:
901
+ input_str = LLAMA_TEMPLATE.format(input=input_str)
902
+
903
+ if device is None:
904
+ assert torch.cuda.is_available(), "CUDA is not available"
905
+ device = torch.device("cuda")
906
+ input_ids_list = []
907
+ cumulative_per_word_lengths_list = [0]
908
+
909
+ words = self.splitter.encode(input_str)
910
+ for word in words:
911
+ input_ids_list.extend(word)
912
+ word_length = len(word)
913
+ cumulative_per_word_lengths_list.append(cumulative_per_word_lengths_list[-1] + word_length)
914
+ input_ids = torch.tensor(input_ids_list, device=device, dtype=torch.int32).unsqueeze(0)
915
+ cumulative_per_word_lengths = torch.tensor(cumulative_per_word_lengths_list, device=device, dtype=torch.int32)
916
+ return input_ids, cumulative_per_word_lengths
splitter.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from hat_splitter import HATSplitter as RustHATSplitter
3
+
4
+
5
+ class HATSplitter:
6
+ def __init__(self, special_token_dict: dict | None = None, max_word_size: int = 128):
7
+ self.hat_splitter = RustHATSplitter()
8
+ self.max_word_size = max_word_size
9
+ self.special_token_dict = special_token_dict
10
+ self.special_token_replace: dict[int, list[int]] = {
11
+ token: list(text.encode("utf-8")) for text, token in self.special_token_dict.items()
12
+ }
13
+ self.special_token_pattern = (
14
+ re.compile(rf"({'|'.join(map(re.escape, special_token_dict.keys()))})")
15
+ if special_token_dict
16
+ else re.compile(r"(?!)")
17
+ )
18
+
19
+
20
+ def encode(self, text: str) -> list[list[int]]:
21
+ chunks = []
22
+ for str_chunk in self.special_token_pattern.split(text):
23
+ if str_chunk:
24
+ if str_chunk in self.special_token_dict:
25
+ chunks.append([self.special_token_dict[str_chunk]])
26
+ else:
27
+ chunks.extend(list(chunk) for chunk in self.hat_splitter.split_with_limit(str_chunk, self.max_word_size))
28
+ return chunks
29
+
30
+ def decode(self, token_ids: list[int], errors: str = "replace", skip_special_tokens: bool = False) -> str:
31
+ assert isinstance(token_ids, list), "token_ids must be a list"
32
+ assert all(isinstance(token_id, int) for token_id in token_ids), "token_ids must be a list of integers"
33
+
34
+ new_token_ids: list[int]
35
+ if skip_special_tokens:
36
+ new_token_ids = [token_id for token_id in token_ids if token_id not in self.special_token_replace]
37
+ else:
38
+ new_token_ids = []
39
+ for token in token_ids:
40
+ if token in self.special_token_replace:
41
+ new_token_ids.extend(self.special_token_replace[token])
42
+ else:
43
+ new_token_ids.append(token)
44
+
45
+ return bytes(new_token_ids).decode("utf-8", errors=errors)