nvedant07's picture
Upload 3 files
f9972a2 verified
raw
history blame contribute delete
37.7 kB
import itertools
from collections.abc import Sequence
from importlib.metadata import PackageNotFoundError, version
from typing import Callable
import torch
import torch.nn as nn
from einops import rearrange
from flash_attn.flash_attn_interface import flash_attn_varlen_func
from transformers import PreTrainedModel
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer,
LlamaRotaryEmbedding,
)
from transformers.utils import ModelOutput
from .config import (
CrossAttentionConfig,
DecoderHATModelConfig,
EncoderHATModelConfig,
HATArchitectureConfig,
TransformerHATModelConfig,
)
from .splitter import HATSplitter
try:
transformers_version = version("transformers")
if transformers_version != "4.46.3":
print(f"Warning: Expecected transformers version 4.46.3, but found {transformers_version}. Outputs might be different.")
except PackageNotFoundError:
print("transformers is not installed")
def sample_argmax(logits: torch.Tensor) -> torch.Tensor:
return torch.argmax(logits, dim=-1)[:, -1]
LLAMA_TEMPLATE = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are a helpful assistant. You give engaging, well-structured answers to user inquiries.<|eot_id|><|start_header_id|>user<|end_header_id|>
{input}<|eot_id|><|start_header_id|>assistant<|end_header_id|>"""
class HATCache(Cache):
encoder_cache: DynamicCache
backbone_cache: DynamicCache
decoder_cache: DynamicCache
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.encoder_cache = DynamicCache()
self.backbone_cache = DynamicCache()
self.decoder_cache = DynamicCache()
def get_backbone_cache(self) -> DynamicCache:
return self.backbone_cache
def get_decoder_cache(self) -> DynamicCache:
return self.decoder_cache
def get_encoder_cache(self) -> DynamicCache:
return self.encoder_cache
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, q_cos=None, q_sin=None, k_cos=None, k_sin=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
and allows for different sequence lengths.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
q_cos (`torch.Tensor`): The cosine part of the rotary embedding.
q_sin (`torch.Tensor`): The sine part of the rotary embedding.
k_cos (`torch.Tensor`): The cosine part of the rotary embedding.
k_sin (`torch.Tensor`): The sine part of the rotary embedding.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze
cos[position_ids] and sin[position_ids] so that they can be properly
broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape
[batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting
unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids]
broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key
tensors rotated using the Rotary Position Embedding.
"""
q_cos = q_cos.unsqueeze(unsqueeze_dim)
q_sin = q_sin.unsqueeze(unsqueeze_dim)
k_cos = k_cos.unsqueeze(unsqueeze_dim)
k_sin = k_sin.unsqueeze(unsqueeze_dim)
q_embed = (q * q_cos) + (rotate_half(q) * q_sin)
k_embed = (k * k_cos) + (rotate_half(k) * k_sin)
return q_embed, k_embed
class HATBackbone(nn.Module):
def __init__(self, config: TransformerHATModelConfig, *args, **kwargs):
super().__init__(*args, **kwargs)
self.layers = nn.ModuleList([LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
self.rotary_emb = LlamaRotaryEmbedding(config=config)
def forward(
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor | None = None,
past_key_values: DynamicCache | None = None,
use_cache: bool | None = False,
) -> BaseModelOutputWithPast:
if use_cache and past_key_values is None:
past_key_values = DynamicCache()
if position_ids is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
position_ids = torch.arange(
past_seen_tokens,
past_seen_tokens + hidden_states.shape[1],
device=hidden_states.device,
).unsqueeze(0)
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
for backbone_layer in self.layers:
layer_outputs = backbone_layer(
hidden_states,
position_ids=position_ids,
past_key_value=past_key_values,
use_cache=use_cache,
position_embeddings=position_embeddings,
)
hidden_states = layer_outputs[0]
return CausalLMOutputWithPast(
hidden_states=hidden_states,
past_key_values=past_key_values if use_cache else None,
)
class HATDecoderConnector(nn.Module):
def __init__(self, backbone_hiden_dim: int, *args, **kwargs):
super().__init__(*args, **kwargs)
self.first_word_embedding = torch.nn.Parameter(
torch.empty(
1,
1,
backbone_hiden_dim,
device="cuda",
dtype=torch.bfloat16,
)
)
def forward(
self,
backbone_activations: torch.Tensor,
):
activations = backbone_activations.clone()
activations[:, -1:, :] = self.first_word_embedding
activations = torch.roll(activations, shifts=1, dims=1)
return activations
class RMSNorm(nn.Module):
def __init__(self, dimensions: int, eps: float, device: torch.device, dtype: torch.dtype = torch.bfloat16, norm_in_fp32: bool = False):
super().__init__()
self.eps = eps
self.weight = torch.nn.Parameter(torch.ones(dimensions, dtype=dtype).to(device))
self.norm_in_fp32 = norm_in_fp32
def forward(self, x: torch.Tensor) -> torch.Tensor:
original_dtype = x.dtype
if self.norm_in_fp32:
x = x.float()
out = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
if out.dtype != original_dtype:
out = out.to(original_dtype)
return out * self.weight
class HATDecoderBlock(nn.Module):
def __init__(
self,
add_cross_attention: bool,
config: DecoderHATModelConfig,
layer_idx: int,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.add_cross_attention = add_cross_attention
self.config = config
self.llama_layer = LlamaDecoderLayer(config, layer_idx)
self.llama_layer.self_attn.sliding_window = config.sliding_window
if add_cross_attention:
self.cross_attention = HATCrossAttention(
hidden_size=config.cross_attention_config.hidden_size,
hidden_size_kv=config.cross_attention_config.hidden_size_kv,
hidden_size_q=config.cross_attention_config.hidden_size_q,
config=config,
cross_attention_config=config.cross_attention_config,
)
self.query_norm = RMSNorm(
config.cross_attention_config.hidden_size_q,
eps=config.rms_norm_eps,
device=torch.device("cuda"),
dtype=torch.bfloat16,
norm_in_fp32=False,
)
self.kv_norm = RMSNorm(
config.cross_attention_config.hidden_size_kv,
eps=config.rms_norm_eps,
device=torch.device("cuda"),
dtype=torch.bfloat16,
norm_in_fp32=False,
)
def apply_norm(self, activations):
return self.query_norm(activations), self.kv_norm(activations)
def forward(
self,
encoder_activations,
backbone_activations,
byte_position_ids,
word_position_ids,
cumulative_seq_lengths_per_word,
position_embeddings,
past_key_values,
use_cache,
):
if self.add_cross_attention:
kv_activations = self.kv_norm(backbone_activations)
q_activations = self.query_norm(encoder_activations)
activations = self.cross_attention.forward(
q_activations=q_activations,
kv_activations=kv_activations,
position_ids_q=byte_position_ids,
position_ids_kv=word_position_ids,
cumulative_seq_q=cumulative_seq_lengths_per_word,
cumulative_seq_kv=torch.arange(0, kv_activations.size(1) + 1, device=encoder_activations.device, dtype=torch.int32),
causal=False,
)
encoder_activations = encoder_activations + activations
return self.llama_layer.forward(
hidden_states=encoder_activations,
position_ids=byte_position_ids,
position_embeddings=position_embeddings,
past_key_value=past_key_values,
use_cache=use_cache,
)[0]
class HATDecoder(nn.Module):
def __init__(self, config: DecoderHATModelConfig, *args, **kwargs):
super().__init__()
self.decoder_layers = nn.Sequential()
for layer_idx in range(config.num_hidden_layers):
add_cross_attention = config.cross_attn_every_layer or layer_idx == 0
self.decoder_layers.add_module(
str(layer_idx),
HATDecoderBlock(
add_cross_attention,
config,
layer_idx,
),
)
self.rotary_emb = LlamaRotaryEmbedding(config=config)
def forward(
self,
backbone_activations: torch.Tensor,
activations: torch.Tensor,
cumulative_seq_lengths_per_word: torch.Tensor | None = None,
byte_position_ids: torch.Tensor | None = None,
word_position_ids: torch.Tensor | None = None,
past_key_values: DynamicCache | None = None,
use_cache: bool | None = False,
) -> BaseModelOutputWithPast:
if use_cache and past_key_values is None:
past_key_values = DynamicCache()
if byte_position_ids is None:
past_seen_bytes = past_key_values.get_seq_length() if past_key_values is not None else 0
byte_position_ids = torch.arange(
past_seen_bytes,
past_seen_bytes + activations.size(1),
device=activations.device,
dtype=torch.int32,
).unsqueeze(0)
if cumulative_seq_lengths_per_word is None:
cumulative_seq_lengths_per_word = torch.tensor([0, byte_position_ids.size(1)], dtype=byte_position_ids.dtype, device=byte_position_ids.device)
if word_position_ids is None:
raise ValueError() # TODO
position_embeddings = self.rotary_emb(activations, byte_position_ids)
for _, layer in enumerate(self.decoder_layers):
activations = layer(
encoder_activations=activations,
backbone_activations=backbone_activations,
position_embeddings=position_embeddings,
cumulative_seq_lengths_per_word=cumulative_seq_lengths_per_word,
byte_position_ids=byte_position_ids,
word_position_ids=word_position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
)
return BaseModelOutputWithPast(
last_hidden_state=activations,
past_key_values=past_key_values if use_cache else None,
)
class HATCrossAttention(nn.Module):
def __init__(
self,
hidden_size: int,
hidden_size_q: int,
hidden_size_kv: int,
config: EncoderHATModelConfig | DecoderHATModelConfig,
cross_attention_config: CrossAttentionConfig,
dtype: torch.dtype = torch.bfloat16,
):
super().__init__()
self.hidden_size = hidden_size
self.hidden_size_q = hidden_size_q
self.hidden_size_kv = hidden_size_kv
self.num_heads = cross_attention_config.num_attention_heads
self.num_key_value_heads = cross_attention_config.attention_num_kv_heads
self.num_repeat_kv = cross_attention_config.num_attention_heads // cross_attention_config.attention_num_kv_heads
self.head_dim = hidden_size // self.num_heads
self.q_proj = nn.Linear(
in_features=hidden_size_q,
out_features=hidden_size,
dtype=dtype,
bias=False,
)
self.k_proj = nn.Linear(
in_features=hidden_size_kv,
out_features=hidden_size // self.num_repeat_kv,
dtype=dtype,
bias=False,
)
self.v_proj = nn.Linear(
in_features=hidden_size_kv,
out_features=hidden_size // self.num_repeat_kv,
dtype=dtype,
bias=False,
)
self.o_proj = nn.Linear(in_features=hidden_size, out_features=hidden_size_q, dtype=dtype, bias=False)
rope_theta = config.rope_theta
rope_type = config.rope_scaling["rope_type"]
self.rotary_emb = LlamaRotaryEmbedding(dim=self.head_dim, base=rope_theta, rope_type=rope_type)
def forward(
self,
q_activations: torch.Tensor,
kv_activations: torch.Tensor,
position_ids_q: torch.Tensor,
position_ids_kv: torch.Tensor,
cumulative_seq_kv: torch.Tensor,
cumulative_seq_q: torch.Tensor,
causal: bool = True,
use_cache: bool = False,
past_key_value: DynamicCache | None = None,
):
q_len = cumulative_seq_q[-1]
bsz, _, _ = kv_activations.size()
query_states = self.q_proj(q_activations)
key_states = self.k_proj(kv_activations)
value_states = self.v_proj(kv_activations)
# TODO get rid of the double rearrange, this is just for compatibility with scaling
query_states = rearrange(query_states, "bsz seq_len (h d) -> bsz h seq_len d", h=self.num_heads)
key_states = rearrange(
key_states,
"bsz seq_len (h d) -> bsz h seq_len d",
h=self.num_key_value_heads,
)
value_states = rearrange(
value_states,
"bsz seq_len (h d) -> bsz h seq_len d",
h=self.num_key_value_heads,
)
# WIP: Should word_positions_id respect document boundaries?
q_cos, q_sin = self.rotary_emb(query_states, position_ids_q)
k_cos, k_sin = self.rotary_emb(key_states, position_ids_kv)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, q_cos=q_cos, q_sin=q_sin, k_cos=k_cos, k_sin=k_sin)
query_states = rearrange(query_states, "bsz h seq_len d -> (bsz seq_len) h d")
key_states = rearrange(key_states, "bsz h seq_len d -> (bsz seq_len) h d")
value_states = rearrange(value_states, "bsz h seq_len d -> (bsz seq_len) h d")
attn_output = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cumulative_seq_q,
cu_seqlens_k=cumulative_seq_kv,
max_seqlen_q=self._get_max_seqlen(cumulative_seq_q),
max_seqlen_k=self._get_max_seqlen(cumulative_seq_kv),
causal=False,
)
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output
def _get_max_seqlen(self, cumulative_word_lengths: torch.Tensor):
diffs = cumulative_word_lengths[1:] - cumulative_word_lengths[:-1]
return int(diffs.max().item())
class HATEncoderConnector(nn.Module):
def __init__(
self,
config: EncoderHATModelConfig,
backbone_hidden_size: int,
dtype: torch.dtype = torch.bfloat16,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.latent_query = torch.nn.Parameter(
torch.empty(
1,
1,
backbone_hidden_size,
device="cuda",
dtype=dtype,
)
)
self.cross_attention_encoder_connector = HATCrossAttention(
hidden_size=config.cross_attention_config.hidden_size,
hidden_size_q=backbone_hidden_size,
hidden_size_kv=config.hidden_size,
config=config,
cross_attention_config=config.cross_attention_config,
)
def forward(
self,
hidden_states: torch.Tensor,
cumulative_seq_lengths_per_word: torch.Tensor,
word_position_ids: torch.Tensor,
byte_position_ids: torch.Tensor,
):
q_len = cumulative_seq_lengths_per_word.shape[0] - 1
latent_query_repeated = self.latent_query.expand(-1, q_len, -1)
cumulative_seq_lengths_q = torch.arange(
start=0,
end=latent_query_repeated.shape[1] + 1,
step=1,
device=self.latent_query.device,
dtype=torch.int32,
)
word_embeddings = self.cross_attention_encoder_connector.forward(
q_activations=latent_query_repeated,
kv_activations=hidden_states,
position_ids_q=word_position_ids,
position_ids_kv=byte_position_ids,
cumulative_seq_q=cumulative_seq_lengths_q,
cumulative_seq_kv=cumulative_seq_lengths_per_word,
)
return word_embeddings
class HATEncoder(nn.Module):
def __init__(
self,
config: EncoderHATModelConfig,
dtype: torch.dtype = torch.bfloat16,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.embedding_layer = nn.Embedding(config.vocab_size, config.hidden_size, dtype=dtype)
self.layers = nn.ModuleList([LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
for layer in self.layers:
layer.self_attn.sliding_window = config.sliding_window
self.rotary_emb = LlamaRotaryEmbedding(config=config)
self.word_window_size = config.cross_attention_config.word_window_size
def forward(
self,
input_ids: torch.Tensor,
cumulative_seq_lengths_per_word: torch.Tensor | None = None,
byte_position_ids: torch.Tensor | None = None,
word_position_ids: torch.Tensor | None = None, # TODO: Remove
past_key_values: DynamicCache | None = None,
use_cache: bool | None = False,
):
input_embeds = self.embedding_layer(input_ids)
if cumulative_seq_lengths_per_word is None:
cumulative_seq_lengths_per_word = torch.tensor([0, input_embeds.shape[1]], dtype=torch.int32, device=input_ids.device)
if use_cache and past_key_values is None:
past_key_values = DynamicCache()
if byte_position_ids is None:
past_seen_bytes = past_key_values.get_seq_length() if past_key_values is not None else 0
byte_position_ids = torch.arange(
past_seen_bytes,
past_seen_bytes + input_embeds.shape[1],
device=input_embeds.device,
).unsqueeze(0)
if word_position_ids is None:
raise ValueError() # TODO
hidden_states = input_embeds
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, byte_position_ids)
for layer in self.layers:
layer_outputs = layer(
hidden_states,
position_ids=byte_position_ids,
past_key_value=past_key_values,
use_cache=use_cache,
position_embeddings=position_embeddings,
)
hidden_states = layer_outputs[0]
return CausalLMOutputWithPast(
hidden_states=hidden_states,
past_key_values=past_key_values if use_cache else None,
)
class HATForCausalLM(PreTrainedModel):
config_class = HATArchitectureConfig
_supports_flash_attn_2 = True
_supports_cache_class = True
def __init__(self, config: HATArchitectureConfig, *args, **kwargs):
super().__init__(config, *args, **kwargs)
self.config = config
self.eos_token_id = config.eos_token_id
self.encoder = HATEncoder(config.encoder_config)
self.encoder_connector = HATEncoderConnector(config.encoder_config, config.backbone_config.hidden_size)
self.backbone = HATBackbone(config.backbone_config)
self.decoder_connector = HATDecoderConnector(config.backbone_config.hidden_size)
self.decoder = HATDecoder(config.decoder_config)
self.splitter = HATSplitter(special_token_dict=config.special_token_dict, max_word_size=config.max_word_size)
self.layer_norm = RMSNorm(config.decoder_config.hidden_size, eps=config.decoder_config.rms_norm_eps, device=torch.device("cuda"), dtype=torch.bfloat16, norm_in_fp32=False)
self.lm_head = nn.Linear(
in_features=config.decoder_config.hidden_size,
out_features=config.decoder_config.vocab_size,
dtype=torch.bfloat16,
bias=False,
)
def forward(
self,
input_ids: torch.Tensor,
byte_position_ids: torch.Tensor,
cumulative_seq_lengths_per_word: torch.Tensor | None = None,
word_position_ids: torch.Tensor | None = None,
past_key_values: HATCache | None = None,
use_cache: bool = False,
):
use_cache = use_cache if use_cache is not None else self.config.use_cache
if past_key_values is None and use_cache:
past_key_values = HATCache()
encoder_past_key_values = past_key_values.get_encoder_cache() if past_key_values is not None else None
backbone_past_key_values = past_key_values.get_backbone_cache() if past_key_values is not None else None
decoder_past_key_values = past_key_values.get_decoder_cache() if past_key_values is not None else None
encoder_output: BaseModelOutputWithPast = self.encoder(
input_ids=input_ids,
cumulative_seq_lengths_per_word=cumulative_seq_lengths_per_word,
byte_position_ids=byte_position_ids,
word_position_ids=word_position_ids,
past_key_values=encoder_past_key_values,
use_cache=use_cache,
)
byte_level_activations = encoder_output.hidden_states
encoder_connector_output = self.encoder_connector(
byte_level_activations,
cumulative_seq_lengths_per_word,
word_position_ids,
byte_position_ids,
)
backbone_output: CausalLMOutputWithPast = self.backbone(
hidden_states=encoder_connector_output,
position_ids=word_position_ids,
past_key_values=backbone_past_key_values,
use_cache=use_cache,
)
predictive_word_embeddings = self.decoder_connector.forward(backbone_activations=backbone_output.hidden_states)
decoder_output = self.decoder.forward(
activations=byte_level_activations,
backbone_activations=predictive_word_embeddings,
cumulative_seq_lengths_per_word=cumulative_seq_lengths_per_word,
byte_position_ids=byte_position_ids,
word_position_ids=word_position_ids,
past_key_values=decoder_past_key_values,
use_cache=use_cache,
)
decoder_output = self.layer_norm(decoder_output.last_hidden_state)
logits = self.lm_head(decoder_output)
loss = None
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=past_key_values if use_cache else None,
hidden_states=backbone_output.hidden_states,
attentions=None,
)
def _append_byte(self, words: list[list[int]], token: int) -> list[list[int]]:
extended_last_word = words.pop() + [token]
try:
text = self.splitter.decode(extended_last_word, skip_special_tokens=False)
list_of_bytes = self.splitter.encode(text)
words.extend([list(word_in_bytes) for word_in_bytes in list_of_bytes])
except UnicodeDecodeError:
# if decoding fails, the token cannot be part of a new word since it is not a valid
# utf-8 end byte and we append it to the current word
words.append(extended_last_word)
return words
def _complete_word(
self,
input_ids: torch.Tensor,
byte_position_ids: torch.Tensor,
backbone_word_prediction: torch.Tensor,
word_position_id: torch.Tensor,
encoder_cache: DynamicCache,
decoder_cache: DynamicCache,
sample_fn: Callable[[torch.Tensor], torch.Tensor] = sample_argmax,
):
"""Generate byte tokens until we hit the first byte of a new word."""
words = [input_ids.squeeze(0).tolist()]
byte_encoder_activations = []
completion_logits = []
while True:
encoder_output = self.encoder.forward(
input_ids,
byte_position_ids=None,
word_position_ids=word_position_id,
past_key_values=encoder_cache,
use_cache=True,
)
byte_encoder_activations.append(encoder_output.hidden_states)
decoder_output = self.decoder.forward(
backbone_word_prediction,
encoder_output.hidden_states,
byte_position_ids=None,
word_position_ids=word_position_id,
past_key_values=decoder_cache,
use_cache=True,
)
decoder_output = self.layer_norm(decoder_output.last_hidden_state)
logits = self.lm_head(decoder_output)
completion_logits.append(logits[0, -1:, :])
next_byte = int(sample_fn(logits).item())
words = self._append_byte(words, next_byte)
if len(words) > 1 or next_byte == self.eos_token_id:
break
input_ids = torch.tensor([[next_byte]], dtype=input_ids.dtype, device=input_ids.device)
byte_encoder_activations = torch.cat(byte_encoder_activations, dim=1)
num_kv = encoder_cache.get_seq_length()
byte_position_ids = torch.arange(num_kv + 1 - byte_encoder_activations.shape[1], num_kv + 1, device=input_ids.device, dtype=torch.long).unsqueeze(0)
completed_word_embedding = self.encoder_connector.forward(
byte_encoder_activations,
cumulative_seq_lengths_per_word=torch.tensor([0, byte_encoder_activations.size(1)], dtype=torch.int32, device=input_ids.device),
word_position_ids=word_position_id,
byte_position_ids=byte_position_ids,
)
completion = sum(words, [])[-len(completion_logits) :]
first_byte_of_next_word = words[1]
return completion, completed_word_embedding, first_byte_of_next_word, byte_position_ids[:, -1].item() + 1, completion_logits
def generate(
self,
input_ids: torch.Tensor,
max_new_tokens: int,
cumulative_seq_lengths_per_word: torch.Tensor,
byte_position_ids: torch.Tensor | None = None,
word_position_ids: torch.Tensor | None = None,
sample_fn: Callable[[torch.Tensor], torch.Tensor] = sample_argmax,
use_cache: bool = True,
stop_sequences: Sequence[str] | None = None,
):
if use_cache:
completion_text, completion_logits = self._generate_cached(input_ids, max_new_tokens, cumulative_seq_lengths_per_word, byte_position_ids, word_position_ids, sample_fn, stop_sequences=stop_sequences)
else:
completion_text, completion_logits = self._generate_uncached(input_ids, max_new_tokens, cumulative_seq_lengths_per_word, byte_position_ids, word_position_ids, sample_fn, stop_sequences=stop_sequences)
# remove stop sequence if exists
if stop_sequences is not None:
stop_sequences = sorted(stop_sequences, key=lambda i: len(i), reverse=True)
for stop_sequence in stop_sequences:
if stop_sequence in completion_text:
completion_text_left = completion_text.split(stop_sequence)[0]
completion_text_removed = completion_text[len(completion_text_left) :]
completion_logits = completion_logits[: -len(list(bytes(completion_text_removed.encode("UTF-8"))))]
completion_text = completion_text_left
break
return ModelOutput(
completion_text=completion_text,
input_ids=input_ids,
completion_logits=completion_logits,
)
@torch.no_grad()
def _generate_cached(
self,
input_ids: torch.Tensor,
max_new_tokens: int,
cumulative_seq_lengths_per_word: torch.Tensor,
byte_position_ids: torch.Tensor | None = None,
word_position_ids: torch.Tensor | None = None,
sample_fn: Callable[[torch.Tensor], torch.Tensor] = sample_argmax,
stop_sequences: Sequence[str] | None = None,
):
max_total_bytes = max_new_tokens + input_ids.shape[1]
if byte_position_ids is None:
byte_position_ids = torch.arange(0, cumulative_seq_lengths_per_word[-1].item(), device=input_ids.device, dtype=torch.int32).unsqueeze(0)
if word_position_ids is None:
word_position_ids = torch.arange(0, cumulative_seq_lengths_per_word.shape[0] - 1, device=input_ids.device, dtype=torch.int32).unsqueeze(0)
last_word_start, last_word_end = (
cumulative_seq_lengths_per_word[-2],
cumulative_seq_lengths_per_word[-1],
)
# Populate cache with everything except last word
initial_forward_output = self.forward(
input_ids=input_ids[:, :last_word_start],
cumulative_seq_lengths_per_word=cumulative_seq_lengths_per_word[:-1],
byte_position_ids=byte_position_ids[:, :last_word_start],
word_position_ids=word_position_ids[:, :-1],
past_key_values=None,
use_cache=True,
)
completion_bytes = []
completion_logits = []
input_ids = input_ids[:, last_word_start:last_word_end]
next_byte_id = last_word_end
byte_position_ids = byte_position_ids[:, last_word_start:last_word_end]
word_position_id = word_position_ids[:, -1].unsqueeze(-1)
backbone_last_hidden_state = initial_forward_output.hidden_states[:, -1:, :]
while next_byte_id < max_total_bytes:
completion, completed_word_embedding, first_byte_of_next_word, next_byte_id, next_completion_logits = self._complete_word(
input_ids=input_ids,
byte_position_ids=byte_position_ids,
backbone_word_prediction=backbone_last_hidden_state,
word_position_id=word_position_id,
encoder_cache=initial_forward_output.past_key_values.get_encoder_cache(),
decoder_cache=initial_forward_output.past_key_values.get_decoder_cache(),
sample_fn=sample_fn,
)
completion_logits.extend(next_completion_logits)
completion_bytes.extend(completion)
if self.eos_token_id in completion_bytes:
completion_bytes = completion_bytes[: completion_bytes.index(self.eos_token_id)]
break
if stop_sequences is not None:
try:
completion_text_tmp = self.splitter.decode(completion_bytes)
if any(stop_sequence in completion_text_tmp for stop_sequence in stop_sequences):
break
except Exception as e:
print("Cannot compare stop sequence", e)
backbone_output = self.backbone.forward(
hidden_states=completed_word_embedding,
position_ids=None,
past_key_values=initial_forward_output.past_key_values.get_backbone_cache(),
use_cache=True,
)
backbone_last_hidden_state = backbone_output.hidden_states[:, -1, :].unsqueeze(1)
input_ids = torch.tensor([first_byte_of_next_word], dtype=input_ids.dtype, device=input_ids.device)
byte_position_ids = torch.tensor([[next_byte_id]], dtype=input_ids.dtype, device=input_ids.device)
word_position_id = word_position_id + 1
completion_bytes.extend(first_byte_of_next_word)
completion_bytes = completion_bytes[:max_new_tokens]
completion_logits = torch.cat(completion_logits[:max_new_tokens], dim=0)
completion_text = self.splitter.decode(completion_bytes)
return completion_text, completion_logits
@torch.no_grad()
def _generate_uncached(
self,
input_ids: torch.Tensor,
max_new_tokens: int,
cumulative_seq_lengths_per_word: torch.Tensor,
byte_position_ids: torch.Tensor | None = None,
word_position_ids: torch.Tensor | None = None,
sample_fn=sample_argmax,
stop_sequences: Sequence[str] | None = None,
):
if byte_position_ids is None:
byte_position_ids = torch.arange(0, cumulative_seq_lengths_per_word[-1].item(), device=input_ids.device, dtype=torch.int32).unsqueeze(0)
if word_position_ids is None:
word_position_ids = torch.arange(0, cumulative_seq_lengths_per_word.shape[0] - 1, device=input_ids.device, dtype=torch.int32).unsqueeze(0)
word_list = []
for i in range(1, cumulative_seq_lengths_per_word.shape[0]):
start_idx = cumulative_seq_lengths_per_word[i - 1]
end_idx = cumulative_seq_lengths_per_word[i]
word_list.append(input_ids[:, start_idx:end_idx].squeeze(0).tolist())
completion_bytes = []
for _ in range(max_new_tokens):
output = self.forward(
input_ids=input_ids,
cumulative_seq_lengths_per_word=cumulative_seq_lengths_per_word,
byte_position_ids=byte_position_ids,
word_position_ids=word_position_ids,
past_key_values=None,
)
next_byte = int(sample_fn(output.logits).item())
completion_bytes.append(next_byte)
if next_byte == self.eos_token_id:
break
word_list = self._append_byte(word_list, next_byte)
input_ids = torch.tensor(sum(word_list, []), dtype=torch.long, device=input_ids.device).unsqueeze(0)
cumulative_seq_lengths_per_word = torch.tensor([0] + list(itertools.accumulate(len(word) for word in word_list if len(word) > 0)), dtype=torch.int32, device=input_ids.device)
byte_position_ids = torch.arange(0, input_ids.shape[1], device=input_ids.device, dtype=torch.int32).unsqueeze(0)
word_position_ids = torch.arange(0, cumulative_seq_lengths_per_word.shape[0] - 1, device=input_ids.device, dtype=torch.int32).unsqueeze(0)
if stop_sequences is not None:
try:
completion_text_tmp = self.splitter.decode(completion_bytes)
if any(completion_text_tmp.endswith(stop_sequence) for stop_sequence in stop_sequences):
break
except Exception as e:
print("Cannot compare stop sequence", e)
completion_text = self.splitter.decode(completion_bytes)
completion_logits = output.logits[0, -len(completion_bytes) :, :]
return completion_text, completion_logits
def _prepare_input(self, input_str: str, add_llama_template: bool = True, device: torch.device | None = None) -> tuple[torch.Tensor, torch.Tensor]:
if add_llama_template:
input_str = LLAMA_TEMPLATE.format(input=input_str)
if device is None:
assert torch.cuda.is_available(), "CUDA is not available"
device = torch.device("cuda")
input_ids_list = []
cumulative_per_word_lengths_list = [0]
words = self.splitter.encode(input_str)
for word in words:
input_ids_list.extend(word)
word_length = len(word)
cumulative_per_word_lengths_list.append(cumulative_per_word_lengths_list[-1] + word_length)
input_ids = torch.tensor(input_ids_list, device=device, dtype=torch.int32).unsqueeze(0)
cumulative_per_word_lengths = torch.tensor(cumulative_per_word_lengths_list, device=device, dtype=torch.int32)
return input_ids, cumulative_per_word_lengths