|
|
|
|
|
from dataclasses import asdict, dataclass, field |
|
from typing import Dict, List, Optional, Tuple |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from transformers import PretrainedConfig, PreTrainedModel |
|
|
|
|
|
@dataclass |
|
class RotaryEmbeddingConfig: |
|
""" |
|
Rotary Positional Embedding configuration |
|
max_seq_len: The number of positions to encode and cache. |
|
dim: Dimension of RoPE. |
|
theta: Rotation angle. |
|
""" |
|
|
|
max_seq_len: int |
|
dim: int |
|
theta: float |
|
|
|
|
|
@dataclass |
|
class PerceiverResamplerConfig: |
|
""" |
|
Parameters to initialize an PerceiverResampler model. |
|
|
|
Args: |
|
emb_layer_norm_before: Whether to use layer norm before the first attention |
|
layer. |
|
attention_heads: Number of attention heads. |
|
key_size: The dimension of the query, key, and values within each attention |
|
head, if not specified, it is set to attention_heads//embed_dim. |
|
It can be useful to set a custom key size if we want to impose the size of |
|
the query, key and value tensor ( for example, tensors shaped with |
|
power of 2 are more efficiently handled on TPUs ). |
|
Note: Parametrizing the model with a custom key size has been done in : |
|
Brown, Tom, et al. "Language models are few-shot learners." |
|
Advances in neural information processing systems 33 (2020): 1877-1901. |
|
embed_dim: Embedding dimension. |
|
ffn_embed_dim: Feed forward embedding dimension. |
|
num_layers: Number of attention blocks. |
|
ffn_activation_name: Activation function to be used in FFN block. Supported |
|
names are "gelu", "relu", "swish". |
|
use_glu_in_ffn: Whether to use Gated Linear Unit (GLU) in Feed |
|
Forward Network (FFN) block. To do a swiGLU (gated-swish) put this arg |
|
to True and use swish as ffn_activation_name. |
|
Same principle for a gated-relu. To keep the same number of parameters in |
|
the FFN block, one should multiply by 2/3 the ffn_embed_dim when using GLU. |
|
See https://arxiv.org/pdf/2002.05202.pdf for more details. |
|
resampled_length: length of the resampled output of the module |
|
use_gradient_checkpointing: Whether to use gradient checkpointing (checkpoint |
|
gradients in the forward pass to reduce the computation in the backward). |
|
""" |
|
|
|
|
|
emb_layer_norm_before: bool = False |
|
attention_heads: int = 20 |
|
key_size: Optional[int] = None |
|
embed_dim: int = 1280 |
|
ffn_embed_dim: int = 5120 |
|
num_layers: int = 24 |
|
add_bias_kv: bool = False |
|
add_bias_ffn: bool = True |
|
ffn_activation_name: str = "gelu-no-approx" |
|
use_glu_in_ffn: bool = False |
|
resampled_length: int = 64 |
|
|
|
|
|
use_gradient_checkpointing: bool = False |
|
|
|
def __post_init__(self) -> None: |
|
""" |
|
Checks that the given values are compatible. |
|
""" |
|
|
|
if self.key_size is None: |
|
if not self.embed_dim % self.attention_heads == 0: |
|
raise ValueError( |
|
f"When no key size is provided, the embedding dimension should be " |
|
f"divisible by the number of heads, however provided embedding " |
|
f"dimension is {self.embed_dim} and the number of heads is " |
|
f"{self.attention_heads}." |
|
) |
|
self.key_size = self.embed_dim // self.attention_heads |
|
|
|
|
|
@dataclass |
|
class GptConfig: |
|
""" |
|
Parameters to initialize a Gpt model. |
|
|
|
NOTE: the pad token is not defined |
|
|
|
Args: |
|
vocab_size: Token vocabulary. |
|
eos_token_id: used to stop sentence generation |
|
embed_dim: Embedding dimension. |
|
ffn_embed_dim: Feed forward embedding dimension. |
|
num_heads: Number of attention heads. |
|
num_kv_heads: Number of key and value heads to support Grouped-Query and |
|
Multi-Query Attention. If None, the number of key and value heads is |
|
equal to the number of attention heads. |
|
num_layers: Number of Decoder layer_stack |
|
rope_config: The configuration for the rotary positional embeddings |
|
add_bias_ffn: Add bias in feed forward network block. |
|
ffn_activation_name: Activation function to be used in FFN block. Supported |
|
names are "gelu", "gelu-no-approx", "relu", "swish". |
|
use_glu_in_ffn: whether to use Gated Linear Unit (GLU) in Feed |
|
Forward Network (FFN) block. |
|
example: To do a swiGLU (gated-swish) put this arg |
|
to True and use swish as ffn_activation_name. |
|
Same principle for a gated-relu. |
|
add_bias_lm_head: whether to use bias in the final LM layer |
|
norm_type: The type of norm used ( pre normalization scheme ) used. can be |
|
one of ["layer_norm", "RMS_norm"] |
|
parallel_attention_ff: Whether to do the attention and the MLP in parallel, |
|
and then sum up the results as it is done in Gpt-NeoX : |
|
Black, Sid, et al. "Gpt-neox-20b: An open-source autoregressive |
|
language model." arXiv preprint arXiv:2204.06745 (2022). |
|
It is said to improve the training time of 15% when compiling with JAX |
|
use_gradient_checkpointing: Whether to use gradient checkpointing (checkpoint |
|
gradients in the forward pass to reduce the computation in the backward). |
|
add_bias_attn: Add bias to the attention mechanism (key, query, value, and |
|
output projections). |
|
""" |
|
|
|
|
|
vocab_size: int |
|
eos_token_id: int |
|
|
|
|
|
embed_dim: int = 16 |
|
ffn_embed_dim: int = 64 |
|
num_heads: int = 2 |
|
num_kv_heads: Optional[int] = None |
|
num_layers: int = 2 |
|
rope_config: RotaryEmbeddingConfig = field( |
|
default_factory=lambda: RotaryEmbeddingConfig( |
|
max_seq_len=512, dim=8, theta=10000.0 |
|
) |
|
) |
|
add_bias_ffn: bool = False |
|
ffn_activation_name: str = "swish" |
|
use_glu_in_ffn: bool = True |
|
add_bias_lm_head: bool = False |
|
norm_type: str = "RMS_norm" |
|
rms_norm_eps: float = 1e-6 |
|
parallel_attention_ff: bool = True |
|
|
|
|
|
use_gradient_checkpointing: bool = False |
|
|
|
|
|
add_bias_attn: bool = False |
|
|
|
def __post_init__(self) -> None: |
|
""" |
|
Checks that the given values are compatible. |
|
""" |
|
if not self.embed_dim % self.num_heads == 0: |
|
raise ValueError( |
|
f"The embedding dimension should be " |
|
f"divisible by the number of heads, however provided embedding " |
|
f"dimension is {self.embed_dim} and the number of heads is " |
|
f"{self.num_heads}." |
|
) |
|
|
|
if not self.embed_dim // self.num_heads > 1: |
|
raise ValueError( |
|
"embed_dim / num_heads must be higher than 2 to apply rotary embeddings" |
|
) |
|
|
|
if not self.embed_dim // self.num_heads >= self.rope_config.dim: |
|
raise ValueError( |
|
"embed_dim // num_heads must be higher than rope_config.dim " |
|
"to apply rotary embeddings" |
|
) |
|
|
|
def to_dict(self): |
|
output = asdict(self) |
|
output["rope_config"] = asdict(self.rope_config) |
|
return output |
|
|
|
|
|
@dataclass |
|
class NucleotideTransformerConfig: |
|
""" |
|
Parameters to initialize an NT model. |
|
|
|
Args: |
|
alphabet_size: Token vocabulary. |
|
pad_token_id: ID of pad token. |
|
mask_token_id: ID of mask token. |
|
max_positions: Maximum sequence length. |
|
embed_scale: Correction ratio applied to the embeddings to make up for the |
|
norm difference between the input during training and inference. |
|
emb_layer_norm_before: Whether to use layer norm before the first attention |
|
layer. |
|
attention_heads: Number of attention heads. |
|
key_size: The dimension of the query, key, and values within each attention |
|
head, if not specified, it is set to attention_heads//embed_dim. |
|
It can be useful to set a custom key size if we want to impose the size of |
|
the query, key and value tensor ( for example, tensors shaped with |
|
power of 2 are more efficiently handled on TPUs ). |
|
Note: Parametrizing the model with a custom key size has been done in : |
|
Brown, Tom, et al. "Language models are few-shot learners." |
|
Advances in neural information processing systems 33 (2020): 1877-1901. |
|
embed_dim: Embedding dimension. |
|
ffn_embed_dim: Feed forward embedding dimension. |
|
num_layers: Number of attention blocks. |
|
positional_embedding: Type of positional embedding to use before the first |
|
attention layer. Options: "learned", "learned_standard" "sinusoidal" or |
|
None. |
|
NOTE: "learned" is the positional embedding of ESM, and "learned_standard" |
|
is a more standard one, used for example in DNAbert. |
|
lm_head: type of language model head. Options: "simple", "roberta" or None. |
|
add_bias_kv: Add bias in attention layer. |
|
add_bias_ffn: Add bias in feed forward network block. |
|
use_rotary_embedding: Whether to use rotary embeddings. Requires: |
|
positional_embeddings = None. |
|
rescaling_factor: Scaling factor to use for rotary embeddings. |
|
ffn_activation_name: Activation function to be used in FFN block. Supported |
|
names are "gelu", "relu", "swish". |
|
use_glu_in_ffn: Whether to use Gated Linear Unit (GLU) in Feed |
|
Forward Network (FFN) block. To do a swiGLU (gated-swish) put this arg |
|
to True and use swish as ffn_activation_name. |
|
Same principle for a gated-relu. To keep the same number of parameters in |
|
the FFN block, one should multiply by 2/3 the ffn_embed_dim when using GLU. |
|
See https://arxiv.org/pdf/2002.05202.pdf for more details. |
|
mask_before_attention: Use mask before attention layers. |
|
layer_norm_eps: the eps factor in the different layer norms of the model (refer |
|
to layer norm implementation) |
|
token_dropout: Token dropout. |
|
masking_ratio: Masking ratio (used if token dropout is enabled). |
|
masking_prob: Masking probability (used if token dropout is enabled). |
|
use_gradient_checkpointing: Whether to use gradient checkpointing (checkpoint |
|
gradients in the forward pass to reduce the computation in the backward). |
|
""" |
|
|
|
alphabet_size: int |
|
pad_token_id: int |
|
mask_token_id: int |
|
|
|
max_positions: int = 1024 |
|
embed_scale: float = 1.0 |
|
|
|
|
|
emb_layer_norm_before: bool = False |
|
attention_heads: int = 20 |
|
key_size: Optional[int] = None |
|
embed_dim: int = 1280 |
|
ffn_embed_dim: int = 5120 |
|
num_layers: int = 24 |
|
positional_embedding: Optional[str] = "learned" |
|
lm_head: Optional[str] = "simple" |
|
add_bias_kv: bool = False |
|
add_bias_ffn: bool = True |
|
use_rotary_embedding: bool = False |
|
rescaling_factor: Optional[float] = None |
|
ffn_activation_name: str = "gelu-no-approx" |
|
use_glu_in_ffn: bool = False |
|
mask_before_attention: bool = False |
|
layer_norm_eps: float = 1e-5 |
|
pre_layer_norm: bool = True |
|
bias_word_embedding: bool = False |
|
|
|
|
|
token_dropout: bool = False |
|
masking_ratio: float = 0.1 |
|
masking_prob: float = 0.8 |
|
|
|
|
|
use_gradient_checkpointing: bool = False |
|
|
|
|
|
embeddings_layers_to_save: List[int] = field(default_factory=list) |
|
attention_maps_to_save: List[Tuple[int, int]] = field(default_factory=list) |
|
|
|
def __post_init__(self) -> None: |
|
""" |
|
Checks that the given values are compatible. |
|
""" |
|
|
|
if self.key_size is None: |
|
if not self.embed_dim % self.attention_heads == 0: |
|
raise ValueError( |
|
f"When no key size is provided, the embedding dimension should be " |
|
f"divisible by the number of heads, however provided embedding " |
|
f"dimension is {self.embed_dim} and the number of heads is " |
|
f"{self.attention_heads}." |
|
) |
|
self.key_size = self.embed_dim // self.attention_heads |
|
if self.positional_embedding is not None: |
|
if type(self.positional_embedding) != str: |
|
raise TypeError |
|
|
|
if self.positional_embedding not in [ |
|
"learned", |
|
"sinusoidal", |
|
"learned_standard", |
|
"alibi_dnabert_2", |
|
]: |
|
raise ValueError( |
|
"The positional_embedding argument should either be None," |
|
"`learned`, `sinusoidal`, 'learned_standard' or 'alibi_dnabert_2'." |
|
) |
|
if self.lm_head is not None: |
|
if type(self.lm_head) != str: |
|
raise TypeError |
|
|
|
if self.lm_head not in ["simple", "roberta"]: |
|
raise ValueError( |
|
"The lm_head argument should either be None," |
|
"`simple` or `roberta`." |
|
) |
|
|
|
if self.use_rotary_embedding and self.positional_embedding is not None: |
|
raise ValueError( |
|
"When using rotary embedding, positional_embedding must be set to none" |
|
) |
|
|
|
if self.add_bias_kv and self.use_rotary_embedding: |
|
raise ValueError( |
|
"Biases on key and values are not compatible with Rotary embeddings." |
|
) |
|
|
|
if self.positional_embedding == "alibi_dnabert_2": |
|
assert not self.add_bias_kv |
|
|
|
|
|
@dataclass |
|
class ChatNTConfig(PretrainedConfig): |
|
model_type = "ChatNT" |
|
|
|
def __init__(self, **kwargs): |
|
self.gpt_config: GptConfig = kwargs.get("gpt_config", GptConfig(32000, 3)) |
|
self.nt_config: NucleotideTransformerConfig = kwargs.get( |
|
"nt_config", NucleotideTransformerConfig(4000, 1, 4) |
|
) |
|
self.perceiver_resampler_config: PerceiverResamplerConfig = kwargs.get( |
|
"perceiver_resampler_config", PerceiverResamplerConfig() |
|
) |
|
self.seq_token_id: int = kwargs.get("seq_token_id", 32000) |
|
self.bio_pad_token_id: int = kwargs.get("bio_pad_token_id", 1) |
|
self.english_pad_token_id: int = kwargs.get("english_pad_token_id", 2) |
|
super().__init__(**kwargs) |
|
|
|
def to_dict(self): |
|
output = super().to_dict() |
|
|
|
def serialize(obj): |
|
return obj.to_dict() if hasattr(obj, "to_dict") else vars(obj) |
|
|
|
output["gpt_config"] = serialize(self.gpt_config) |
|
output["nt_config"] = serialize(self.nt_config) |
|
output["perceiver_resampler_config"] = serialize( |
|
self.perceiver_resampler_config |
|
) |
|
return output |
|
|
|
|
|
class TorchBioBrainDecoder(nn.Module): |
|
def __init__( |
|
self, |
|
gpt_config: GptConfig, |
|
seq_token_id: int, |
|
): |
|
""" |
|
Initializes the BioBrain decoder, using a GPT model for text generation with |
|
bio embeddings. |
|
|
|
Args: |
|
gpt_config: Configuration for the GPT model |
|
seq_token_id: Index of the SEQ token |
|
""" |
|
super(TorchBioBrainDecoder, self).__init__() |
|
self.gpt_config = gpt_config |
|
self.seq_token_id = seq_token_id |
|
|
|
|
|
self.gpt_model = TorchGptDecoder(self.gpt_config) |
|
|
|
def forward( |
|
self, english_token_ids: torch.Tensor, projected_bio_embeddings: torch.Tensor |
|
) -> torch.Tensor: |
|
""" |
|
Forward pass through the model. |
|
|
|
Args: |
|
english_token_ids: Tensor of English token IDs with shape |
|
(batch_size, num_english_tokens). |
|
projected_bio_embeddings: Optional tensor of bio embeddings with shape |
|
(batch_size, num_bio_sequences, ?, embed_dim). |
|
|
|
Returns: |
|
torch.Tensor: The logits from the GPT model, |
|
shaped (batch_size, num_english_tokens, vocab_size). |
|
""" |
|
|
|
|
|
tokens_embeddings = self.gpt_model.token_embed(english_token_ids) |
|
|
|
if projected_bio_embeddings is not None: |
|
( |
|
batch_size, |
|
num_bio_sequences, |
|
_, |
|
bio_embed_dim, |
|
) = projected_bio_embeddings.shape |
|
|
|
|
|
processed_tokens_ids = english_token_ids.clone() |
|
for bio_seq_num in range(num_bio_sequences): |
|
tokens_embeddings, processed_tokens_ids = self.insert_embeddings( |
|
processed_tokens_ids, |
|
tokens_embeddings, |
|
projected_bio_embeddings[:, bio_seq_num, :, :], |
|
bio_seq_num=bio_seq_num, |
|
) |
|
|
|
|
|
embeddings = self.gpt_model.apply_transformer_layers(tokens_embeddings) |
|
embeddings = self.gpt_model.final_norm(embeddings) |
|
|
|
|
|
logits = self.gpt_model.lm_head(embeddings) |
|
|
|
if projected_bio_embeddings is not None: |
|
|
|
processed_tokens_ids = english_token_ids.clone() |
|
resampled_length = projected_bio_embeddings.shape[-2] |
|
for _ in range(num_bio_sequences): |
|
logits, processed_tokens_ids = self.cleanup_logits( |
|
tokens=processed_tokens_ids, |
|
logits=logits, |
|
resampled_length=resampled_length, |
|
) |
|
|
|
return logits |
|
|
|
def insert_embeddings( |
|
self, |
|
tokens: torch.Tensor, |
|
input_embeddings: torch.Tensor, |
|
resampled_embeddings: torch.Tensor, |
|
bio_seq_num: int, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Inserts resampled embeddings in input_embeddings, starting at the SEQ token |
|
|
|
Args: |
|
tokens (torch.Tensor): Shape (batch_size, num_tokens) |
|
input_embeddings (torch.Tensor): Shape (batch_size, num_tokens, embed_dim) |
|
resampled_embeddings (torch.Tensor): |
|
Shape (batch_size, num_bio_sequences, bio_sequence_length, embed_dim) |
|
|
|
Returns: |
|
Tuple[torch.Tensor, torch.Tensor]: |
|
- input_embeddings with resampled_embeddings inserted at the SEQ token |
|
- tokens with the SEQ token set to -1 |
|
""" |
|
|
|
def _insert( |
|
tokens_1d: torch.Tensor, |
|
input_embeddings_1d: torch.Tensor, |
|
resampled_embeddings_1d: torch.Tensor, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Args: |
|
tokens (torch.Tensor): Shape (num_tokens,) |
|
input_embeddings (torch.Tensor): Shape (num_tokens, embed_dim,) |
|
resampled_embeddings (torch.Tensor): |
|
Shape (bio_sequence_length, embed_dim,) |
|
""" |
|
indices = torch.where(tokens_1d == self.seq_token_id)[0] |
|
if indices.numel() > 0: |
|
idx = indices[0].item() |
|
insertion_pos = idx + resampled_embeddings_1d.shape[-2] * bio_seq_num |
|
x = torch.cat( |
|
[ |
|
input_embeddings_1d[:insertion_pos, :], |
|
resampled_embeddings_1d, |
|
input_embeddings_1d[insertion_pos:, :], |
|
], |
|
dim=0, |
|
)[: tokens_1d.shape[0] + 1, :] |
|
x = torch.roll(torch.roll(x, shifts=-idx, dims=0), shifts=idx, dims=0)[ |
|
:-1, : |
|
] |
|
tokens_1d[idx] = -1 |
|
return x, tokens_1d |
|
else: |
|
return ( |
|
input_embeddings, |
|
tokens_1d, |
|
) |
|
|
|
tokens_acc = [] |
|
embeddings_acc = [] |
|
|
|
for i in range(tokens.shape[0]): |
|
embeddings_out, tokens_out = _insert( |
|
tokens[i].clone(), |
|
input_embeddings[i].clone(), |
|
resampled_embeddings[i].clone(), |
|
) |
|
tokens_acc.append(tokens_out) |
|
embeddings_acc.append(embeddings_out) |
|
tokens_acc = torch.stack(tokens_acc) |
|
embeddings_acc = torch.stack(embeddings_acc) |
|
|
|
return embeddings_acc, tokens_acc |
|
|
|
def cleanup_logits( |
|
self, tokens: torch.Tensor, logits: torch.Tensor, resampled_length: int |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Removes the logits corresponding to the unused embeddings. |
|
|
|
Args: |
|
tokens: Input english tokens. |
|
logits: Input logits. |
|
|
|
Returns: |
|
Cleaned logits, last values will be equal to 0. |
|
""" |
|
|
|
def _clean( |
|
token: torch.Tensor, logit: torch.Tensor |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
indices = torch.where(token == self.seq_token_id)[0] |
|
if indices.numel() > 0: |
|
idx = indices[0].item() |
|
|
|
mask_idx = ( |
|
torch.arange(logit.shape[0] - resampled_length, device=logit.device) |
|
> idx |
|
) |
|
mask_idx = mask_idx.unsqueeze(1) |
|
|
|
|
|
logit = ( |
|
logit[:-resampled_length] * (~mask_idx) |
|
+ logit[resampled_length:] * mask_idx |
|
) |
|
|
|
|
|
logit = torch.cat( |
|
( |
|
logit, |
|
torch.zeros( |
|
(resampled_length, logit.shape[1]), |
|
dtype=logit.dtype, |
|
device=logit.device, |
|
), |
|
) |
|
) |
|
|
|
|
|
token[idx] = -1 |
|
|
|
return logit, token |
|
|
|
else: |
|
return logit, token |
|
|
|
tokens_acc = [] |
|
logits_acc = [] |
|
|
|
for i in range(tokens.shape[0]): |
|
logits_out, tokens_out = _clean(tokens[i].clone(), logits[i].clone()) |
|
tokens_acc.append(tokens_out) |
|
logits_acc.append(logits_out) |
|
tokens_acc = torch.stack(tokens_acc) |
|
logits_acc = torch.stack(logits_acc) |
|
|
|
return logits_acc, tokens_acc |
|
|
|
|
|
class TorchMultiOmicsModel(PreTrainedModel): |
|
config_class = ChatNTConfig |
|
|
|
def __init__(self, config: ChatNTConfig) -> None: |
|
if isinstance(config, dict): |
|
|
|
|
|
config["gpt_config"]["rope_config"] = RotaryEmbeddingConfig( |
|
**config["gpt_config"]["rope_config"] |
|
) |
|
config["gpt_config"] = GptConfig(**config["gpt_config"]) |
|
config["nt_config"] = NucleotideTransformerConfig(**config["nt_config"]) |
|
config["perceiver_resampler_config"] = PerceiverResamplerConfig( |
|
**config["perceiver_resampler_config"] |
|
) |
|
config = ChatNTConfig(**config) |
|
|
|
else: |
|
if isinstance(config.gpt_config, dict): |
|
config.gpt_config["rope_config"] = RotaryEmbeddingConfig( |
|
**config.gpt_config["rope_config"] |
|
) |
|
config.gpt_config = GptConfig(**config.gpt_config) |
|
|
|
if isinstance(config.nt_config, dict): |
|
config.nt_config = NucleotideTransformerConfig(**config.nt_config) |
|
|
|
if isinstance(config.perceiver_resampler_config, dict): |
|
config.perceiver_resampler_config = PerceiverResamplerConfig( |
|
**config.perceiver_resampler_config |
|
) |
|
|
|
super().__init__(config=config) |
|
self.gpt_config = config.gpt_config |
|
self.nt_config = config.nt_config |
|
self.perceiver_resampler_config = config.perceiver_resampler_config |
|
self.seq_token_id = config.seq_token_id |
|
self.bio_pad_token_id = config.bio_pad_token_id |
|
self.english_pad_token_id = config.english_pad_token_id |
|
|
|
|
|
self.seq_token_id -= 1 |
|
|
|
self.biobrain_encoder = TorchBioBrainEncoder(nt_config=self.nt_config) |
|
self.biobrain_decoder = TorchBioBrainDecoder( |
|
gpt_config=self.gpt_config, seq_token_id=self.seq_token_id |
|
) |
|
self.projection_model = TorchMultiModalPerceiverResamplerProjection( |
|
perceiver_resampler_config=self.perceiver_resampler_config, |
|
input_embed_dim=self.nt_config.embed_dim, |
|
embed_dim=self.gpt_config.embed_dim, |
|
english_vocab_size=self.gpt_config.vocab_size, |
|
bio_pad_token_id=self.bio_pad_token_id, |
|
english_pad_token_id=self.english_pad_token_id, |
|
) |
|
|
|
def forward( |
|
self, |
|
multi_omics_tokens_ids: tuple[torch.Tensor, torch.Tensor], |
|
projection_english_tokens_ids: torch.Tensor, |
|
projected_bio_embeddings: torch.Tensor = None, |
|
) -> dict[str, torch.Tensor]: |
|
""" |
|
|
|
Args: |
|
multi_omics_tokens_ids (Tuple[torch.Tensor, torch.Tensor]): |
|
english_tokens_ids: Represents the prompt tokens (english tokens) |
|
Shape (batch_size, num_english_tokens) |
|
|
|
bio_tokens_ids: Represents the bio sequences tokens |
|
Shape (batch_size, num_bio_sequences, num_bio_tokens) |
|
|
|
projection_english_tokens_ids (torch.Tensor): |
|
Shape (batch_size, num_english_tokens) |
|
|
|
projected_bio_embeddings (projected_bio_embeddings, optional): |
|
Shape (batch_size, num_bio_sequencse, ?, embed_dim). |
|
Defaults to None. |
|
|
|
Returns: |
|
dict[str, torch.Tensor] containing: |
|
- logits: |
|
Shape (batch_size, num_tokens, vocab_size) |
|
|
|
- projected_bio_embeddings: |
|
Shape (batch_size, num_bio_sequences, ?, embed_dim) |
|
""" |
|
english_token_ids, bio_token_ids = multi_omics_tokens_ids |
|
english_token_ids = english_token_ids.clone() |
|
bio_token_ids = bio_token_ids.clone() |
|
projection_english_tokens_ids = projection_english_tokens_ids.clone() |
|
if projected_bio_embeddings is not None: |
|
projected_bio_embeddings = projected_bio_embeddings.clone() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vocab_size = self.gpt_config.vocab_size |
|
|
|
english_token_ids[english_token_ids == vocab_size - 1] = 0 |
|
projection_english_tokens_ids[ |
|
projection_english_tokens_ids == vocab_size - 1 |
|
] = 0 |
|
english_token_ids[english_token_ids == vocab_size] = vocab_size - 1 |
|
projection_english_tokens_ids[projection_english_tokens_ids == vocab_size] = ( |
|
vocab_size - 1 |
|
) |
|
|
|
if bio_token_ids is None: |
|
projected_bio_embeddings = None |
|
else: |
|
num_bio_sequences = bio_token_ids.shape[1] |
|
|
|
if projected_bio_embeddings is None: |
|
|
|
bio_embeddings_list = [ |
|
self.biobrain_encoder(bio_token_ids=bio_token_ids[:, bio_seq_num]) |
|
for bio_seq_num in range(num_bio_sequences) |
|
] |
|
|
|
|
|
projected_bio_embeddings = [ |
|
self.projection_model( |
|
bio_token_ids=bio_token_ids[:, bio_seq_num], |
|
bio_embeddings=bio_embeddings, |
|
english_token_ids=projection_english_tokens_ids, |
|
) |
|
for bio_seq_num, bio_embeddings in enumerate(bio_embeddings_list) |
|
] |
|
projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1) |
|
|
|
|
|
logits = self.biobrain_decoder( |
|
english_token_ids=english_token_ids, |
|
projected_bio_embeddings=projected_bio_embeddings, |
|
) |
|
|
|
outs = {"logits": logits, "projected_bio_embeddings": projected_bio_embeddings} |
|
|
|
return outs |
|
|
|
|
|
class TorchRotaryEmbedding(torch.nn.Module): |
|
def __init__(self, config: RotaryEmbeddingConfig): |
|
super().__init__() |
|
|
|
self.max_seq_len = config.max_seq_len |
|
self.dim = config.dim |
|
self.theta = config.theta |
|
self.sincos_cache = None |
|
|
|
def _create_sinusoidal_positions(self, device: torch.device) -> torch.Tensor: |
|
""" |
|
Create the sines and cosines for the RoPE. |
|
|
|
Returns: |
|
Sinusoidal positions of shape (self.max_seq_len, self.dim). |
|
""" |
|
|
|
inv_freq = 1.0 / ( |
|
self.theta |
|
** (torch.arange(0, self.dim, 2, device=device).float() / self.dim) |
|
) |
|
|
|
|
|
sinusoid_inp = torch.einsum( |
|
"i,j->ij", torch.arange(self.max_seq_len, device=device).float(), inv_freq |
|
) |
|
|
|
|
|
sin, cos = sinusoid_inp.sin(), sinusoid_inp.cos() |
|
|
|
|
|
sincos = torch.zeros( |
|
(self.max_seq_len, self.dim), dtype=torch.float32, device=device |
|
) |
|
|
|
|
|
sentinel = self.dim // 2 + self.dim % 2 |
|
sincos[:, :sentinel] = sin |
|
sincos[:, sentinel:] = cos |
|
|
|
return sincos |
|
|
|
def _rotate_every_two(self, x: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Prepare a tensor to apply the RoPE mechanism. |
|
|
|
Args: |
|
x: Tensor of shape (batch_size, seq_len, num_heads, head_dim), |
|
typically this is the key or query tensor. |
|
|
|
Returns: |
|
The even indices in the last dimension have their sign flipped. |
|
Tensor of shape (batch_size, seq_len, num_heads, head_dim). |
|
""" |
|
|
|
rotate_half = torch.stack((-x[..., 1::2], x[..., ::2]), dim=-1) |
|
|
|
|
|
rotate_half = rotate_half.view(rotate_half.shape[:-2] + (-1,)) |
|
return rotate_half |
|
|
|
def _apply_rotary_pos_emb( |
|
self, x: torch.Tensor, sincos: torch.Tensor |
|
) -> torch.Tensor: |
|
""" |
|
Applies rotary embeddings to x. |
|
|
|
Args: |
|
x: Tensor of shape (batch_size, seq_len, num_heads, head_dim), |
|
typically this is the key or query tensor. |
|
sincos: Tuple of sine and cosine tensors for position encoding. |
|
|
|
Returns: |
|
RoPE embeddings tensor. |
|
""" |
|
sin_pos, cos_pos = sincos |
|
|
|
|
|
sin_pos = torch.repeat_interleave(sin_pos.unsqueeze(2), repeats=2, dim=-1) |
|
cos_pos = torch.repeat_interleave(cos_pos.unsqueeze(2), repeats=2, dim=-1) |
|
|
|
|
|
return (x * cos_pos) + (self._rotate_every_two(x) * sin_pos) |
|
|
|
def __call__( |
|
self, k: torch.Tensor, q: torch.Tensor, positions: Optional[torch.Tensor] = None |
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Applies rotary embeddings to k and q. |
|
|
|
Args: |
|
k: key tensor of shape (batch_size, seq_len, num_heads, head_dim), |
|
q: value tensor of shape (batch_size, seq_len, num_heads, head_dim), |
|
positions: optional positions offset useful when caching, |
|
|
|
Returns: |
|
RoPE embeddings for the keys and values. |
|
""" |
|
if self.sincos_cache is None: |
|
device = k.device |
|
self.sincos_cache = self._create_sinusoidal_positions(device=device) |
|
|
|
batch_size, seq_len, num_heads, head_dim = k.shape |
|
|
|
|
|
position_ids = ( |
|
torch.arange(seq_len, device=k.device).unsqueeze(0).expand(batch_size, -1) |
|
) |
|
|
|
if positions is not None: |
|
position_ids += positions |
|
|
|
|
|
sincos = self.sincos_cache[position_ids] |
|
|
|
|
|
sincos = torch.chunk(sincos, 2, dim=-1) |
|
|
|
|
|
k_rot = self._apply_rotary_pos_emb(k[..., : self.dim], sincos) |
|
k_pass = k[..., self.dim :] |
|
|
|
q_rot = self._apply_rotary_pos_emb(q[..., : self.dim], sincos) |
|
q_pass = q[..., self.dim :] |
|
|
|
|
|
keys = torch.cat([k_rot, k_pass], dim=-1) |
|
values = torch.cat([q_rot, q_pass], dim=-1) |
|
|
|
return keys, values |
|
|
|
|
|
class TorchGptGroupedQueryAttention(nn.Module): |
|
def __init__( |
|
self, |
|
embed_dim: int, |
|
num_heads: int, |
|
rope_config: RotaryEmbeddingConfig, |
|
num_kv_heads: int = None, |
|
head_dim: int = None, |
|
add_bias_attn: bool = False, |
|
) -> None: |
|
super().__init__() |
|
self.num_heads = num_heads |
|
self.num_kv_heads = num_kv_heads or num_heads |
|
self.embed_dim = embed_dim |
|
self.head_dim = head_dim or (embed_dim // num_heads) |
|
self.add_bias_attn = add_bias_attn |
|
self.rope = TorchRotaryEmbedding(rope_config) |
|
|
|
self.query_linear = nn.Linear( |
|
embed_dim, self.num_heads * self.head_dim, bias=add_bias_attn |
|
) |
|
self.key_linear = nn.Linear( |
|
embed_dim, self.num_kv_heads * self.head_dim, bias=add_bias_attn |
|
) |
|
self.value_linear = nn.Linear( |
|
embed_dim, self.num_kv_heads * self.head_dim, bias=add_bias_attn |
|
) |
|
self.out_linear = nn.Linear( |
|
self.num_heads * self.head_dim, embed_dim, bias=add_bias_attn |
|
) |
|
|
|
def forward( |
|
self, |
|
query_inputs: torch.Tensor, |
|
key_inputs: torch.Tensor, |
|
value_inputs: torch.Tensor, |
|
attention_mask: torch.Tensor = None, |
|
) -> torch.Tensor: |
|
batch_size, seq_len, _ = query_inputs.shape |
|
|
|
queries = self.query_linear(query_inputs).view( |
|
batch_size, seq_len, self.num_heads, self.head_dim |
|
) |
|
keys = self.key_linear(key_inputs).view( |
|
batch_size, seq_len, self.num_kv_heads, self.head_dim |
|
) |
|
values = self.value_linear(value_inputs).view( |
|
batch_size, seq_len, self.num_kv_heads, self.head_dim |
|
) |
|
|
|
keys, queries = self.rope(keys, queries) |
|
|
|
n_rep = self.num_heads // self.num_kv_heads |
|
keys = keys.repeat_interleave(n_rep, dim=2) |
|
values = values.repeat_interleave(n_rep, dim=2) |
|
|
|
attention_logits = torch.einsum("bthd,bThd->bhtT", queries, keys) / ( |
|
self.head_dim**0.5 |
|
) |
|
|
|
if attention_mask is not None: |
|
attention_logits = attention_logits.masked_fill( |
|
attention_mask == 0, float("-inf") |
|
) |
|
|
|
attention_weights = nn.functional.softmax(attention_logits, dim=-1) |
|
|
|
values = torch.einsum("bhtT,bThd->bthd", attention_weights, values) |
|
values = values.contiguous().view(batch_size, seq_len, -1) |
|
|
|
return self.out_linear(values) |
|
|
|
|
|
class TorchGptDecoder(nn.Module): |
|
def __init__(self, config: GptConfig, name: Optional[str] = None): |
|
super().__init__() |
|
self.config = config |
|
|
|
self.token_embed = nn.Embedding(config.vocab_size, config.embed_dim) |
|
|
|
if config.norm_type == "layer_norm": |
|
self.final_norm = nn.LayerNorm(config.embed_dim) |
|
elif config.norm_type == "RMS_norm": |
|
self.final_norm = TorchRMSNorm(config.embed_dim, eps=config.rms_norm_eps) |
|
else: |
|
raise ValueError(f"unrecognized norm_type in config {config.norm_type}") |
|
|
|
self.layers = nn.ModuleList( |
|
[ |
|
TorchGptDecoderLayer( |
|
embed_dim=config.embed_dim, |
|
ffn_embed_dim=config.ffn_embed_dim, |
|
num_heads=config.num_heads, |
|
rope_config=config.rope_config, |
|
norm_type=config.norm_type, |
|
parallel_attention_ff=config.parallel_attention_ff, |
|
add_bias_ffn=config.add_bias_ffn, |
|
ffn_activation_name=config.ffn_activation_name, |
|
use_glu_in_ffn=config.use_glu_in_ffn, |
|
num_kv_heads=config.num_kv_heads, |
|
add_bias_attn=config.add_bias_attn, |
|
rms_norm_eps=config.rms_norm_eps, |
|
) |
|
for _ in range(config.num_layers) |
|
] |
|
) |
|
|
|
self.lm_head = TorchSimpleLMHead( |
|
embed_dim=config.embed_dim, |
|
alphabet_size=config.vocab_size, |
|
add_bias_lm_head=config.add_bias_lm_head, |
|
) |
|
|
|
def apply_transformer_layers( |
|
self, embeddings: torch.Tensor, attention_mask: torch.Tensor = None |
|
) -> torch.Tensor: |
|
if attention_mask is None: |
|
attention_mask = build_causal_attention_mask( |
|
1, embeddings.shape[1], device=embeddings.device |
|
) |
|
for layer in self.layers: |
|
embeddings = layer(embeddings, attention_mask) |
|
|
|
return embeddings |
|
|
|
def forward( |
|
self, token_ids: torch.Tensor, attention_mask: torch.Tensor = None |
|
) -> dict[str, torch.Tensor]: |
|
if attention_mask is None: |
|
attention_mask = build_causal_attention_mask( |
|
1, token_ids.shape[1], device=token_ids.device |
|
) |
|
|
|
tokens_embeddings = self.token_embed(token_ids) |
|
|
|
after_transformer_embeddings = self.apply_transformer_layers( |
|
tokens_embeddings, attention_mask=attention_mask |
|
) |
|
|
|
embeddings = self.final_norm(after_transformer_embeddings) |
|
logits = self.lm_head(embeddings) |
|
return {"embeddings": embeddings, "logits": logits} |
|
|
|
|
|
class TorchSimpleLMHead(nn.Module): |
|
def __init__( |
|
self, embed_dim: int, alphabet_size: int, add_bias_lm_head: bool = True |
|
) -> None: |
|
super().__init__() |
|
self.fc = nn.Linear(embed_dim, alphabet_size, bias=add_bias_lm_head) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
return self.fc(x) |
|
|
|
|
|
class TorchGptDecoderLayer(nn.Module): |
|
def __init__( |
|
self, |
|
embed_dim: int, |
|
ffn_embed_dim: int, |
|
num_heads: int, |
|
rope_config: RotaryEmbeddingConfig, |
|
norm_type: str, |
|
parallel_attention_ff: bool, |
|
add_bias_ffn: bool, |
|
ffn_activation_name: str, |
|
use_glu_in_ffn: bool, |
|
num_kv_heads: int, |
|
add_bias_attn: bool, |
|
rms_norm_eps: float = 1e-6, |
|
) -> None: |
|
super().__init__() |
|
self.num_heads = num_heads |
|
self.parallel_attention_ff = parallel_attention_ff |
|
self.use_glu_in_ffn = use_glu_in_ffn |
|
|
|
|
|
self.self_attn = TorchGptGroupedQueryAttention( |
|
embed_dim=embed_dim, |
|
num_heads=num_heads, |
|
num_kv_heads=num_kv_heads, |
|
rope_config=rope_config, |
|
add_bias_attn=add_bias_attn, |
|
) |
|
|
|
|
|
if norm_type == "layer_norm": |
|
self.attn_norm = nn.LayerNorm(embed_dim) |
|
if not self.parallel_attention_ff: |
|
self.ffn_norm = nn.LayerNorm(embed_dim) |
|
elif norm_type == "RMS_norm": |
|
self.attn_norm = TorchRMSNorm(embed_dim, eps=rms_norm_eps) |
|
if not self.parallel_attention_ff: |
|
self.ffn_norm = TorchRMSNorm(embed_dim, eps=rms_norm_eps) |
|
else: |
|
raise ValueError(f"unrecognized norm_type: {norm_type}") |
|
|
|
|
|
self.activation = get_activation_fn(ffn_activation_name) |
|
ffn_hidden_dim = ffn_embed_dim * (2 if use_glu_in_ffn else 1) |
|
self.fc1 = nn.Linear(embed_dim, ffn_hidden_dim, bias=add_bias_ffn) |
|
self.fc2 = nn.Linear(ffn_embed_dim, embed_dim, bias=add_bias_ffn) |
|
|
|
def forward( |
|
self, embeddings: torch.Tensor, attention_mask: torch.Tensor |
|
) -> torch.Tensor: |
|
residuals = embeddings |
|
|
|
if self.parallel_attention_ff: |
|
|
|
embeddings_normed = self.attn_norm(embeddings) |
|
|
|
attn_output, _ = self.self_attn( |
|
embeddings_normed, |
|
embeddings_normed, |
|
embeddings_normed, |
|
attn_mask=attention_mask, |
|
) |
|
ffn_output = self.mlp(embeddings_normed) |
|
|
|
return residuals + attn_output + ffn_output |
|
else: |
|
|
|
normed_embeddings = self.attn_norm(embeddings) |
|
|
|
attn_output = embeddings + self.self_attn( |
|
normed_embeddings, |
|
normed_embeddings, |
|
normed_embeddings, |
|
attention_mask=attention_mask, |
|
) |
|
|
|
normed_embeddings2 = self.ffn_norm(attn_output) |
|
ffn_output = self.mlp(normed_embeddings2) |
|
return attn_output + ffn_output |
|
|
|
def mlp(self, x: torch.Tensor) -> torch.Tensor: |
|
"""Applies the feedforward network (MLP) with optional GLU.""" |
|
ffn_output = self.fc1(x) |
|
|
|
if self.use_glu_in_ffn: |
|
ffn_output1, ffn_output2 = ffn_output.chunk(2, dim=-1) |
|
ffn_output = self.activation(ffn_output1) * ffn_output2 |
|
else: |
|
ffn_output = self.activation(ffn_output) |
|
|
|
return self.fc2(ffn_output) |
|
|
|
|
|
class TorchRMSNorm(nn.Module): |
|
def __init__(self, dim: int, eps: float = 1e-6) -> None: |
|
super().__init__() |
|
self.eps = eps |
|
self.scale = nn.Parameter(torch.ones(dim)) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
return ( |
|
x |
|
* self.scale |
|
/ torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps) |
|
) |
|
|
|
|
|
def get_activation_fn(activation_name: str): |
|
activations = { |
|
"gelu": nn.functional.gelu, |
|
"relu": nn.functional.relu, |
|
"swish": nn.functional.silu, |
|
"silu": nn.functional.silu, |
|
} |
|
return activations.get(activation_name, nn.functional.relu) |
|
|
|
|
|
def build_causal_attention_mask( |
|
batch_size: int, seq_len: int, device: torch.device |
|
) -> torch.Tensor: |
|
""" |
|
Builds a batch of causal masks of shape (batch_size, 1, seq_len, seq_len) to feed |
|
to an attention layer. |
|
|
|
Args: |
|
batch_size: Batch size. |
|
seq_len: Length of the sequences. |
|
|
|
Returns: |
|
Batch of causal masks. |
|
""" |
|
mask = torch.ones((batch_size, 1, seq_len, seq_len), device=device) |
|
causal_mask = torch.tril(mask) |
|
return causal_mask |
|
|
|
|
|
@dataclass |
|
class RotaryEmbeddingConfigBis: |
|
""" |
|
Parameters to initialize the RotaryEmbedding layer. The rescaling factor allows |
|
to adapt the rotary embeddings to larger lengths than what was used for training. |
|
One of this strategy is presented in the Yarn paper: https://arxiv.org/pdf/2309.00071.pdf. # noqa |
|
Args: |
|
""" |
|
|
|
rescaling_factor: Optional[float] |
|
|
|
|
|
class RotaryEmbeddingBis(torch.nn.Module): |
|
""" |
|
Rotary position embeddings based on those in |
|
[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). |
|
Query and keys are transformed by rotation |
|
matrices which depend on their relative positions. |
|
""" |
|
|
|
def __init__(self, dim: int, rotary_embedding_config: RotaryEmbeddingConfigBis): |
|
super().__init__() |
|
|
|
|
|
self.rescaling_factor = rotary_embedding_config.rescaling_factor |
|
self.upper_freq = 10000 |
|
self.dim = dim |
|
|
|
self._seq_len_cached = None |
|
self._cos_cached = None |
|
self._sin_cached = None |
|
|
|
def _apply_rotary_pos_emb( |
|
self, |
|
heads: torch.Tensor, |
|
cos: torch.Tensor, |
|
sin: torch.Tensor, |
|
) -> torch.Tensor: |
|
""" """ |
|
x_first, x_second = ( |
|
heads[..., : heads.shape[-1] // 2], |
|
heads[..., heads.shape[-1] // 2 :], |
|
) |
|
|
|
first_part = x_first * cos - x_second * sin |
|
second_part = x_second * cos + x_first * sin |
|
|
|
return torch.cat((first_part, second_part), dim=-1) |
|
|
|
def _compute_cos_sin_tables( |
|
self, x: torch.Tensor, inv_freq: torch.Tensor, seq_dimension: int = 2 |
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
seq_len = x.shape[seq_dimension] |
|
|
|
|
|
self._seq_len_cached = seq_len |
|
t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(inv_freq) |
|
|
|
freqs = torch.einsum("i, j -> ij", t, inv_freq) |
|
|
|
self._cos_cached = torch.cos(freqs)[None, :, None, :] |
|
self._sin_cached = torch.sin(freqs)[None, :, None, :] |
|
|
|
|
|
|
|
|
|
|
|
return self._cos_cached, self._sin_cached |
|
|
|
def forward( |
|
self, q: torch.Tensor, k: torch.Tensor |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
if self.rescaling_factor is None: |
|
inv_freq = 1.0 / ( |
|
self.upper_freq |
|
** (torch.arange(0, self.dim, 2, device=q.device).float() / self.dim) |
|
) |
|
else: |
|
updated_base = self.upper_freq * ( |
|
self.rescaling_factor ** (self.dim / (self.dim - 2)) |
|
) |
|
inv_freq = 1.0 / ( |
|
updated_base |
|
** (torch.arange(0, self.dim, 2, device=q.device).float() / self.dim) |
|
) |
|
|
|
self._cos_cached, self._sin_cached = self._compute_cos_sin_tables( |
|
q, |
|
inv_freq, |
|
seq_dimension=-3, |
|
) |
|
|
|
return ( |
|
self._apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached), |
|
self._apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached), |
|
) |
|
|
|
|
|
class MultiHeadAttention(nn.Module): |
|
def __init__( |
|
self, |
|
num_heads: int, |
|
key_size: int, |
|
rotary_embedding_config: Optional[RotaryEmbeddingConfigBis] = None, |
|
add_bias_kv: bool = False, |
|
value_size: Optional[int] = None, |
|
model_size: Optional[int] = None, |
|
name: Optional[str] = None, |
|
): |
|
super().__init__() |
|
if not model_size: |
|
model_size = key_size * num_heads |
|
if not value_size: |
|
value_size = key_size |
|
self.model_size = model_size |
|
self.key_size = key_size |
|
self.value_size = value_size |
|
self.add_bias_kv = add_bias_kv |
|
self.name = name |
|
self.num_heads = num_heads |
|
self._rotary_embedding_config = rotary_embedding_config |
|
|
|
self.w_k = nn.Linear(self.model_size, self.num_heads * self.key_size) |
|
self.w_q = nn.Linear(self.model_size, self.num_heads * self.key_size) |
|
self.w_v = nn.Linear(self.model_size, self.num_heads * self.value_size) |
|
self.output = nn.Linear(self.num_heads * self.value_size, self.model_size) |
|
if self._rotary_embedding_config: |
|
self._rotary_embedding = RotaryEmbeddingBis( |
|
self.key_size, self._rotary_embedding_config |
|
) |
|
|
|
def apply_rotary_embeddings( |
|
self, |
|
query: torch.Tensor, |
|
key: torch.Tensor, |
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
""" """ |
|
query, key = self._rotary_embedding(query, key) |
|
return query, key |
|
|
|
def forward( |
|
self, |
|
query: torch.Tensor, |
|
key: torch.Tensor, |
|
value: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
attention_weight_bias: Optional[torch.Tensor] = None, |
|
) -> dict[str, torch.Tensor]: |
|
""" |
|
Returns: |
|
dictionary containing attention weights |
|
and outputs. |
|
""" |
|
key_heads = self.w_k(key).reshape( |
|
(*key.shape[:-1], self.num_heads, self.key_size) |
|
) |
|
query_heads = self.w_q(query).reshape( |
|
(*query.shape[:-1], self.num_heads, self.key_size) |
|
) |
|
value_heads = self.w_v(value).reshape( |
|
(*value.shape[:-1], self.num_heads, self.value_size) |
|
) |
|
if self._rotary_embedding_config: |
|
query_heads, key_heads = self.apply_rotary_embeddings( |
|
query_heads, key_heads |
|
) |
|
attention_weights = torch.einsum( |
|
"...thd, ...Thd -> ...htT", query_heads, key_heads |
|
) |
|
sqrt_key_size = np.sqrt(self.key_size) |
|
attention_weights = attention_weights / sqrt_key_size |
|
if attention_mask is not None: |
|
attention_weights = torch.where(attention_mask, attention_weights, -1e30) |
|
if attention_weight_bias is not None: |
|
attention_weights = F.softmax( |
|
attention_weights + attention_weight_bias, dim=-1 |
|
) |
|
else: |
|
attention_weights = F.softmax(attention_weights, dim=-1) |
|
value_out = torch.einsum( |
|
"...htT, ...Thd->...thd", attention_weights, value_heads |
|
) |
|
value_out = value_out.reshape((*value_out.shape[:-2], -1)) |
|
embeddings = self.output(value_out) |
|
|
|
return {"attention_weights": attention_weights, "embeddings": embeddings} |
|
|
|
|
|
class SelfAttentionBlock(nn.Module): |
|
def __init__( |
|
self, |
|
num_heads: int, |
|
embed_dim: int, |
|
ffn_embed_dim: int, |
|
key_size: Optional[int] = None, |
|
add_bias_kv: bool = False, |
|
add_bias_fnn: bool = True, |
|
ffn_activation_name: str = "gelu-no-approx", |
|
use_glu_in_ffn: bool = False, |
|
layer_norm_eps: float = 1e-5, |
|
pre_layer_norm: bool = True, |
|
name: Optional[str] = None, |
|
rotary_embedding_config: Optional[RotaryEmbeddingConfigBis] = None, |
|
): |
|
super().__init__() |
|
if key_size is None: |
|
if embed_dim % num_heads != 0: |
|
raise ValueError( |
|
f"The embedding dimension should be divisible by the number of " |
|
f"heads, however provided embedding dimension is {embed_dim} and " |
|
f"the number of heads is {num_heads}." |
|
) |
|
else: |
|
key_size = embed_dim // num_heads |
|
|
|
|
|
self._pre_layer_norm = pre_layer_norm |
|
self._use_glu_in_fnn = use_glu_in_ffn |
|
|
|
if use_glu_in_ffn: |
|
|
|
|
|
|
|
|
|
self.fc1 = nn.Linear(embed_dim, int(2 * ffn_embed_dim), bias=add_bias_fnn) |
|
else: |
|
self.fc1 = nn.Linear(embed_dim, ffn_embed_dim, bias=add_bias_fnn) |
|
|
|
self.fc2 = nn.Linear(ffn_embed_dim, embed_dim, bias=add_bias_fnn) |
|
|
|
self.layer_norm_self_attention = nn.LayerNorm( |
|
embed_dim, |
|
) |
|
self.layer_norm_mlp = nn.LayerNorm(embed_dim) |
|
if ffn_activation_name == "swish": |
|
self._ffn_activation_fn = nn.SiLU() |
|
elif ffn_activation_name == "gelu-no-approx": |
|
self._ffn_activation_fn = nn.GELU(approximate="tanh") |
|
else: |
|
self._ffn_activation_fn = getattr(torch.nn, ffn_activation_name) |
|
|
|
self.mha = MultiHeadAttention( |
|
num_heads=num_heads, |
|
key_size=key_size, |
|
add_bias_kv=add_bias_kv, |
|
model_size=embed_dim, |
|
name="self_attention", |
|
rotary_embedding_config=rotary_embedding_config, |
|
) |
|
|
|
def mlp(self, embed: torch.Tensor) -> torch.Tensor: |
|
|
|
if self._pre_layer_norm: |
|
x = self.layer_norm_mlp(embed) |
|
else: |
|
x = embed |
|
|
|
if self._use_glu_in_fnn: |
|
x = self.fc1(x) |
|
x1, x2 = torch.split(x, split_size_or_sections=x.shape[-1] // 2, dim=-1) |
|
x = self._ffn_activation_fn(x1) * x2 |
|
else: |
|
x = self._ffn_activation_fn(self.fc1(x)) |
|
x = self.fc2(x) |
|
|
|
if not self._pre_layer_norm: |
|
x = self.layer_norm_mlp(x + embed) |
|
return x |
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
attention_weight_bias: Optional[torch.Tensor] = None, |
|
) -> dict[str, torch.Tensor]: |
|
|
|
res = x |
|
if self._pre_layer_norm: |
|
x = self.layer_norm_self_attention(x) |
|
|
|
output: dict[str, torch.Tensor] = self.mha( |
|
x, |
|
x, |
|
x, |
|
attention_mask=attention_mask, |
|
attention_weight_bias=attention_weight_bias, |
|
) |
|
|
|
if not self._pre_layer_norm: |
|
output["embeddings"] = self.layer_norm_self_attention( |
|
output["embeddings"] + res |
|
) |
|
|
|
x = output["embeddings"] |
|
else: |
|
x = output["embeddings"] |
|
x = res + x |
|
|
|
|
|
if not self._pre_layer_norm: |
|
x = self.mlp(x) |
|
else: |
|
x = x + self.mlp(x) |
|
|
|
output["embeddings"] = x |
|
return output |
|
|
|
|
|
class RobertaLMHead(nn.Module): |
|
""" |
|
Roberta Language Model head. Transforms final attention layer output into a |
|
distribution over tokens at each position. |
|
""" |
|
|
|
def __init__(self, embed_dim: int, alphabet_size: int): |
|
""" |
|
Args: |
|
embed_dim: Embedding dimension. |
|
alphabet_size: Number of tokens in the alphabet. |
|
""" |
|
super().__init__() |
|
self.embed_dim = embed_dim |
|
self.alphabet_size = alphabet_size |
|
|
|
|
|
self._first_layer_norm = nn.LayerNorm(embed_dim, elementwise_affine=True) |
|
self._fc1 = nn.Linear(embed_dim, embed_dim) |
|
self._second_layer_norm = nn.LayerNorm(embed_dim, elementwise_affine=True) |
|
self._final_fc = nn.Linear(embed_dim, alphabet_size) |
|
|
|
def forward(self, x: torch.Tensor) -> dict: |
|
x = self._first_layer_norm(x) |
|
embeddings = x |
|
x = self._fc1(x) |
|
x = nn.functional.gelu(x) |
|
x = self._second_layer_norm(x) |
|
logits = self._final_fc(x) |
|
return {"embeddings": embeddings, "logits": logits} |
|
|
|
|
|
class TorchNucleotideTransformer(nn.Module): |
|
def __init__( |
|
self, |
|
nt_config: NucleotideTransformerConfig, |
|
): |
|
super(TorchNucleotideTransformer, self).__init__() |
|
self.nt_config = nt_config |
|
|
|
|
|
assert nt_config.positional_embedding is None |
|
assert nt_config.lm_head == "roberta" |
|
assert nt_config.use_rotary_embedding is True |
|
assert nt_config.token_dropout is False |
|
assert nt_config.emb_layer_norm_before is False |
|
assert nt_config.mask_before_attention is False |
|
assert nt_config.bias_word_embedding is False |
|
assert nt_config.use_gradient_checkpointing is False |
|
|
|
self.embed_layer = nn.Embedding(nt_config.alphabet_size, nt_config.embed_dim) |
|
|
|
self.lm_head = RobertaLMHead( |
|
embed_dim=nt_config.embed_dim, |
|
alphabet_size=nt_config.alphabet_size, |
|
) |
|
|
|
self.rotary_embedding_config = RotaryEmbeddingConfigBis( |
|
rescaling_factor=nt_config.rescaling_factor |
|
) |
|
|
|
self.attention_blocks = nn.ModuleList( |
|
[ |
|
SelfAttentionBlock( |
|
num_heads=nt_config.attention_heads, |
|
embed_dim=nt_config.embed_dim, |
|
key_size=nt_config.key_size, |
|
ffn_embed_dim=nt_config.ffn_embed_dim, |
|
add_bias_kv=nt_config.add_bias_kv, |
|
add_bias_fnn=nt_config.add_bias_ffn, |
|
ffn_activation_name=nt_config.ffn_activation_name, |
|
use_glu_in_ffn=nt_config.use_glu_in_ffn, |
|
rotary_embedding_config=self.rotary_embedding_config, |
|
layer_norm_eps=nt_config.layer_norm_eps, |
|
pre_layer_norm=nt_config.pre_layer_norm, |
|
) |
|
for _ in range(nt_config.num_layers) |
|
] |
|
) |
|
|
|
def forward( |
|
self, tokens: torch.Tensor, attention_mask: torch.Tensor = None |
|
) -> torch.Tensor: |
|
""" |
|
Computes the embeddings based on the input tokens. |
|
|
|
Args: |
|
tokens: Input tokens out of the tokenizer of shape (batch_size, seq_len). |
|
attention_mask: Attention mask of shape (batch_size, 1, seq_len, seq_len). |
|
If no mask is provided, a mask by default which equals 1 over all non |
|
pad tokens and 0 over pad tokens is computed. |
|
|
|
Returns: |
|
Dictionary containing the final embeddings and logits. |
|
""" |
|
x = self.embed_layer(tokens) |
|
|
|
|
|
x = self.nt_config.embed_scale * x |
|
|
|
if attention_mask is None: |
|
attention_mask = build_padding_attention_mask( |
|
tokens=tokens, pad_token_id=self.nt_config.pad_token_id |
|
) |
|
|
|
for layer in self.attention_blocks: |
|
x = layer(x, attention_mask)["embeddings"] |
|
|
|
assert self.nt_config.lm_head == "roberta" |
|
x = self.lm_head(x)["embeddings"] |
|
|
|
return x |
|
|
|
|
|
def build_padding_attention_mask( |
|
tokens: torch.Tensor, pad_token_id: int |
|
) -> torch.Tensor: |
|
""" |
|
Builds a padding mask from a sequence of tokens by masking <pad> in the attention. |
|
|
|
Args: |
|
tokens: Batch of sequences of shape (batch_size, seq_len). |
|
pad_token_id: Int corresponding to the <pad> token to mask. |
|
|
|
Returns: |
|
Batch of attention masks, masking out <pad> tokens. |
|
""" |
|
padding_mask = tokens != pad_token_id |
|
padding_mask = padding_mask.unsqueeze(1) |
|
padding_mask = torch.einsum("bhT, bht -> bhtT", padding_mask, padding_mask) |
|
return padding_mask |
|
|
|
|
|
class TorchBioBrainEncoder(nn.Module): |
|
def __init__( |
|
self, |
|
nt_config: NucleotideTransformerConfig, |
|
): |
|
super(TorchBioBrainEncoder, self).__init__() |
|
self.nt_config = nt_config |
|
self.nt_model = TorchNucleotideTransformer(self.nt_config) |
|
|
|
def forward( |
|
self, |
|
bio_token_ids: torch.Tensor, |
|
) -> torch.Tensor: |
|
""" |
|
Args: |
|
bio_token_ids (torch.Tensor): |
|
Shape (batch_size, num_bio_tokens) |
|
|
|
Returns: |
|
torch.Tensor: |
|
Shape (batch_size, num_bio_tokens, embed_dim) |
|
""" |
|
bio_embeddings = self.nt_model(tokens=bio_token_ids) |
|
|
|
return bio_embeddings |
|
|
|
|
|
class TorchMultiModalPerceiverResamplerBlock(nn.Module): |
|
def __init__( |
|
self, |
|
num_heads: int, |
|
embed_dim: int, |
|
ffn_embed_dim: int, |
|
key_size: Optional[int] = None, |
|
add_bias_kv: bool = False, |
|
add_bias_ffn: bool = True, |
|
ffn_activation_name: str = "gelu", |
|
use_glu_in_ffn: bool = False, |
|
): |
|
super().__init__() |
|
|
|
if key_size is None: |
|
if embed_dim % num_heads != 0: |
|
raise ValueError( |
|
f"Embedding dimension {embed_dim} should be divisible by " |
|
f"num_heads {num_heads}." |
|
) |
|
key_size = embed_dim // num_heads |
|
|
|
self.num_heads = num_heads |
|
self.embed_dim = embed_dim |
|
self.ffn_embed_dim = ffn_embed_dim * 2 if use_glu_in_ffn else ffn_embed_dim |
|
self.use_glu_in_ffn = use_glu_in_ffn |
|
|
|
self.cross_attention_1 = MultiHeadAttention( |
|
num_heads=num_heads, key_size=key_size, add_bias_kv=add_bias_kv |
|
) |
|
self.cross_attention_2 = MultiHeadAttention( |
|
num_heads=num_heads, key_size=key_size, add_bias_kv=add_bias_kv |
|
) |
|
|
|
self.norm_cross_attention_1 = nn.LayerNorm(embed_dim) |
|
self.norm_cross_attention_2 = nn.LayerNorm(embed_dim) |
|
self.norm_mlp = nn.LayerNorm(embed_dim) |
|
|
|
self.fc1 = nn.Linear(embed_dim, self.ffn_embed_dim, bias=add_bias_ffn) |
|
self.fc2 = nn.Linear(self.ffn_embed_dim, embed_dim, bias=add_bias_ffn) |
|
|
|
self.activation_fn = getattr( |
|
nn.functional, ffn_activation_name, nn.functional.gelu |
|
) |
|
|
|
def mlp(self, x: torch.Tensor) -> torch.Tensor: |
|
x = self.norm_mlp(x) |
|
if self.use_glu_in_ffn: |
|
x1, x2 = torch.chunk(self.fc1(x), 2, dim=-1) |
|
x = self.activation_fn(x1) * x2 |
|
else: |
|
x = self.activation_fn(self.fc1(x)) |
|
return self.fc2(x) |
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
cross_attention_embeddings_1: torch.Tensor, |
|
cross_attention_embeddings_2: torch.Tensor, |
|
attention_mask_1: Optional[torch.Tensor] = None, |
|
attention_mask_2: Optional[torch.Tensor] = None, |
|
) -> Dict[str, torch.Tensor]: |
|
res = x |
|
x = self.norm_cross_attention_1(x) |
|
|
|
attn_output = self.cross_attention_1( |
|
query=x, |
|
key=cross_attention_embeddings_1, |
|
value=cross_attention_embeddings_1, |
|
attention_mask=attention_mask_1, |
|
)["embeddings"] |
|
x = res + attn_output |
|
|
|
res = x |
|
x = self.norm_cross_attention_2(x) |
|
attn_output = self.cross_attention_2( |
|
query=x, |
|
key=cross_attention_embeddings_2, |
|
value=cross_attention_embeddings_2, |
|
attention_mask=attention_mask_2, |
|
)["embeddings"] |
|
x = res + attn_output |
|
|
|
x = x + self.mlp(x) |
|
|
|
return {"embeddings": x} |
|
|
|
|
|
class TorchMultiModalPerceiverResampler(nn.Module): |
|
""" |
|
Perceiver Resampler model, made of successive PerceiverResamplerBlocks. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
config: PerceiverResamplerConfig, |
|
name: Optional[str] = None, |
|
): |
|
""" |
|
Initialize a Perceiver Resampler model. |
|
|
|
Args: |
|
config: Dataclass containing model hyperparameters. |
|
name: Name for module (custom will break weight loading). |
|
""" |
|
super().__init__() |
|
self.config = config |
|
self.name = name |
|
self.layers = nn.ModuleList( |
|
[ |
|
TorchMultiModalPerceiverResamplerBlock( |
|
num_heads=self.config.attention_heads, |
|
embed_dim=self.config.embed_dim, |
|
key_size=self.config.key_size, |
|
ffn_embed_dim=self.config.ffn_embed_dim, |
|
add_bias_kv=self.config.add_bias_kv, |
|
add_bias_ffn=self.config.add_bias_ffn, |
|
ffn_activation_name=self.config.ffn_activation_name, |
|
use_glu_in_ffn=self.config.use_glu_in_ffn, |
|
) |
|
for _ in range(self.config.num_layers) |
|
] |
|
) |
|
|
|
self.latent_queries = torch.nn.Parameter( |
|
torch.randn(self.config.resampled_length, self.config.embed_dim) |
|
* ( |
|
1.0 |
|
/ torch.sqrt(torch.tensor(self.config.embed_dim, dtype=torch.float32)) |
|
) |
|
) |
|
|
|
def apply_attention_blocks( |
|
self, |
|
x: torch.Tensor, |
|
xf_1: torch.Tensor, |
|
xf_2: torch.Tensor, |
|
outs: Dict[str, torch.Tensor], |
|
attention_mask_1: Optional[torch.Tensor] = None, |
|
attention_mask_2: Optional[torch.Tensor] = None, |
|
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: |
|
""" |
|
Create the blocks of attention layers and applies them. |
|
""" |
|
for layer in self.layers: |
|
concat_input_1 = torch.cat([xf_1, x], dim=1) |
|
concat_input_2 = torch.cat([xf_2, x], dim=1) |
|
|
|
output = layer( |
|
x=x, |
|
cross_attention_embeddings_1=concat_input_1, |
|
cross_attention_embeddings_2=concat_input_2, |
|
attention_mask_1=attention_mask_1, |
|
attention_mask_2=attention_mask_2, |
|
) |
|
x = output["embeddings"] |
|
|
|
return x, outs |
|
|
|
def forward( |
|
self, |
|
input_embeddings_1: torch.Tensor, |
|
input_embeddings_2: torch.Tensor, |
|
attention_mask_1: Optional[torch.Tensor] = None, |
|
attention_mask_2: Optional[torch.Tensor] = None, |
|
) -> Dict[str, torch.Tensor]: |
|
""" |
|
Computes the embeddings based on the input tokens. |
|
""" |
|
assert ( |
|
input_embeddings_1.shape[-1] == self.config.embed_dim |
|
), "The input embedding dim should match the model embed dim" |
|
assert ( |
|
input_embeddings_2.shape[-1] == self.config.embed_dim |
|
), "The input embedding dim should match the model embed dim" |
|
|
|
batch_size = input_embeddings_1.shape[0] |
|
|
|
latent_queries = self.latent_queries.unsqueeze(0).repeat(batch_size, 1, 1) |
|
|
|
outs: Dict[str, torch.Tensor] = {} |
|
x = latent_queries |
|
|
|
x, outs = self.apply_attention_blocks( |
|
x=x, |
|
xf_1=input_embeddings_1, |
|
xf_2=input_embeddings_2, |
|
outs=outs, |
|
attention_mask_1=attention_mask_1, |
|
attention_mask_2=attention_mask_2, |
|
) |
|
|
|
outs["embeddings"] = x |
|
|
|
return outs |
|
|
|
|
|
class TorchMultiModalPerceiverResamplerProjection(nn.Module): |
|
def __init__( |
|
self, |
|
perceiver_resampler_config: PerceiverResamplerConfig, |
|
input_embed_dim: int, |
|
embed_dim: int, |
|
bio_pad_token_id: int, |
|
english_pad_token_id: int, |
|
english_vocab_size: int, |
|
): |
|
super().__init__() |
|
self.config = perceiver_resampler_config |
|
self.input_embed_dim = input_embed_dim |
|
self.embed_dim = embed_dim |
|
self.bio_pad_token_id = bio_pad_token_id |
|
self.english_pad_token_id = english_pad_token_id |
|
self.english_vocab_size = english_vocab_size |
|
|
|
self.bio_projection = nn.Linear(input_embed_dim, embed_dim) |
|
self.token_embedding = nn.Embedding(english_vocab_size, embed_dim) |
|
self.perceiver_resampler = TorchMultiModalPerceiverResampler(config=self.config) |
|
|
|
def forward( |
|
self, |
|
bio_token_ids: torch.Tensor, |
|
bio_embeddings: torch.Tensor, |
|
english_token_ids: torch.Tensor, |
|
) -> torch.Tensor: |
|
""" |
|
Args: |
|
bio_token_ids (torch.Tensor): |
|
Shape (batch_size, num_bio_tokens) |
|
|
|
bio_embeddings (torch.Tensor): |
|
Shape (batch_size, num_bio_tokens, embed_dim) |
|
|
|
english_token_ids (torch.Tensor): |
|
Shape (batch_size, num_english_tokens) |
|
""" |
|
projected_bio_embeddings = self.bio_projection(bio_embeddings) |
|
english_embeddings = self.token_embedding(english_token_ids) |
|
|
|
bio_attention_mask = build_perceiver_padding_attention_mask( |
|
bio_token_ids, self.config.resampled_length, self.bio_pad_token_id |
|
) |
|
english_attention_mask = build_perceiver_padding_attention_mask( |
|
english_token_ids, self.config.resampled_length, self.english_pad_token_id |
|
) |
|
|
|
projected_embeddings = self.perceiver_resampler( |
|
input_embeddings_1=projected_bio_embeddings, |
|
attention_mask_1=bio_attention_mask, |
|
input_embeddings_2=english_embeddings, |
|
attention_mask_2=english_attention_mask, |
|
)["embeddings"] |
|
|
|
return projected_embeddings |
|
|
|
|
|
def build_perceiver_padding_attention_mask( |
|
tokens: torch.Tensor, resampled_length: int, pad_token_id: int |
|
) -> torch.Tensor: |
|
batch_size, seq_len = tokens.shape |
|
padding_mask = tokens != pad_token_id |
|
|
|
padding_mask = torch.cat( |
|
[ |
|
padding_mask, |
|
torch.ones( |
|
(batch_size, resampled_length), dtype=torch.bool, device=tokens.device |
|
), |
|
], |
|
dim=1, |
|
) |
|
|
|
padding_mask = padding_mask[:, None, None, :] |
|
padding_mask = padding_mask.repeat(1, 1, resampled_length, 1) |
|
return padding_mask |
|
|