import itertools from collections.abc import Sequence from importlib.metadata import PackageNotFoundError, version from typing import Callable import torch import torch.nn as nn from einops import rearrange from flash_attn.flash_attn_interface import flash_attn_varlen_func from transformers import PreTrainedModel from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.models.llama.modeling_llama import ( LlamaDecoderLayer, LlamaRotaryEmbedding, ) from transformers.utils import ModelOutput from .config import ( CrossAttentionConfig, DecoderHATModelConfig, EncoderHATModelConfig, HATArchitectureConfig, TransformerHATModelConfig, ) from .splitter import HATSplitter try: transformers_version = version("transformers") if transformers_version != "4.46.3": print(f"Warning: Expecected transformers version 4.46.3, but found {transformers_version}. Outputs might be different.") except PackageNotFoundError: print("transformers is not installed") def sample_argmax(logits: torch.Tensor) -> torch.Tensor: return torch.argmax(logits, dim=-1)[:, -1] LLAMA_TEMPLATE = """<|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a helpful assistant. You give engaging, well-structured answers to user inquiries.<|eot_id|><|start_header_id|>user<|end_header_id|> {input}<|eot_id|><|start_header_id|>assistant<|end_header_id|>""" class HATCache(Cache): encoder_cache: DynamicCache backbone_cache: DynamicCache decoder_cache: DynamicCache def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.encoder_cache = DynamicCache() self.backbone_cache = DynamicCache() self.decoder_cache = DynamicCache() def get_backbone_cache(self) -> DynamicCache: return self.backbone_cache def get_decoder_cache(self) -> DynamicCache: return self.decoder_cache def get_encoder_cache(self) -> DynamicCache: return self.encoder_cache def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, q_cos=None, q_sin=None, k_cos=None, k_sin=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. and allows for different sequence lengths. Args: q (`torch.Tensor`): The query tensor. k (`torch.Tensor`): The key tensor. q_cos (`torch.Tensor`): The cosine part of the rotary embedding. q_sin (`torch.Tensor`): The sine part of the rotary embedding. k_cos (`torch.Tensor`): The cosine part of the rotary embedding. k_sin (`torch.Tensor`): The sine part of the rotary embedding. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ q_cos = q_cos.unsqueeze(unsqueeze_dim) q_sin = q_sin.unsqueeze(unsqueeze_dim) k_cos = k_cos.unsqueeze(unsqueeze_dim) k_sin = k_sin.unsqueeze(unsqueeze_dim) q_embed = (q * q_cos) + (rotate_half(q) * q_sin) k_embed = (k * k_cos) + (rotate_half(k) * k_sin) return q_embed, k_embed class HATBackbone(nn.Module): def __init__(self, config: TransformerHATModelConfig, *args, **kwargs): super().__init__(*args, **kwargs) self.layers = nn.ModuleList([LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) self.rotary_emb = LlamaRotaryEmbedding(config=config) def forward( self, hidden_states: torch.Tensor, position_ids: torch.Tensor | None = None, past_key_values: DynamicCache | None = None, use_cache: bool | None = False, ) -> BaseModelOutputWithPast: if use_cache and past_key_values is None: past_key_values = DynamicCache() if position_ids is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 position_ids = torch.arange( past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device, ).unsqueeze(0) # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) for backbone_layer in self.layers: layer_outputs = backbone_layer( hidden_states, position_ids=position_ids, past_key_value=past_key_values, use_cache=use_cache, position_embeddings=position_embeddings, ) hidden_states = layer_outputs[0] return CausalLMOutputWithPast( hidden_states=hidden_states, past_key_values=past_key_values if use_cache else None, ) class HATDecoderConnector(nn.Module): def __init__(self, backbone_hiden_dim: int, *args, **kwargs): super().__init__(*args, **kwargs) self.first_word_embedding = torch.nn.Parameter( torch.empty( 1, 1, backbone_hiden_dim, device="cuda", dtype=torch.bfloat16, ) ) def forward( self, backbone_activations: torch.Tensor, ): activations = backbone_activations.clone() activations[:, -1:, :] = self.first_word_embedding activations = torch.roll(activations, shifts=1, dims=1) return activations class RMSNorm(nn.Module): def __init__(self, dimensions: int, eps: float, device: torch.device, dtype: torch.dtype = torch.bfloat16, norm_in_fp32: bool = False): super().__init__() self.eps = eps self.weight = torch.nn.Parameter(torch.ones(dimensions, dtype=dtype).to(device)) self.norm_in_fp32 = norm_in_fp32 def forward(self, x: torch.Tensor) -> torch.Tensor: original_dtype = x.dtype if self.norm_in_fp32: x = x.float() out = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) if out.dtype != original_dtype: out = out.to(original_dtype) return out * self.weight class HATDecoderBlock(nn.Module): def __init__( self, add_cross_attention: bool, config: DecoderHATModelConfig, layer_idx: int, *args, **kwargs, ): super().__init__(*args, **kwargs) self.add_cross_attention = add_cross_attention self.config = config self.llama_layer = LlamaDecoderLayer(config, layer_idx) self.llama_layer.self_attn.sliding_window = config.sliding_window if add_cross_attention: self.cross_attention = HATCrossAttention( hidden_size=config.cross_attention_config.hidden_size, hidden_size_kv=config.cross_attention_config.hidden_size_kv, hidden_size_q=config.cross_attention_config.hidden_size_q, config=config, cross_attention_config=config.cross_attention_config, ) self.query_norm = RMSNorm( config.cross_attention_config.hidden_size_q, eps=config.rms_norm_eps, device=torch.device("cuda"), dtype=torch.bfloat16, norm_in_fp32=False, ) self.kv_norm = RMSNorm( config.cross_attention_config.hidden_size_kv, eps=config.rms_norm_eps, device=torch.device("cuda"), dtype=torch.bfloat16, norm_in_fp32=False, ) def apply_norm(self, activations): return self.query_norm(activations), self.kv_norm(activations) def forward( self, encoder_activations, backbone_activations, byte_position_ids, word_position_ids, cumulative_seq_lengths_per_word, position_embeddings, past_key_values, use_cache, ): if self.add_cross_attention: kv_activations = self.kv_norm(backbone_activations) q_activations = self.query_norm(encoder_activations) activations = self.cross_attention.forward( q_activations=q_activations, kv_activations=kv_activations, position_ids_q=byte_position_ids, position_ids_kv=word_position_ids, cumulative_seq_q=cumulative_seq_lengths_per_word, cumulative_seq_kv=torch.arange(0, kv_activations.size(1) + 1, device=encoder_activations.device, dtype=torch.int32), causal=False, ) encoder_activations = encoder_activations + activations return self.llama_layer.forward( hidden_states=encoder_activations, position_ids=byte_position_ids, position_embeddings=position_embeddings, past_key_value=past_key_values, use_cache=use_cache, )[0] class HATDecoder(nn.Module): def __init__(self, config: DecoderHATModelConfig, *args, **kwargs): super().__init__() self.decoder_layers = nn.Sequential() for layer_idx in range(config.num_hidden_layers): add_cross_attention = config.cross_attn_every_layer or layer_idx == 0 self.decoder_layers.add_module( str(layer_idx), HATDecoderBlock( add_cross_attention, config, layer_idx, ), ) self.rotary_emb = LlamaRotaryEmbedding(config=config) def forward( self, backbone_activations: torch.Tensor, activations: torch.Tensor, cumulative_seq_lengths_per_word: torch.Tensor | None = None, byte_position_ids: torch.Tensor | None = None, word_position_ids: torch.Tensor | None = None, past_key_values: DynamicCache | None = None, use_cache: bool | None = False, ) -> BaseModelOutputWithPast: if use_cache and past_key_values is None: past_key_values = DynamicCache() if byte_position_ids is None: past_seen_bytes = past_key_values.get_seq_length() if past_key_values is not None else 0 byte_position_ids = torch.arange( past_seen_bytes, past_seen_bytes + activations.size(1), device=activations.device, dtype=torch.int32, ).unsqueeze(0) if cumulative_seq_lengths_per_word is None: cumulative_seq_lengths_per_word = torch.tensor([0, byte_position_ids.size(1)], dtype=byte_position_ids.dtype, device=byte_position_ids.device) if word_position_ids is None: raise ValueError() # TODO position_embeddings = self.rotary_emb(activations, byte_position_ids) for _, layer in enumerate(self.decoder_layers): activations = layer( encoder_activations=activations, backbone_activations=backbone_activations, position_embeddings=position_embeddings, cumulative_seq_lengths_per_word=cumulative_seq_lengths_per_word, byte_position_ids=byte_position_ids, word_position_ids=word_position_ids, past_key_values=past_key_values, use_cache=use_cache, ) return BaseModelOutputWithPast( last_hidden_state=activations, past_key_values=past_key_values if use_cache else None, ) class HATCrossAttention(nn.Module): def __init__( self, hidden_size: int, hidden_size_q: int, hidden_size_kv: int, config: EncoderHATModelConfig | DecoderHATModelConfig, cross_attention_config: CrossAttentionConfig, dtype: torch.dtype = torch.bfloat16, ): super().__init__() self.hidden_size = hidden_size self.hidden_size_q = hidden_size_q self.hidden_size_kv = hidden_size_kv self.num_heads = cross_attention_config.num_attention_heads self.num_key_value_heads = cross_attention_config.attention_num_kv_heads self.num_repeat_kv = cross_attention_config.num_attention_heads // cross_attention_config.attention_num_kv_heads self.head_dim = hidden_size // self.num_heads self.q_proj = nn.Linear( in_features=hidden_size_q, out_features=hidden_size, dtype=dtype, bias=False, ) self.k_proj = nn.Linear( in_features=hidden_size_kv, out_features=hidden_size // self.num_repeat_kv, dtype=dtype, bias=False, ) self.v_proj = nn.Linear( in_features=hidden_size_kv, out_features=hidden_size // self.num_repeat_kv, dtype=dtype, bias=False, ) self.o_proj = nn.Linear(in_features=hidden_size, out_features=hidden_size_q, dtype=dtype, bias=False) rope_theta = config.rope_theta rope_type = config.rope_scaling["rope_type"] self.rotary_emb = LlamaRotaryEmbedding(dim=self.head_dim, base=rope_theta, rope_type=rope_type) def forward( self, q_activations: torch.Tensor, kv_activations: torch.Tensor, position_ids_q: torch.Tensor, position_ids_kv: torch.Tensor, cumulative_seq_kv: torch.Tensor, cumulative_seq_q: torch.Tensor, causal: bool = True, use_cache: bool = False, past_key_value: DynamicCache | None = None, ): q_len = cumulative_seq_q[-1] bsz, _, _ = kv_activations.size() query_states = self.q_proj(q_activations) key_states = self.k_proj(kv_activations) value_states = self.v_proj(kv_activations) # TODO get rid of the double rearrange, this is just for compatibility with scaling query_states = rearrange(query_states, "bsz seq_len (h d) -> bsz h seq_len d", h=self.num_heads) key_states = rearrange( key_states, "bsz seq_len (h d) -> bsz h seq_len d", h=self.num_key_value_heads, ) value_states = rearrange( value_states, "bsz seq_len (h d) -> bsz h seq_len d", h=self.num_key_value_heads, ) # WIP: Should word_positions_id respect document boundaries? q_cos, q_sin = self.rotary_emb(query_states, position_ids_q) k_cos, k_sin = self.rotary_emb(key_states, position_ids_kv) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, q_cos=q_cos, q_sin=q_sin, k_cos=k_cos, k_sin=k_sin) query_states = rearrange(query_states, "bsz h seq_len d -> (bsz seq_len) h d") key_states = rearrange(key_states, "bsz h seq_len d -> (bsz seq_len) h d") value_states = rearrange(value_states, "bsz h seq_len d -> (bsz seq_len) h d") attn_output = flash_attn_varlen_func( query_states, key_states, value_states, cu_seqlens_q=cumulative_seq_q, cu_seqlens_k=cumulative_seq_kv, max_seqlen_q=self._get_max_seqlen(cumulative_seq_q), max_seqlen_k=self._get_max_seqlen(cumulative_seq_kv), causal=False, ) attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output def _get_max_seqlen(self, cumulative_word_lengths: torch.Tensor): diffs = cumulative_word_lengths[1:] - cumulative_word_lengths[:-1] return int(diffs.max().item()) class HATEncoderConnector(nn.Module): def __init__( self, config: EncoderHATModelConfig, backbone_hidden_size: int, dtype: torch.dtype = torch.bfloat16, *args, **kwargs, ): super().__init__(*args, **kwargs) self.latent_query = torch.nn.Parameter( torch.empty( 1, 1, backbone_hidden_size, device="cuda", dtype=dtype, ) ) self.cross_attention_encoder_connector = HATCrossAttention( hidden_size=config.cross_attention_config.hidden_size, hidden_size_q=backbone_hidden_size, hidden_size_kv=config.hidden_size, config=config, cross_attention_config=config.cross_attention_config, ) def forward( self, hidden_states: torch.Tensor, cumulative_seq_lengths_per_word: torch.Tensor, word_position_ids: torch.Tensor, byte_position_ids: torch.Tensor, ): q_len = cumulative_seq_lengths_per_word.shape[0] - 1 latent_query_repeated = self.latent_query.expand(-1, q_len, -1) cumulative_seq_lengths_q = torch.arange( start=0, end=latent_query_repeated.shape[1] + 1, step=1, device=self.latent_query.device, dtype=torch.int32, ) word_embeddings = self.cross_attention_encoder_connector.forward( q_activations=latent_query_repeated, kv_activations=hidden_states, position_ids_q=word_position_ids, position_ids_kv=byte_position_ids, cumulative_seq_q=cumulative_seq_lengths_q, cumulative_seq_kv=cumulative_seq_lengths_per_word, ) return word_embeddings class HATEncoder(nn.Module): def __init__( self, config: EncoderHATModelConfig, dtype: torch.dtype = torch.bfloat16, *args, **kwargs, ): super().__init__(*args, **kwargs) self.embedding_layer = nn.Embedding(config.vocab_size, config.hidden_size, dtype=dtype) self.layers = nn.ModuleList([LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) for layer in self.layers: layer.self_attn.sliding_window = config.sliding_window self.rotary_emb = LlamaRotaryEmbedding(config=config) self.word_window_size = config.cross_attention_config.word_window_size def forward( self, input_ids: torch.Tensor, cumulative_seq_lengths_per_word: torch.Tensor | None = None, byte_position_ids: torch.Tensor | None = None, word_position_ids: torch.Tensor | None = None, # TODO: Remove past_key_values: DynamicCache | None = None, use_cache: bool | None = False, ): input_embeds = self.embedding_layer(input_ids) if cumulative_seq_lengths_per_word is None: cumulative_seq_lengths_per_word = torch.tensor([0, input_embeds.shape[1]], dtype=torch.int32, device=input_ids.device) if use_cache and past_key_values is None: past_key_values = DynamicCache() if byte_position_ids is None: past_seen_bytes = past_key_values.get_seq_length() if past_key_values is not None else 0 byte_position_ids = torch.arange( past_seen_bytes, past_seen_bytes + input_embeds.shape[1], device=input_embeds.device, ).unsqueeze(0) if word_position_ids is None: raise ValueError() # TODO hidden_states = input_embeds # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, byte_position_ids) for layer in self.layers: layer_outputs = layer( hidden_states, position_ids=byte_position_ids, past_key_value=past_key_values, use_cache=use_cache, position_embeddings=position_embeddings, ) hidden_states = layer_outputs[0] return CausalLMOutputWithPast( hidden_states=hidden_states, past_key_values=past_key_values if use_cache else None, ) class HATForCausalLM(PreTrainedModel): config_class = HATArchitectureConfig _supports_flash_attn_2 = True _supports_cache_class = True def __init__(self, config: HATArchitectureConfig, *args, **kwargs): super().__init__(config, *args, **kwargs) self.config = config self.eos_token_id = config.eos_token_id self.encoder = HATEncoder(config.encoder_config) self.encoder_connector = HATEncoderConnector(config.encoder_config, config.backbone_config.hidden_size) self.backbone = HATBackbone(config.backbone_config) self.decoder_connector = HATDecoderConnector(config.backbone_config.hidden_size) self.decoder = HATDecoder(config.decoder_config) self.splitter = HATSplitter(special_token_dict=config.special_token_dict, max_word_size=config.max_word_size) self.layer_norm = RMSNorm(config.decoder_config.hidden_size, eps=config.decoder_config.rms_norm_eps, device=torch.device("cuda"), dtype=torch.bfloat16, norm_in_fp32=False) self.lm_head = nn.Linear( in_features=config.decoder_config.hidden_size, out_features=config.decoder_config.vocab_size, dtype=torch.bfloat16, bias=False, ) def forward( self, input_ids: torch.Tensor, byte_position_ids: torch.Tensor, cumulative_seq_lengths_per_word: torch.Tensor | None = None, word_position_ids: torch.Tensor | None = None, past_key_values: HATCache | None = None, use_cache: bool = False, ): use_cache = use_cache if use_cache is not None else self.config.use_cache if past_key_values is None and use_cache: past_key_values = HATCache() encoder_past_key_values = past_key_values.get_encoder_cache() if past_key_values is not None else None backbone_past_key_values = past_key_values.get_backbone_cache() if past_key_values is not None else None decoder_past_key_values = past_key_values.get_decoder_cache() if past_key_values is not None else None encoder_output: BaseModelOutputWithPast = self.encoder( input_ids=input_ids, cumulative_seq_lengths_per_word=cumulative_seq_lengths_per_word, byte_position_ids=byte_position_ids, word_position_ids=word_position_ids, past_key_values=encoder_past_key_values, use_cache=use_cache, ) byte_level_activations = encoder_output.hidden_states encoder_connector_output = self.encoder_connector( byte_level_activations, cumulative_seq_lengths_per_word, word_position_ids, byte_position_ids, ) backbone_output: CausalLMOutputWithPast = self.backbone( hidden_states=encoder_connector_output, position_ids=word_position_ids, past_key_values=backbone_past_key_values, use_cache=use_cache, ) predictive_word_embeddings = self.decoder_connector.forward(backbone_activations=backbone_output.hidden_states) decoder_output = self.decoder.forward( activations=byte_level_activations, backbone_activations=predictive_word_embeddings, cumulative_seq_lengths_per_word=cumulative_seq_lengths_per_word, byte_position_ids=byte_position_ids, word_position_ids=word_position_ids, past_key_values=decoder_past_key_values, use_cache=use_cache, ) decoder_output = self.layer_norm(decoder_output.last_hidden_state) logits = self.lm_head(decoder_output) loss = None return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=past_key_values if use_cache else None, hidden_states=backbone_output.hidden_states, attentions=None, ) def _append_byte(self, words: list[list[int]], token: int) -> list[list[int]]: extended_last_word = words.pop() + [token] try: text = self.splitter.decode(extended_last_word, skip_special_tokens=False) list_of_bytes = self.splitter.encode(text) words.extend([list(word_in_bytes) for word_in_bytes in list_of_bytes]) except UnicodeDecodeError: # if decoding fails, the token cannot be part of a new word since it is not a valid # utf-8 end byte and we append it to the current word words.append(extended_last_word) return words def _complete_word( self, input_ids: torch.Tensor, byte_position_ids: torch.Tensor, backbone_word_prediction: torch.Tensor, word_position_id: torch.Tensor, encoder_cache: DynamicCache, decoder_cache: DynamicCache, sample_fn: Callable[[torch.Tensor], torch.Tensor] = sample_argmax, ): """Generate byte tokens until we hit the first byte of a new word.""" words = [input_ids.squeeze(0).tolist()] byte_encoder_activations = [] completion_logits = [] while True: encoder_output = self.encoder.forward( input_ids, byte_position_ids=None, word_position_ids=word_position_id, past_key_values=encoder_cache, use_cache=True, ) byte_encoder_activations.append(encoder_output.hidden_states) decoder_output = self.decoder.forward( backbone_word_prediction, encoder_output.hidden_states, byte_position_ids=None, word_position_ids=word_position_id, past_key_values=decoder_cache, use_cache=True, ) decoder_output = self.layer_norm(decoder_output.last_hidden_state) logits = self.lm_head(decoder_output) completion_logits.append(logits[0, -1:, :]) next_byte = int(sample_fn(logits).item()) words = self._append_byte(words, next_byte) if len(words) > 1 or next_byte == self.eos_token_id: break input_ids = torch.tensor([[next_byte]], dtype=input_ids.dtype, device=input_ids.device) byte_encoder_activations = torch.cat(byte_encoder_activations, dim=1) num_kv = encoder_cache.get_seq_length() byte_position_ids = torch.arange(num_kv + 1 - byte_encoder_activations.shape[1], num_kv + 1, device=input_ids.device, dtype=torch.long).unsqueeze(0) completed_word_embedding = self.encoder_connector.forward( byte_encoder_activations, cumulative_seq_lengths_per_word=torch.tensor([0, byte_encoder_activations.size(1)], dtype=torch.int32, device=input_ids.device), word_position_ids=word_position_id, byte_position_ids=byte_position_ids, ) completion = sum(words, [])[-len(completion_logits) :] first_byte_of_next_word = words[1] return completion, completed_word_embedding, first_byte_of_next_word, byte_position_ids[:, -1].item() + 1, completion_logits def generate( self, input_ids: torch.Tensor, max_new_tokens: int, cumulative_seq_lengths_per_word: torch.Tensor, byte_position_ids: torch.Tensor | None = None, word_position_ids: torch.Tensor | None = None, sample_fn: Callable[[torch.Tensor], torch.Tensor] = sample_argmax, use_cache: bool = True, stop_sequences: Sequence[str] | None = None, ): if use_cache: completion_text, completion_logits = self._generate_cached(input_ids, max_new_tokens, cumulative_seq_lengths_per_word, byte_position_ids, word_position_ids, sample_fn, stop_sequences=stop_sequences) else: completion_text, completion_logits = self._generate_uncached(input_ids, max_new_tokens, cumulative_seq_lengths_per_word, byte_position_ids, word_position_ids, sample_fn, stop_sequences=stop_sequences) # remove stop sequence if exists if stop_sequences is not None: stop_sequences = sorted(stop_sequences, key=lambda i: len(i), reverse=True) for stop_sequence in stop_sequences: if stop_sequence in completion_text: completion_text_left = completion_text.split(stop_sequence)[0] completion_text_removed = completion_text[len(completion_text_left) :] completion_logits = completion_logits[: -len(list(bytes(completion_text_removed.encode("UTF-8"))))] completion_text = completion_text_left break return ModelOutput( completion_text=completion_text, input_ids=input_ids, completion_logits=completion_logits, ) @torch.no_grad() def _generate_cached( self, input_ids: torch.Tensor, max_new_tokens: int, cumulative_seq_lengths_per_word: torch.Tensor, byte_position_ids: torch.Tensor | None = None, word_position_ids: torch.Tensor | None = None, sample_fn: Callable[[torch.Tensor], torch.Tensor] = sample_argmax, stop_sequences: Sequence[str] | None = None, ): max_total_bytes = max_new_tokens + input_ids.shape[1] if byte_position_ids is None: byte_position_ids = torch.arange(0, cumulative_seq_lengths_per_word[-1].item(), device=input_ids.device, dtype=torch.int32).unsqueeze(0) if word_position_ids is None: word_position_ids = torch.arange(0, cumulative_seq_lengths_per_word.shape[0] - 1, device=input_ids.device, dtype=torch.int32).unsqueeze(0) last_word_start, last_word_end = ( cumulative_seq_lengths_per_word[-2], cumulative_seq_lengths_per_word[-1], ) # Populate cache with everything except last word initial_forward_output = self.forward( input_ids=input_ids[:, :last_word_start], cumulative_seq_lengths_per_word=cumulative_seq_lengths_per_word[:-1], byte_position_ids=byte_position_ids[:, :last_word_start], word_position_ids=word_position_ids[:, :-1], past_key_values=None, use_cache=True, ) completion_bytes = [] completion_logits = [] input_ids = input_ids[:, last_word_start:last_word_end] next_byte_id = last_word_end byte_position_ids = byte_position_ids[:, last_word_start:last_word_end] word_position_id = word_position_ids[:, -1].unsqueeze(-1) backbone_last_hidden_state = initial_forward_output.hidden_states[:, -1:, :] while next_byte_id < max_total_bytes: completion, completed_word_embedding, first_byte_of_next_word, next_byte_id, next_completion_logits = self._complete_word( input_ids=input_ids, byte_position_ids=byte_position_ids, backbone_word_prediction=backbone_last_hidden_state, word_position_id=word_position_id, encoder_cache=initial_forward_output.past_key_values.get_encoder_cache(), decoder_cache=initial_forward_output.past_key_values.get_decoder_cache(), sample_fn=sample_fn, ) completion_logits.extend(next_completion_logits) completion_bytes.extend(completion) if self.eos_token_id in completion_bytes: completion_bytes = completion_bytes[: completion_bytes.index(self.eos_token_id)] break if stop_sequences is not None: try: completion_text_tmp = self.splitter.decode(completion_bytes) if any(stop_sequence in completion_text_tmp for stop_sequence in stop_sequences): break except Exception as e: print("Cannot compare stop sequence", e) backbone_output = self.backbone.forward( hidden_states=completed_word_embedding, position_ids=None, past_key_values=initial_forward_output.past_key_values.get_backbone_cache(), use_cache=True, ) backbone_last_hidden_state = backbone_output.hidden_states[:, -1, :].unsqueeze(1) input_ids = torch.tensor([first_byte_of_next_word], dtype=input_ids.dtype, device=input_ids.device) byte_position_ids = torch.tensor([[next_byte_id]], dtype=input_ids.dtype, device=input_ids.device) word_position_id = word_position_id + 1 completion_bytes.extend(first_byte_of_next_word) completion_bytes = completion_bytes[:max_new_tokens] completion_logits = torch.cat(completion_logits[:max_new_tokens], dim=0) completion_text = self.splitter.decode(completion_bytes) return completion_text, completion_logits @torch.no_grad() def _generate_uncached( self, input_ids: torch.Tensor, max_new_tokens: int, cumulative_seq_lengths_per_word: torch.Tensor, byte_position_ids: torch.Tensor | None = None, word_position_ids: torch.Tensor | None = None, sample_fn=sample_argmax, stop_sequences: Sequence[str] | None = None, ): if byte_position_ids is None: byte_position_ids = torch.arange(0, cumulative_seq_lengths_per_word[-1].item(), device=input_ids.device, dtype=torch.int32).unsqueeze(0) if word_position_ids is None: word_position_ids = torch.arange(0, cumulative_seq_lengths_per_word.shape[0] - 1, device=input_ids.device, dtype=torch.int32).unsqueeze(0) word_list = [] for i in range(1, cumulative_seq_lengths_per_word.shape[0]): start_idx = cumulative_seq_lengths_per_word[i - 1] end_idx = cumulative_seq_lengths_per_word[i] word_list.append(input_ids[:, start_idx:end_idx].squeeze(0).tolist()) completion_bytes = [] for _ in range(max_new_tokens): output = self.forward( input_ids=input_ids, cumulative_seq_lengths_per_word=cumulative_seq_lengths_per_word, byte_position_ids=byte_position_ids, word_position_ids=word_position_ids, past_key_values=None, ) next_byte = int(sample_fn(output.logits).item()) completion_bytes.append(next_byte) if next_byte == self.eos_token_id: break word_list = self._append_byte(word_list, next_byte) input_ids = torch.tensor(sum(word_list, []), dtype=torch.long, device=input_ids.device).unsqueeze(0) cumulative_seq_lengths_per_word = torch.tensor([0] + list(itertools.accumulate(len(word) for word in word_list if len(word) > 0)), dtype=torch.int32, device=input_ids.device) byte_position_ids = torch.arange(0, input_ids.shape[1], device=input_ids.device, dtype=torch.int32).unsqueeze(0) word_position_ids = torch.arange(0, cumulative_seq_lengths_per_word.shape[0] - 1, device=input_ids.device, dtype=torch.int32).unsqueeze(0) if stop_sequences is not None: try: completion_text_tmp = self.splitter.decode(completion_bytes) if any(completion_text_tmp.endswith(stop_sequence) for stop_sequence in stop_sequences): break except Exception as e: print("Cannot compare stop sequence", e) completion_text = self.splitter.decode(completion_bytes) completion_logits = output.logits[0, -len(completion_bytes) :, :] return completion_text, completion_logits def _prepare_input(self, input_str: str, add_llama_template: bool = True, device: torch.device | None = None) -> tuple[torch.Tensor, torch.Tensor]: if add_llama_template: input_str = LLAMA_TEMPLATE.format(input=input_str) if device is None: assert torch.cuda.is_available(), "CUDA is not available" device = torch.device("cuda") input_ids_list = [] cumulative_per_word_lengths_list = [0] words = self.splitter.encode(input_str) for word in words: input_ids_list.extend(word) word_length = len(word) cumulative_per_word_lengths_list.append(cumulative_per_word_lengths_list[-1] + word_length) input_ids = torch.tensor(input_ids_list, device=device, dtype=torch.int32).unsqueeze(0) cumulative_per_word_lengths = torch.tensor(cumulative_per_word_lengths_list, device=device, dtype=torch.int32) return input_ids, cumulative_per_word_lengths