Feature Extraction
Transformers
Safetensors
ModularStarEncoder
custom_code
ModularStarEncoder / config.py
andreagurioli1995's picture
Upload ModularStarEncoder
dde6157 verified
from transformers import PretrainedConfig
from typing import List
#STARCODER2_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
class ModularStarEncoderConfig(PretrainedConfig):
model_type = "ModularStarEncoder"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
attention_dropout= 0.1,
residual_dropout= 0.1,
embedding_dropout= 0.1,
bos_token_id= 0,
eos_token_id= 0,
hidden_act= "gelu_pytorch_tanh",
_attn_implementation="flash_attention_2",
hidden_size= 1024,
conditional_size= 4,
initializer_range= 0.018042,
intermediate_size= 12288,
max_position_embeddings= 2048,
mlp_type= "default",
model_type= "starcoder2",
torch_dtype= "bfloat16",
layer_matryoshka_loss= True,
matryoshka_layers= [4,9,18,27,36],
norm_epsilon= 1e-05,
layer_norm_eps=1e-05,
norm_type= "layer_norm",
num_attention_heads= 16,
num_hidden_layers= 36,
num_key_value_heads= 4,
rope_theta= 999999.4420358813,
sliding_window= None,
transformers_version= "4.39.3",
use_bias= True,
use_cache= False,
vocab_size= 49156,
pad_token_id=0,
**kwargs,
):
if _attn_implementation not in ["flash_attention_2", "sdpa"]:
raise ValueError(f"`_attn_implementation` must be 'flash_attention_2', 'sdpa', got {_attn_implementation}.")
self.attention_dropout=attention_dropout ,
self.residual_dropout= residual_dropout,
self.embedding_dropout= embedding_dropout,
self.bos_token_id= bos_token_id,
self.eos_token_id= eos_token_id,
self.hidden_act= hidden_act,
self._attn_implementation=_attn_implementation,
self.hidden_size= hidden_size,
self.conditional_size= conditional_size,
self.initializer_range= initializer_range,
self.intermediate_size= intermediate_size,
self.max_position_embeddings= max_position_embeddings,
self.mlp_type= mlp_type,
self.model_type= model_type,
self.torch_dtype= torch_dtype,
self.layer_matryoshka_loss= layer_matryoshka_loss,
self.matryoshka_layers= matryoshka_layers,
self.norm_epsilon= norm_epsilon,
self.layer_norm_eps=layer_norm_eps,
self.norm_type= norm_type,
self.num_attention_heads= num_attention_heads,
self.num_hidden_layers= num_hidden_layers,
self.num_key_value_heads= num_key_value_heads,
self.rope_theta= rope_theta,
self.sliding_window= sliding_window,
self.transformers_version= transformers_version,
self.use_bias= use_bias,
self.use_cache= use_cache,
self.vocab_size= vocab_size,
self.pad_token_id=pad_token_id,
super().__init__(
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
**kwargs)