lambda-technologies-limited commited on
Commit
4d4f910
·
verified ·
1 Parent(s): a2a80f4

Uploading optimized model files

Browse files
Files changed (1) hide show
  1. block_config.py +118 -0
block_config.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import json
3
+ import warnings
4
+ from dataclasses import dataclass, MISSING
5
+ from functools import partial
6
+ from typing import Optional, Any
7
+
8
+
9
+ @partial(dataclass, frozen=True, kw_only=True)
10
+ class JsonComparable:
11
+ def to_json(self) -> str:
12
+ return json.dumps(dataclasses.asdict(self))
13
+
14
+ def __eq__(self, other: "JsonComparable") -> bool:
15
+ return self.to_json() == other.to_json()
16
+
17
+ def __hash__(self) -> int:
18
+ return hash(self.to_json())
19
+
20
+ def __lt__(self, other: "JsonComparable") -> bool:
21
+ return self.to_json() < other.to_json()
22
+
23
+
24
+ @partial(dataclass, frozen=True, kw_only=True)
25
+ class SubblockConfig(JsonComparable):
26
+ no_op: bool = False
27
+ replace_with_linear: bool = False
28
+ sparsify: Optional[list[str]] = None
29
+
30
+ def __post_init__(self):
31
+ assert not (self.no_op and self.replace_with_linear)
32
+
33
+ def _force_setattr(self, name: str, value: Any) -> None:
34
+ """
35
+ Set an attribute even in frozen dataclasses.
36
+ Use only inside __post_init__!
37
+ """
38
+ object.__setattr__(self, name, value)
39
+
40
+
41
+ @partial(dataclass, frozen=True, kw_only=True)
42
+ class AttentionConfig(SubblockConfig):
43
+ n_heads_in_group: Optional[int] = None
44
+ window_length: Optional[int] = None
45
+ num_sink_tokens: Optional[int] = None
46
+ use_prefill_window_in_sink_attention: bool = False
47
+ unshifted_sink: bool = False
48
+
49
+ def __post_init__(self):
50
+ super().__post_init__()
51
+ assert not (self.no_op and self.replace_with_linear)
52
+
53
+ if self.no_op or self.replace_with_linear:
54
+ for irrelevant_att in ["n_heads_in_group", "window_length", "num_sink_tokens"]:
55
+ self._force_setattr(irrelevant_att, None)
56
+ else:
57
+ assert self.n_heads_in_group is not None
58
+
59
+ if self.is_sink:
60
+ assert not (self.unshifted_sink and self.use_prefill_window_in_sink_attention), \
61
+ ("Unshifted sink uses its own kind of explicit masking, not standard window. "
62
+ "Set use_prefill_window_in_sink_attention to False.")
63
+ assert not (self.num_sink_tokens == 0 and not self.unshifted_sink), \
64
+ "Fake sink attention with 0 sink tokens is only supported with unshifted_sink=True"
65
+
66
+ @property
67
+ def prefill_sliding_window(self) -> Optional[int]:
68
+ if self.window_length is not None:
69
+ if not self.is_sink or self.use_prefill_window_in_sink_attention:
70
+ return self.window_length
71
+ return None
72
+
73
+ @property
74
+ def is_sliding(self) -> bool:
75
+ return self.prefill_sliding_window is not None
76
+
77
+ @property
78
+ def is_sink(self) -> bool:
79
+ return (
80
+ (self.window_length is not None)
81
+ and
82
+ (self.num_sink_tokens is not None)
83
+ )
84
+
85
+
86
+ @partial(dataclass, frozen=True, kw_only=True)
87
+ class FFNConfig(SubblockConfig):
88
+ ffn_mult: Optional[float] = None
89
+
90
+ def __post_init__(self):
91
+ super().__post_init__()
92
+ if self.no_op or self.replace_with_linear:
93
+ self._force_setattr("ffn_mult", None)
94
+ else:
95
+ assert self.ffn_mult is not None
96
+ self._force_setattr("ffn_mult", round(self.ffn_mult, 6))
97
+
98
+
99
+ @partial(dataclass, frozen=True, kw_only=True)
100
+ class BlockConfig(JsonComparable):
101
+ attention: AttentionConfig = MISSING
102
+ ffn: FFNConfig = MISSING
103
+
104
+ def __post_init__(self):
105
+ """
106
+ Init subblock dataclasses from dicts
107
+ """
108
+ for subblock_name in dataclasses.fields(self):
109
+ subblock_config = getattr(self, subblock_name.name)
110
+ if isinstance(subblock_config, dict):
111
+ subblock_fields = [field.name for field in dataclasses.fields(subblock_name.type)]
112
+ unsupported_fields = [field_name for field_name in subblock_config.keys()
113
+ if field_name not in subblock_fields]
114
+ if len(unsupported_fields) > 0:
115
+ warnings.warn(f"Removed unsupported fields {unsupported_fields} from {subblock_name.type.__name__}")
116
+ subblock_config = {k: v for k, v in subblock_config.items() if k not in unsupported_fields}
117
+ object.__setattr__(self, subblock_name.name,
118
+ subblock_name.type(**subblock_config)) # __setattr__ to overcome frozen=True