|
import math |
|
import torch |
|
import torch.nn as nn |
|
from transformers.modeling_utils import PreTrainedModel |
|
from transformers.modeling_outputs import CausalLMOutput |
|
from transformers.generation import GenerationMixin |
|
from .configuration_dakitari_instruct import DakitariInstructConfig |
|
|
|
class SimpleAttention(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.n_embd = config.n_embd |
|
self.in_proj = nn.Linear(config.n_embd, config.n_embd) |
|
self.out_proj = nn.Linear(config.n_embd, config.n_embd) |
|
|
|
def forward(self, x, attention_mask=None): |
|
B, L, D = x.size() |
|
q = k = v = self.in_proj(x) |
|
|
|
|
|
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(D) |
|
|
|
if attention_mask is not None: |
|
|
|
attention_mask = attention_mask.view(B, 1, L) |
|
scores = scores.masked_fill(~attention_mask, float('-inf')) |
|
|
|
attn = torch.softmax(scores, dim=-1) |
|
context = torch.matmul(attn, v) |
|
return self.out_proj(context) |
|
|
|
class CustomTransformerLayer(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.attention = SimpleAttention(config) |
|
self.linear1 = nn.Linear(config.n_embd, config.n_embd) |
|
self.linear2 = nn.Linear(config.n_embd, config.n_embd) |
|
self.norm1 = nn.LayerNorm(config.n_embd) |
|
self.norm2 = nn.LayerNorm(config.n_embd) |
|
self.dropout = nn.Dropout(config.resid_pdrop) |
|
self.activation = nn.GELU() |
|
|
|
self.adapter_down = nn.Linear(config.n_embd, config.adapter_bottleneck) |
|
self.adapter_up = nn.Linear(config.adapter_bottleneck, config.n_embd) |
|
self.norm_adapter = nn.LayerNorm(config.n_embd) |
|
|
|
def forward(self, x, attention_mask=None): |
|
residual = x |
|
x_norm = self.norm1(x) |
|
x_attn = self.attention(x_norm, attention_mask) |
|
x = residual + self.dropout(x_attn) |
|
|
|
residual = x |
|
x_norm = self.norm2(x) |
|
x_ff = self.linear2(self.dropout(self.activation(self.linear1(x_norm)))) |
|
x = residual + self.dropout(x_ff) |
|
|
|
|
|
adapter_input = self.norm_adapter(x) |
|
adapter_out = self.adapter_up(self.adapter_down(adapter_input)) |
|
x = x + adapter_out |
|
|
|
return x |
|
|
|
class DakitariInstructModel(PreTrainedModel, GenerationMixin): |
|
config_class = DakitariInstructConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.config = config |
|
|
|
|
|
self.wte = nn.Embedding(config.vocab_size, config.n_embd) |
|
self.wpe = nn.Embedding(config.n_positions, config.n_embd) |
|
self.drop = nn.Dropout(config.embd_pdrop) |
|
|
|
|
|
self.layers = nn.ModuleList([ |
|
CustomTransformerLayer(config) |
|
for _ in range(config.n_layer) |
|
]) |
|
|
|
self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) |
|
|
|
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) |
|
|
|
self.apply(self._init_weights) |
|
|
|
for name, param in self.named_parameters(): |
|
if "adapter" in name: |
|
param.requires_grad = True |
|
|
|
def _init_weights(self, module): |
|
if isinstance(module, (nn.Linear, nn.Embedding)): |
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
if isinstance(module, nn.Linear) and module.bias is not None: |
|
module.bias.data.zero_() |
|
elif isinstance(module, nn.LayerNorm): |
|
module.bias.data.zero_() |
|
module.weight.data.fill_(1.0) |
|
|
|
def forward(self, input_ids=None, attention_mask=None, position_ids=None, **kwargs): |
|
if input_ids is None: |
|
raise ValueError("input_ids must be provided") |
|
|
|
|
|
input_ids = torch.clamp(input_ids, 0, self.config.vocab_size - 1) |
|
|
|
input_shape = input_ids.shape |
|
batch_size, seq_length = input_shape |
|
|
|
|
|
if position_ids is None: |
|
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) |
|
position_ids = position_ids.unsqueeze(0).expand(batch_size, seq_length) |
|
|
|
|
|
inputs_embeds = self.wte(input_ids) |
|
position_embeds = self.wpe(position_ids) |
|
hidden_states = self.drop(inputs_embeds + position_embeds) |
|
|
|
|
|
if attention_mask is not None: |
|
attention_mask = attention_mask.bool() |
|
|
|
|
|
for layer in self.layers: |
|
hidden_states = layer(hidden_states, attention_mask) |
|
|
|
hidden_states = self.ln_f(hidden_states) |
|
logits = self.lm_head(hidden_states) |
|
|
|
return CausalLMOutput(logits=logits, loss=logits.mean()) |
|
|
|
|
|
def prepare_inputs_for_generation(self, input_ids, **kwargs): |
|
return {"input_ids": input_ids} |