Text Generation
Transformers
Safetensors
PyTorch
English
nemotron-nas
llama-3
llama
nvidia-nemotron
nemotron-ultra
fine-tuned
conversational-ai
large-language-model
huggingface
open-source-llm
generative-ai
nvidia
meta-llama
instruct-tuning
chat-model
llm
artificial-intelligence
deep-learning
tensorrt-llm
gpu-optimized
multilingual
instruction-following
conversational
custom_code
import dataclasses | |
import json | |
import warnings | |
from dataclasses import dataclass, MISSING | |
from functools import partial | |
from typing import Optional, Any | |
class JsonComparable: | |
def to_json(self) -> str: | |
return json.dumps(dataclasses.asdict(self)) | |
def __eq__(self, other: "JsonComparable") -> bool: | |
return self.to_json() == other.to_json() | |
def __hash__(self) -> int: | |
return hash(self.to_json()) | |
def __lt__(self, other: "JsonComparable") -> bool: | |
return self.to_json() < other.to_json() | |
class SubblockConfig(JsonComparable): | |
no_op: bool = False | |
replace_with_linear: bool = False | |
sparsify: Optional[list[str]] = None | |
def __post_init__(self): | |
assert not (self.no_op and self.replace_with_linear) | |
def _force_setattr(self, name: str, value: Any) -> None: | |
""" | |
Set an attribute even in frozen dataclasses. | |
Use only inside __post_init__! | |
""" | |
object.__setattr__(self, name, value) | |
class AttentionConfig(SubblockConfig): | |
n_heads_in_group: Optional[int] = None | |
window_length: Optional[int] = None | |
num_sink_tokens: Optional[int] = None | |
use_prefill_window_in_sink_attention: bool = False | |
unshifted_sink: bool = False | |
def __post_init__(self): | |
super().__post_init__() | |
assert not (self.no_op and self.replace_with_linear) | |
if self.no_op or self.replace_with_linear: | |
for irrelevant_att in ["n_heads_in_group", "window_length", "num_sink_tokens"]: | |
self._force_setattr(irrelevant_att, None) | |
else: | |
assert self.n_heads_in_group is not None | |
if self.is_sink: | |
assert not (self.unshifted_sink and self.use_prefill_window_in_sink_attention), \ | |
("Unshifted sink uses its own kind of explicit masking, not standard window. " | |
"Set use_prefill_window_in_sink_attention to False.") | |
assert not (self.num_sink_tokens == 0 and not self.unshifted_sink), \ | |
"Fake sink attention with 0 sink tokens is only supported with unshifted_sink=True" | |
def prefill_sliding_window(self) -> Optional[int]: | |
if self.window_length is not None: | |
if not self.is_sink or self.use_prefill_window_in_sink_attention: | |
return self.window_length | |
return None | |
def is_sliding(self) -> bool: | |
return self.prefill_sliding_window is not None | |
def is_sink(self) -> bool: | |
return ( | |
(self.window_length is not None) | |
and | |
(self.num_sink_tokens is not None) | |
) | |
class FFNConfig(SubblockConfig): | |
ffn_mult: Optional[float] = None | |
def __post_init__(self): | |
super().__post_init__() | |
if self.no_op or self.replace_with_linear: | |
self._force_setattr("ffn_mult", None) | |
else: | |
assert self.ffn_mult is not None | |
self._force_setattr("ffn_mult", round(self.ffn_mult, 6)) | |
class BlockConfig(JsonComparable): | |
attention: AttentionConfig = MISSING | |
ffn: FFNConfig = MISSING | |
def __post_init__(self): | |
""" | |
Init subblock dataclasses from dicts | |
""" | |
for subblock_name in dataclasses.fields(self): | |
subblock_config = getattr(self, subblock_name.name) | |
if isinstance(subblock_config, dict): | |
subblock_fields = [field.name for field in dataclasses.fields(subblock_name.type)] | |
unsupported_fields = [field_name for field_name in subblock_config.keys() | |
if field_name not in subblock_fields] | |
if len(unsupported_fields) > 0: | |
warnings.warn(f"Removed unsupported fields {unsupported_fields} from {subblock_name.type.__name__}") | |
subblock_config = {k: v for k, v in subblock_config.items() if k not in unsupported_fields} | |
object.__setattr__(self, subblock_name.name, | |
subblock_name.type(**subblock_config)) # __setattr__ to overcome frozen=True | |