sweinbach's picture
Upload HATForCausalLM
58634eb verified
raw
history blame contribute delete
8.62 kB
from dataclasses import dataclass
import torch.nn as nn
from transformers.configuration_utils import PretrainedConfig
from transformers.models.llama.configuration_llama import LlamaConfig
@dataclass
class TransformerHATModelConfig(LlamaConfig):
def __init__(
self,
hidden_size: int,
num_hidden_layers: int,
num_attention_heads: int,
num_key_value_heads: int,
rms_norm_eps: float,
intermediate_size: int,
max_position_embeddings: int,
rope_scaling: dict,
rope_theta: float,
mlp_bias: bool,
use_cache: bool = True,
sliding_window: int | None = None,
vocab_size: int = 0,
hidden_act: str = "silu",
**kwargs,
):
super().__init__(
vocab_size=vocab_size,
hidden_size=hidden_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
hidden_act=hidden_act,
rms_norm_eps=rms_norm_eps,
intermediate_size=intermediate_size,
max_position_embeddings=max_position_embeddings,
rope_scaling=rope_scaling,
rope_theta=rope_theta,
mlp_bias=mlp_bias,
use_cache=use_cache,
**kwargs,
)
self.sliding_window = sliding_window
def to_dict(self):
config_dict = {
"vocab_size": self.vocab_size,
"hidden_size": self.hidden_size,
"num_hidden_layers": self.num_hidden_layers,
"num_attention_heads": self.num_attention_heads,
"num_key_value_heads": self.num_key_value_heads,
"rms_norm_eps": self.rms_norm_eps,
"intermediate_size": self.intermediate_size,
"max_position_embeddings": self.max_position_embeddings,
"rope_scaling": self.rope_scaling,
"rope_theta": self.rope_theta,
"mlp_bias": self.mlp_bias,
"use_cache": self.use_cache,
"sliding_window": self.sliding_window,
"transformers_version": self.transformers_version,
}
return config_dict
@dataclass
class CrossAttentionConfig:
def __init__(
self,
hidden_size: int,
hidden_size_q: int,
hidden_size_kv: int,
num_attention_heads: int,
attention_num_kv_heads: int,
word_window_size: int,
):
self.hidden_size = hidden_size
self.hidden_size_q = hidden_size_q
self.hidden_size_kv = hidden_size_kv
self.num_attention_heads = num_attention_heads
self.attention_num_kv_heads = attention_num_kv_heads
self.word_window_size = word_window_size
def to_dict(self):
return {
"hidden_size_q": self.hidden_size_q,
"hidden_size_kv": self.hidden_size_kv,
"hidden_size": self.hidden_size,
"num_attention_heads": self.num_attention_heads,
"attention_num_kv_heads": self.attention_num_kv_heads,
"word_window_size": self.word_window_size,
}
@dataclass
class DecoderHATModelConfig(TransformerHATModelConfig):
def __init__(
self,
num_attention_heads: int,
num_key_value_heads: int,
sliding_window: int,
cross_attention_config: CrossAttentionConfig,
cross_attn_every_layer: bool,
**kwargs,
):
super().__init__(
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
sliding_window=sliding_window,
**kwargs,
)
self.cross_attn_every_layer = cross_attn_every_layer
self.cross_attention_config = cross_attention_config
def to_dict(self):
config_dict = super().to_dict()
config_dict["cross_attn_every_layer"] = self.cross_attn_every_layer
config_dict["cross_attention_config"] = self.cross_attention_config.to_dict()
return config_dict
@classmethod
def from_dict(cls, config_dict, **kwargs):
config_dict = config_dict.copy() # Avoid modifying the original dict
config_dict.update(kwargs) # Apply overrides
dict_config = config_dict.pop("cross_attention_config", {})
cross_attention_config = CrossAttentionConfig(**dict_config)
config_dict["cross_attention_config"] = cross_attention_config
return cls(**config_dict)
@dataclass
class EncoderHATModelConfig(TransformerHATModelConfig):
def __init__(
self,
cross_attention_config: CrossAttentionConfig,
**kwargs,
):
super().__init__(**kwargs)
self.cross_attention_config = cross_attention_config
@classmethod
def from_dict(cls, config_dict, **kwargs):
config_dict = config_dict.copy() # Avoid modifying the original dict
config_dict.update(kwargs) # Apply overrides
dict_config = config_dict.pop("cross_attention_config", {})
cross_attention_config = CrossAttentionConfig(**dict_config)
config_dict["cross_attention_config"] = cross_attention_config
return cls(**config_dict)
def to_dict(self):
config_dict = super().to_dict()
if self.cross_attention_config:
config_dict["cross_attention_config"] = self.cross_attention_config.to_dict()
return config_dict
@dataclass
class HATArchitectureConfig(PretrainedConfig):
model_type: str
def __init__(
self,
special_token_dict : dict | None = None,
encoder_config: EncoderHATModelConfig | None = None,
backbone_config: TransformerHATModelConfig | None = None,
decoder_config: DecoderHATModelConfig | None = None,
model_type: str = "hierarchical_autoregressive_transformer",
eos_token_id: int = 192,
max_word_size: int = 100,
**kwargs,
):
super().__init__(**kwargs)
self.encoder_config = encoder_config
self.backbone_config = backbone_config
self.decoder_config = decoder_config
self.model_type = model_type
self.eos_token_id = eos_token_id
self.max_word_size = max_word_size
self.special_token_dict = special_token_dict
self.transformers_version = "4.46.3"
@classmethod
def from_dict(cls, config_dict, **kwargs):
"""
Instantiates a HATArchitectureConfig from a Python dictionary of parameters.
Overrides the base `from_dict` to correctly handle nested config objects.
"""
config_dict = config_dict.copy() # Avoid modifying the original dict
config_dict.update(kwargs) # Apply overrides
# Pop and instantiate nested config dictionaries
encoder_dict = config_dict.pop("encoder_config", {})
backbone_dict = config_dict.pop("backbone_config", {})
decoder_dict = config_dict.pop("decoder_config", {})
# Instantiate nested configs
encoder_config = EncoderHATModelConfig.from_dict(encoder_dict) if encoder_dict else None
backbone_config = TransformerHATModelConfig.from_dict(backbone_dict) if backbone_dict else None
decoder_config = DecoderHATModelConfig.from_dict(decoder_dict) if decoder_dict else None
special_token_dict = config_dict.pop("special_token_dict", {"<|eot_id|>": 192})
max_word_size = config_dict.pop("max_word_size", 100)
return cls(
encoder_config=encoder_config,
backbone_config=backbone_config,
decoder_config=decoder_config,
special_token_dict=special_token_dict,
max_word_size=max_word_size,
**config_dict,
), {}
def to_dict(self):
config_dict = {}
if self.encoder_config:
config_dict["encoder_config"] = self.encoder_config.to_dict()
if self.backbone_config:
config_dict["backbone_config"] = self.backbone_config.to_dict()
if self.decoder_config:
config_dict["decoder_config"] = self.decoder_config.to_dict()
config_dict["model_type"] = self.model_type
config_dict["transformers_version"] = self.transformers_version
config_dict["auto_map"] = {"AutoConfig": "config.HATArchitectureConfig", "AutoModelForCausalLM": "model.HATForCausalLM"}
config_dict["special_token_dict"] = self.special_token_dict
return config_dict
class EncoderHATModel(nn.Module):
def __init__(self, config: HATArchitectureConfig, *args, **kwargs):
super().__init__(*args, **kwargs)
self.config = config