|
from transformers import Starcoder2Model |
|
import sys |
|
from .config import ModularStarEncoderConfig |
|
import os |
|
from dataclasses import dataclass |
|
from typing import Optional, Tuple, Union |
|
import sys |
|
import torch |
|
import torch.utils.checkpoint |
|
from torch import nn |
|
from torch.nn import CrossEntropyLoss |
|
from transformers.activations import ACT2FN |
|
from transformers.modeling_utils import PreTrainedModel |
|
from transformers.utils import ( |
|
ModelOutput, |
|
logging, |
|
|
|
) |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
class StarEncoder2PreTrainedModel(PreTrainedModel): |
|
""" |
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained |
|
models. |
|
""" |
|
|
|
config_class = ModularStarEncoderConfig |
|
base_model_prefix = "ModularStarEncoder" |
|
model_type = "ModularStarEncoder" |
|
supports_gradient_checkpointing = True |
|
_supports_flash_attn_2 = True |
|
_supports_sdpa = True |
|
_supports_cache_class = True |
|
|
|
|
|
|
|
def _init_weights(self, module): |
|
"""Initialize the weights""" |
|
if isinstance(module, nn.Linear): |
|
|
|
|
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
elif isinstance(module, nn.Embedding): |
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
if module.padding_idx is not None: |
|
module.weight.data[module.padding_idx].zero_() |
|
elif isinstance(module, nn.LayerNorm): |
|
module.bias.data.zero_() |
|
module.weight.data.fill_(1.0) |
|
|
|
class StarEncoder2Pooler(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
|
self.activation = nn.Tanh() |
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
last_token_tensor = hidden_states[:, -1] |
|
pooled_output = self.dense(last_token_tensor) |
|
pooled_output = self.activation(pooled_output) |
|
return pooled_output |
|
|
|
@dataclass |
|
class ModularStarEncoderOutput(ModelOutput): |
|
""" |
|
Output type of [`BertForPreTraining`]. |
|
|
|
Args: |
|
loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`): |
|
Total loss as the sum of the masked language modeling loss and the next sequence prediction |
|
(classification) loss. |
|
prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): |
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). |
|
seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`): |
|
Prediction scores of the in context classification (classification) head (scores of True/False continuation |
|
before SoftMax). |
|
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): |
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of |
|
shape `(batch_size, sequence_length, hidden_size)`. |
|
|
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs. |
|
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): |
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, |
|
sequence_length)`. |
|
|
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention |
|
heads. |
|
""" |
|
|
|
last_hidden_state: Optional[torch.FloatTensor] = None |
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
loss: Optional[torch.FloatTensor] = None |
|
prediction_logits: torch.FloatTensor = None |
|
seq_relationship_logits: torch.FloatTensor = None |
|
attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
|
|
|
|
|
|
|
|
class StarEncoder2PredictionHeadTransform(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.is_matryoshka = config.layer_matryoshka_loss |
|
|
|
if self.is_matryoshka: |
|
self.dense = nn.Linear(config.hidden_size + config.conditional_size, config.hidden_size + config.conditional_size) |
|
self.LayerNorm = nn.LayerNorm(config.hidden_size + config.conditional_size, eps=config.layer_norm_eps) |
|
|
|
else: |
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
|
|
if isinstance(config.hidden_act, str): |
|
self.transform_act_fn = ACT2FN[config.hidden_act] |
|
else: |
|
self.transform_act_fn = config.hidden_act |
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
hidden_states = self.dense(hidden_states) |
|
hidden_states = self.transform_act_fn(hidden_states) |
|
hidden_states = self.LayerNorm(hidden_states) |
|
return hidden_states |
|
|
|
|
|
|
|
class StarEncoder2LMPredictionHead(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
for element in dir(config): |
|
value = getattr(config, element) |
|
if (isinstance(value, tuple) or isinstance(value, list)) and len(value)>0: |
|
setattr(config, element, value[0]) |
|
self.transform = StarEncoder2PredictionHeadTransform(config) |
|
|
|
|
|
|
|
self.is_matryoshka = config.layer_matryoshka_loss |
|
|
|
if self.is_matryoshka: |
|
self.decoder = nn.Linear(config.hidden_size + config.conditional_size, config.vocab_size, bias=False) |
|
else: |
|
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
|
|
|
self.bias = nn.Parameter(torch.zeros(config.vocab_size)) |
|
|
|
|
|
self.decoder.bias = self.bias |
|
|
|
def forward(self, hidden_states): |
|
hidden_states = self.transform(hidden_states) |
|
hidden_states = self.decoder(hidden_states) |
|
return hidden_states |
|
|
|
class StarEncoder2PreTrainingHeads(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.predictions = StarEncoder2LMPredictionHead(config) |
|
self.is_matryoshka = config.layer_matryoshka_loss |
|
if self.is_matryoshka: |
|
self.seq_relationship = nn.Linear(config.hidden_size + config.conditional_size, 2) |
|
self.conditional_embeddings = nn.Embedding(len(config.matryoshka_layers),config.conditional_size) |
|
else: |
|
self.seq_relationship = nn.Linear(config.hidden_size, 2) |
|
|
|
|
|
|
|
def forward(self, sequence_output, pooled_output,idx_layer: Optional[torch.Tensor] = None): |
|
if self.is_matryoshka: |
|
device_sequence = sequence_output.get_device() |
|
if device_sequence<0: |
|
device_sequence = "cpu" |
|
prediction_scores = self.predictions(torch.cat([sequence_output , self.conditional_embeddings(torch.tensor(idx_layer,device=device_sequence).int()).expand(sequence_output.size()[0],sequence_output.size()[1],-1)],dim=-1)) |
|
seq_relationship_score = self.seq_relationship(torch.cat([pooled_output , self.conditional_embeddings(torch.tensor(idx_layer,device=device_sequence).int()).expand(pooled_output.size()[0],-1)],dim=-1)) |
|
else: |
|
prediction_scores = self.predictions(sequence_output) |
|
seq_relationship_score = self.seq_relationship(pooled_output) |
|
return prediction_scores, seq_relationship_score |
|
|
|
|
|
|
|
|
|
|
|
class ModularStarEncoder(StarEncoder2PreTrainedModel): |
|
_tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"] |
|
config_class = ModularStarEncoderConfig |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.model_type = "ModularStarEncoder" |
|
self.cls = StarEncoder2PreTrainingHeads(config) |
|
self.layer_matryoshka_loss = config.layer_matryoshka_loss |
|
self.matryoshka_layers = config.matryoshka_layers |
|
|
|
if self.layer_matryoshka_loss: |
|
config.sliding_window = None |
|
logger.warning_once( |
|
"The matryoshka loss is implemented without sliding_window, if you want to use the sliding window set sliding_window to True" |
|
) |
|
if self.matryoshka_layers[-1] != config.num_hidden_layers: |
|
logger.warning_once( |
|
f"To get optimal results, the last layer on matryoshka layers, which now is {self.matryoshka_layers[-1]} " |
|
"must be set as the overall number of hidden layers." |
|
f"The overall number of hidden layers is now set to {config.num_hidden_layers}" |
|
) |
|
sys.exit() |
|
|
|
|
|
|
|
self.starEncoder2 = Starcoder2Model(config) |
|
|
|
|
|
self.pooler = StarEncoder2Pooler(config) |
|
|
|
|
|
for layer in self.starEncoder2.layers: |
|
layer.self_attn.is_causal=False |
|
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
|
|
|
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
|
position_ids: Optional[torch.Tensor] = None, |
|
head_mask: Optional[torch.Tensor] = None, |
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
labels: Optional[torch.Tensor] = None, |
|
next_sentence_label: Optional[torch.Tensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple[torch.Tensor], ModularStarEncoderOutput]: |
|
r""" |
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., |
|
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), |
|
the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` |
|
next_sentence_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
|
This label is assigned to the in context loss: |
|
- 0 indicates sequence B belongs to the same repository of A, |
|
- 1 indicates sequence B is a random repository. |
|
kwargs (`Dict[str, any]`, optional, defaults to *{}*): |
|
Used to hide legacy arguments that have been deprecated. |
|
|
|
|
|
""" |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
outputs = self.starEncoder2( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
inputs_embeds=inputs_embeds, |
|
output_attentions=output_attentions, |
|
output_hidden_states=True, |
|
return_dict=return_dict, |
|
) |
|
|
|
|
|
if self.layer_matryoshka_loss: |
|
prediction_scores = [] |
|
seq_relationship_score = [] |
|
|
|
for counter,idx_layer in enumerate(self.matryoshka_layers): |
|
|
|
|
|
pooled_output = self.pooler(outputs.hidden_states[idx_layer]) |
|
|
|
sequence_output = outputs.hidden_states[idx_layer] |
|
temp_prediction_scores, temp_seq_relationship_score = self.cls(sequence_output, pooled_output,counter) |
|
prediction_scores.append(temp_prediction_scores) |
|
seq_relationship_score.append(temp_seq_relationship_score) |
|
else: |
|
|
|
pooled_output = self.pooler(outputs.last_hidden_state) |
|
|
|
sequence_output = outputs.last_hidden_state |
|
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) |
|
|
|
total_loss = None |
|
if labels is not None and next_sentence_label is not None and not self.layer_matryoshka_loss: |
|
loss_fct = CrossEntropyLoss() |
|
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) |
|
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) |
|
total_loss = masked_lm_loss + next_sentence_loss |
|
|
|
elif labels is not None and next_sentence_label is not None and self.layer_matryoshka_loss: |
|
loss_fct = CrossEntropyLoss() |
|
num_layers = len(prediction_scores) |
|
|
|
|
|
for index in range(num_layers): |
|
masked_lm_loss = loss_fct(prediction_scores[index].view(-1, self.config.vocab_size), labels.view(-1)) |
|
next_sentence_loss = loss_fct(seq_relationship_score[index].view(-1, 2), next_sentence_label.view(-1)) |
|
if total_loss: |
|
total_loss += (masked_lm_loss + next_sentence_loss) * ((index+1)/num_layers) |
|
else: |
|
total_loss = (masked_lm_loss + next_sentence_loss) * ((index+1)/num_layers) |
|
|
|
|
|
|
|
|
|
if not return_dict: |
|
output = (prediction_scores, seq_relationship_score) + outputs[2:] |
|
return ((total_loss,) + output) if total_loss is not None else output |
|
|
|
|
|
last_hidden_state= outputs.hidden_states[-1] |
|
|
|
return ModularStarEncoderOutput( |
|
last_hidden_state = last_hidden_state, |
|
hidden_states = outputs.hidden_states, |
|
loss = total_loss, |
|
prediction_logits = prediction_scores, |
|
seq_relationship_logits = seq_relationship_score, |
|
attentions = outputs.attentions, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|