File size: 8,619 Bytes
b49ce94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
from dataclasses import dataclass

import torch.nn as nn
from transformers.configuration_utils import PretrainedConfig
from transformers.models.llama.configuration_llama import LlamaConfig


@dataclass
class TransformerHATModelConfig(LlamaConfig):
    def __init__(
        self,
        hidden_size: int,
        num_hidden_layers: int,
        num_attention_heads: int,
        num_key_value_heads: int,
        rms_norm_eps: float,
        intermediate_size: int,
        max_position_embeddings: int,
        rope_scaling: dict,
        rope_theta: float,
        mlp_bias: bool,
        use_cache: bool = True,
        sliding_window: int | None = None,
        vocab_size: int = 0,
        hidden_act: str = "silu",
        **kwargs,
    ):
        super().__init__(
            vocab_size=vocab_size,
            hidden_size=hidden_size,
            num_hidden_layers=num_hidden_layers,
            num_attention_heads=num_attention_heads,
            num_key_value_heads=num_key_value_heads,
            hidden_act=hidden_act,
            rms_norm_eps=rms_norm_eps,
            intermediate_size=intermediate_size,
            max_position_embeddings=max_position_embeddings,
            rope_scaling=rope_scaling,
            rope_theta=rope_theta,
            mlp_bias=mlp_bias,
            use_cache=use_cache,
            **kwargs,
        )

        self.sliding_window = sliding_window

    def to_dict(self):
        config_dict = {
            "vocab_size": self.vocab_size,
            "hidden_size": self.hidden_size,
            "num_hidden_layers": self.num_hidden_layers,
            "num_attention_heads": self.num_attention_heads,
            "num_key_value_heads": self.num_key_value_heads,
            "rms_norm_eps": self.rms_norm_eps,
            "intermediate_size": self.intermediate_size,
            "max_position_embeddings": self.max_position_embeddings,
            "rope_scaling": self.rope_scaling,
            "rope_theta": self.rope_theta,
            "mlp_bias": self.mlp_bias,
            "use_cache": self.use_cache,
            "sliding_window": self.sliding_window,
            "transformers_version": self.transformers_version,
        }
        return config_dict


@dataclass
class CrossAttentionConfig:
    def __init__(
        self,
        hidden_size: int,
        hidden_size_q: int,
        hidden_size_kv: int,
        num_attention_heads: int,
        attention_num_kv_heads: int,
        word_window_size: int,
    ):
        self.hidden_size = hidden_size
        self.hidden_size_q = hidden_size_q
        self.hidden_size_kv = hidden_size_kv
        self.num_attention_heads = num_attention_heads
        self.attention_num_kv_heads = attention_num_kv_heads
        self.word_window_size = word_window_size

    def to_dict(self):
        return {
            "hidden_size_q": self.hidden_size_q,
            "hidden_size_kv": self.hidden_size_kv,
            "hidden_size": self.hidden_size,
            "num_attention_heads": self.num_attention_heads,
            "attention_num_kv_heads": self.attention_num_kv_heads,
            "word_window_size": self.word_window_size,
        }


@dataclass
class DecoderHATModelConfig(TransformerHATModelConfig):
    def __init__(
        self,
        num_attention_heads: int,
        num_key_value_heads: int,
        sliding_window: int,
        cross_attention_config: CrossAttentionConfig,
        cross_attn_every_layer: bool,
        **kwargs,
    ):
        super().__init__(
            num_attention_heads=num_attention_heads,
            num_key_value_heads=num_key_value_heads,
            sliding_window=sliding_window,
            **kwargs,
        )
        self.cross_attn_every_layer = cross_attn_every_layer
        self.cross_attention_config = cross_attention_config

    def to_dict(self):
        config_dict = super().to_dict()
        config_dict["cross_attn_every_layer"] = self.cross_attn_every_layer
        config_dict["cross_attention_config"] = self.cross_attention_config.to_dict()
        return config_dict

    @classmethod
    def from_dict(cls, config_dict, **kwargs):
        config_dict = config_dict.copy()  # Avoid modifying the original dict
        config_dict.update(kwargs)  # Apply overrides
        dict_config = config_dict.pop("cross_attention_config", {})
        cross_attention_config = CrossAttentionConfig(**dict_config)
        config_dict["cross_attention_config"] = cross_attention_config
        return cls(**config_dict)


@dataclass
class EncoderHATModelConfig(TransformerHATModelConfig):
    def __init__(
        self,
        cross_attention_config: CrossAttentionConfig,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.cross_attention_config = cross_attention_config

    @classmethod
    def from_dict(cls, config_dict, **kwargs):
        config_dict = config_dict.copy()  # Avoid modifying the original dict
        config_dict.update(kwargs)  # Apply overrides
        dict_config = config_dict.pop("cross_attention_config", {})
        cross_attention_config = CrossAttentionConfig(**dict_config)
        config_dict["cross_attention_config"] = cross_attention_config

        return cls(**config_dict)

    def to_dict(self):
        config_dict = super().to_dict()
        if self.cross_attention_config:
            config_dict["cross_attention_config"] = self.cross_attention_config.to_dict()
        return config_dict


@dataclass
class HATArchitectureConfig(PretrainedConfig):
    model_type: str

    def __init__(
        self,
        special_token_dict : dict | None = None,
        encoder_config: EncoderHATModelConfig | None = None,
        backbone_config: TransformerHATModelConfig | None = None,
        decoder_config: DecoderHATModelConfig | None = None,
        model_type: str = "hierarchical_autoregressive_transformer",
        eos_token_id: int = 192,
        max_word_size: int = 100,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.encoder_config = encoder_config
        self.backbone_config = backbone_config
        self.decoder_config = decoder_config
        self.model_type = model_type
        self.eos_token_id = eos_token_id
        self.max_word_size = max_word_size
        self.special_token_dict = special_token_dict
        self.transformers_version = "4.46.3"

    @classmethod
    def from_dict(cls, config_dict, **kwargs):
        """
        Instantiates a HATArchitectureConfig from a Python dictionary of parameters.

        Overrides the base `from_dict` to correctly handle nested config objects.
        """
        config_dict = config_dict.copy()  # Avoid modifying the original dict
        config_dict.update(kwargs)  # Apply overrides

        # Pop and instantiate nested config dictionaries
        encoder_dict = config_dict.pop("encoder_config", {})
        backbone_dict = config_dict.pop("backbone_config", {})
        decoder_dict = config_dict.pop("decoder_config", {})

        # Instantiate nested configs
        encoder_config = EncoderHATModelConfig.from_dict(encoder_dict) if encoder_dict else None
        backbone_config = TransformerHATModelConfig.from_dict(backbone_dict) if backbone_dict else None
        decoder_config = DecoderHATModelConfig.from_dict(decoder_dict) if decoder_dict else None
        special_token_dict = config_dict.pop("special_token_dict", {"<|eot_id|>": 192})
        max_word_size = config_dict.pop("max_word_size", 100)
        return cls(
            encoder_config=encoder_config,
            backbone_config=backbone_config,
            decoder_config=decoder_config,
            special_token_dict=special_token_dict,
            max_word_size=max_word_size,
            **config_dict,
        ), {}

    def to_dict(self):
        config_dict = {}
        if self.encoder_config:
            config_dict["encoder_config"] = self.encoder_config.to_dict()
        if self.backbone_config:
            config_dict["backbone_config"] = self.backbone_config.to_dict()
        if self.decoder_config:
            config_dict["decoder_config"] = self.decoder_config.to_dict()
        config_dict["model_type"] = self.model_type
        config_dict["transformers_version"] = self.transformers_version
        config_dict["auto_map"] = {"AutoConfig": "config.HATArchitectureConfig", "AutoModelForCausalLM": "model.HATForCausalLM"}
        config_dict["special_token_dict"] = self.special_token_dict
        return config_dict


class EncoderHATModel(nn.Module):
    def __init__(self, config: HATArchitectureConfig, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.config = config