lambda-technologies-limited's picture
Uploading optimized model files
4d4f910 verified
raw
history blame contribute delete
4.35 kB
import dataclasses
import json
import warnings
from dataclasses import dataclass, MISSING
from functools import partial
from typing import Optional, Any
@partial(dataclass, frozen=True, kw_only=True)
class JsonComparable:
def to_json(self) -> str:
return json.dumps(dataclasses.asdict(self))
def __eq__(self, other: "JsonComparable") -> bool:
return self.to_json() == other.to_json()
def __hash__(self) -> int:
return hash(self.to_json())
def __lt__(self, other: "JsonComparable") -> bool:
return self.to_json() < other.to_json()
@partial(dataclass, frozen=True, kw_only=True)
class SubblockConfig(JsonComparable):
no_op: bool = False
replace_with_linear: bool = False
sparsify: Optional[list[str]] = None
def __post_init__(self):
assert not (self.no_op and self.replace_with_linear)
def _force_setattr(self, name: str, value: Any) -> None:
"""
Set an attribute even in frozen dataclasses.
Use only inside __post_init__!
"""
object.__setattr__(self, name, value)
@partial(dataclass, frozen=True, kw_only=True)
class AttentionConfig(SubblockConfig):
n_heads_in_group: Optional[int] = None
window_length: Optional[int] = None
num_sink_tokens: Optional[int] = None
use_prefill_window_in_sink_attention: bool = False
unshifted_sink: bool = False
def __post_init__(self):
super().__post_init__()
assert not (self.no_op and self.replace_with_linear)
if self.no_op or self.replace_with_linear:
for irrelevant_att in ["n_heads_in_group", "window_length", "num_sink_tokens"]:
self._force_setattr(irrelevant_att, None)
else:
assert self.n_heads_in_group is not None
if self.is_sink:
assert not (self.unshifted_sink and self.use_prefill_window_in_sink_attention), \
("Unshifted sink uses its own kind of explicit masking, not standard window. "
"Set use_prefill_window_in_sink_attention to False.")
assert not (self.num_sink_tokens == 0 and not self.unshifted_sink), \
"Fake sink attention with 0 sink tokens is only supported with unshifted_sink=True"
@property
def prefill_sliding_window(self) -> Optional[int]:
if self.window_length is not None:
if not self.is_sink or self.use_prefill_window_in_sink_attention:
return self.window_length
return None
@property
def is_sliding(self) -> bool:
return self.prefill_sliding_window is not None
@property
def is_sink(self) -> bool:
return (
(self.window_length is not None)
and
(self.num_sink_tokens is not None)
)
@partial(dataclass, frozen=True, kw_only=True)
class FFNConfig(SubblockConfig):
ffn_mult: Optional[float] = None
def __post_init__(self):
super().__post_init__()
if self.no_op or self.replace_with_linear:
self._force_setattr("ffn_mult", None)
else:
assert self.ffn_mult is not None
self._force_setattr("ffn_mult", round(self.ffn_mult, 6))
@partial(dataclass, frozen=True, kw_only=True)
class BlockConfig(JsonComparable):
attention: AttentionConfig = MISSING
ffn: FFNConfig = MISSING
def __post_init__(self):
"""
Init subblock dataclasses from dicts
"""
for subblock_name in dataclasses.fields(self):
subblock_config = getattr(self, subblock_name.name)
if isinstance(subblock_config, dict):
subblock_fields = [field.name for field in dataclasses.fields(subblock_name.type)]
unsupported_fields = [field_name for field_name in subblock_config.keys()
if field_name not in subblock_fields]
if len(unsupported_fields) > 0:
warnings.warn(f"Removed unsupported fields {unsupported_fields} from {subblock_name.type.__name__}")
subblock_config = {k: v for k, v in subblock_config.items() if k not in unsupported_fields}
object.__setattr__(self, subblock_name.name,
subblock_name.type(**subblock_config)) # __setattr__ to overcome frozen=True