|
from dataclasses import dataclass |
|
|
|
import torch.nn as nn |
|
from transformers.configuration_utils import PretrainedConfig |
|
from transformers.models.llama.configuration_llama import LlamaConfig |
|
|
|
|
|
@dataclass |
|
class TransformerHATModelConfig(LlamaConfig): |
|
def __init__( |
|
self, |
|
hidden_size: int, |
|
num_hidden_layers: int, |
|
num_attention_heads: int, |
|
num_key_value_heads: int, |
|
rms_norm_eps: float, |
|
intermediate_size: int, |
|
max_position_embeddings: int, |
|
rope_scaling: dict, |
|
rope_theta: float, |
|
mlp_bias: bool, |
|
use_cache: bool = True, |
|
sliding_window: int | None = None, |
|
vocab_size: int = 0, |
|
hidden_act: str = "silu", |
|
**kwargs, |
|
): |
|
super().__init__( |
|
vocab_size=vocab_size, |
|
hidden_size=hidden_size, |
|
num_hidden_layers=num_hidden_layers, |
|
num_attention_heads=num_attention_heads, |
|
num_key_value_heads=num_key_value_heads, |
|
hidden_act=hidden_act, |
|
rms_norm_eps=rms_norm_eps, |
|
intermediate_size=intermediate_size, |
|
max_position_embeddings=max_position_embeddings, |
|
rope_scaling=rope_scaling, |
|
rope_theta=rope_theta, |
|
mlp_bias=mlp_bias, |
|
use_cache=use_cache, |
|
**kwargs, |
|
) |
|
|
|
self.sliding_window = sliding_window |
|
|
|
def to_dict(self): |
|
config_dict = { |
|
"vocab_size": self.vocab_size, |
|
"hidden_size": self.hidden_size, |
|
"num_hidden_layers": self.num_hidden_layers, |
|
"num_attention_heads": self.num_attention_heads, |
|
"num_key_value_heads": self.num_key_value_heads, |
|
"rms_norm_eps": self.rms_norm_eps, |
|
"intermediate_size": self.intermediate_size, |
|
"max_position_embeddings": self.max_position_embeddings, |
|
"rope_scaling": self.rope_scaling, |
|
"rope_theta": self.rope_theta, |
|
"mlp_bias": self.mlp_bias, |
|
"use_cache": self.use_cache, |
|
"sliding_window": self.sliding_window, |
|
"transformers_version": self.transformers_version, |
|
} |
|
return config_dict |
|
|
|
|
|
@dataclass |
|
class CrossAttentionConfig: |
|
def __init__( |
|
self, |
|
hidden_size: int, |
|
hidden_size_q: int, |
|
hidden_size_kv: int, |
|
num_attention_heads: int, |
|
attention_num_kv_heads: int, |
|
word_window_size: int, |
|
): |
|
self.hidden_size = hidden_size |
|
self.hidden_size_q = hidden_size_q |
|
self.hidden_size_kv = hidden_size_kv |
|
self.num_attention_heads = num_attention_heads |
|
self.attention_num_kv_heads = attention_num_kv_heads |
|
self.word_window_size = word_window_size |
|
|
|
def to_dict(self): |
|
return { |
|
"hidden_size_q": self.hidden_size_q, |
|
"hidden_size_kv": self.hidden_size_kv, |
|
"hidden_size": self.hidden_size, |
|
"num_attention_heads": self.num_attention_heads, |
|
"attention_num_kv_heads": self.attention_num_kv_heads, |
|
"word_window_size": self.word_window_size, |
|
} |
|
|
|
|
|
@dataclass |
|
class DecoderHATModelConfig(TransformerHATModelConfig): |
|
def __init__( |
|
self, |
|
num_attention_heads: int, |
|
num_key_value_heads: int, |
|
sliding_window: int, |
|
cross_attention_config: CrossAttentionConfig, |
|
cross_attn_every_layer: bool, |
|
**kwargs, |
|
): |
|
super().__init__( |
|
num_attention_heads=num_attention_heads, |
|
num_key_value_heads=num_key_value_heads, |
|
sliding_window=sliding_window, |
|
**kwargs, |
|
) |
|
self.cross_attn_every_layer = cross_attn_every_layer |
|
self.cross_attention_config = cross_attention_config |
|
|
|
def to_dict(self): |
|
config_dict = super().to_dict() |
|
config_dict["cross_attn_every_layer"] = self.cross_attn_every_layer |
|
config_dict["cross_attention_config"] = self.cross_attention_config.to_dict() |
|
return config_dict |
|
|
|
@classmethod |
|
def from_dict(cls, config_dict, **kwargs): |
|
config_dict = config_dict.copy() |
|
config_dict.update(kwargs) |
|
dict_config = config_dict.pop("cross_attention_config", {}) |
|
cross_attention_config = CrossAttentionConfig(**dict_config) |
|
config_dict["cross_attention_config"] = cross_attention_config |
|
return cls(**config_dict) |
|
|
|
|
|
@dataclass |
|
class EncoderHATModelConfig(TransformerHATModelConfig): |
|
def __init__( |
|
self, |
|
cross_attention_config: CrossAttentionConfig, |
|
**kwargs, |
|
): |
|
super().__init__(**kwargs) |
|
self.cross_attention_config = cross_attention_config |
|
|
|
@classmethod |
|
def from_dict(cls, config_dict, **kwargs): |
|
config_dict = config_dict.copy() |
|
config_dict.update(kwargs) |
|
dict_config = config_dict.pop("cross_attention_config", {}) |
|
cross_attention_config = CrossAttentionConfig(**dict_config) |
|
config_dict["cross_attention_config"] = cross_attention_config |
|
|
|
return cls(**config_dict) |
|
|
|
def to_dict(self): |
|
config_dict = super().to_dict() |
|
if self.cross_attention_config: |
|
config_dict["cross_attention_config"] = self.cross_attention_config.to_dict() |
|
return config_dict |
|
|
|
|
|
@dataclass |
|
class HATArchitectureConfig(PretrainedConfig): |
|
model_type: str |
|
|
|
def __init__( |
|
self, |
|
special_token_dict : dict | None = None, |
|
encoder_config: EncoderHATModelConfig | None = None, |
|
backbone_config: TransformerHATModelConfig | None = None, |
|
decoder_config: DecoderHATModelConfig | None = None, |
|
model_type: str = "hierarchical_autoregressive_transformer", |
|
eos_token_id: int = 192, |
|
max_word_size: int = 100, |
|
**kwargs, |
|
): |
|
super().__init__(**kwargs) |
|
self.encoder_config = encoder_config |
|
self.backbone_config = backbone_config |
|
self.decoder_config = decoder_config |
|
self.model_type = model_type |
|
self.eos_token_id = eos_token_id |
|
self.max_word_size = max_word_size |
|
self.special_token_dict = special_token_dict |
|
self.transformers_version = "4.46.3" |
|
|
|
@classmethod |
|
def from_dict(cls, config_dict, **kwargs): |
|
""" |
|
Instantiates a HATArchitectureConfig from a Python dictionary of parameters. |
|
|
|
Overrides the base `from_dict` to correctly handle nested config objects. |
|
""" |
|
config_dict = config_dict.copy() |
|
config_dict.update(kwargs) |
|
|
|
|
|
encoder_dict = config_dict.pop("encoder_config", {}) |
|
backbone_dict = config_dict.pop("backbone_config", {}) |
|
decoder_dict = config_dict.pop("decoder_config", {}) |
|
|
|
|
|
encoder_config = EncoderHATModelConfig.from_dict(encoder_dict) if encoder_dict else None |
|
backbone_config = TransformerHATModelConfig.from_dict(backbone_dict) if backbone_dict else None |
|
decoder_config = DecoderHATModelConfig.from_dict(decoder_dict) if decoder_dict else None |
|
special_token_dict = config_dict.pop("special_token_dict", {"<|eot_id|>": 192}) |
|
max_word_size = config_dict.pop("max_word_size", 100) |
|
return cls( |
|
encoder_config=encoder_config, |
|
backbone_config=backbone_config, |
|
decoder_config=decoder_config, |
|
special_token_dict=special_token_dict, |
|
max_word_size=max_word_size, |
|
**config_dict, |
|
), {} |
|
|
|
def to_dict(self): |
|
config_dict = {} |
|
if self.encoder_config: |
|
config_dict["encoder_config"] = self.encoder_config.to_dict() |
|
if self.backbone_config: |
|
config_dict["backbone_config"] = self.backbone_config.to_dict() |
|
if self.decoder_config: |
|
config_dict["decoder_config"] = self.decoder_config.to_dict() |
|
config_dict["model_type"] = self.model_type |
|
config_dict["transformers_version"] = self.transformers_version |
|
config_dict["auto_map"] = {"AutoConfig": "config.HATArchitectureConfig", "AutoModelForCausalLM": "model.HATForCausalLM"} |
|
config_dict["special_token_dict"] = self.special_token_dict |
|
return config_dict |
|
|
|
|
|
class EncoderHATModel(nn.Module): |
|
def __init__(self, config: HATArchitectureConfig, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.config = config |
|
|