File size: 5,825 Bytes
8410be3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import math
import torch
import torch.nn as nn
from transformers.modeling_utils import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutput
from transformers.generation import GenerationMixin # NEW: Import GenerationMixin
from .configuration_dakitari_instruct import DakitariInstructConfig
class SimpleAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.n_embd = config.n_embd
self.in_proj = nn.Linear(config.n_embd, config.n_embd)
self.out_proj = nn.Linear(config.n_embd, config.n_embd)
def forward(self, x, attention_mask=None):
B, L, D = x.size() # batch, length, dimension
q = k = v = self.in_proj(x)
# Compute attention scores
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(D)
if attention_mask is not None:
# Expand mask to correct shape
attention_mask = attention_mask.view(B, 1, L)
scores = scores.masked_fill(~attention_mask, float('-inf'))
attn = torch.softmax(scores, dim=-1)
context = torch.matmul(attn, v)
return self.out_proj(context)
class CustomTransformerLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.attention = SimpleAttention(config)
self.linear1 = nn.Linear(config.n_embd, config.n_embd) # Keep dimensions consistent
self.linear2 = nn.Linear(config.n_embd, config.n_embd) # Keep dimensions consistent
self.norm1 = nn.LayerNorm(config.n_embd)
self.norm2 = nn.LayerNorm(config.n_embd)
self.dropout = nn.Dropout(config.resid_pdrop)
self.activation = nn.GELU()
# New: Adapter layers for domain-specific finetuning
self.adapter_down = nn.Linear(config.n_embd, config.adapter_bottleneck)
self.adapter_up = nn.Linear(config.adapter_bottleneck, config.n_embd)
self.norm_adapter = nn.LayerNorm(config.n_embd)
def forward(self, x, attention_mask=None):
residual = x
x_norm = self.norm1(x)
x_attn = self.attention(x_norm, attention_mask)
x = residual + self.dropout(x_attn)
residual = x
x_norm = self.norm2(x)
x_ff = self.linear2(self.dropout(self.activation(self.linear1(x_norm))))
x = residual + self.dropout(x_ff)
# New: Adapter branch (only train adapter layers)
adapter_input = self.norm_adapter(x)
adapter_out = self.adapter_up(self.adapter_down(adapter_input))
x = x + adapter_out
return x
class DakitariInstructModel(PreTrainedModel, GenerationMixin): # Updated: Inherit from GenerationMixin
config_class = DakitariInstructConfig
def __init__(self, config):
super().__init__(config)
self.config = config
# Update embeddings with new dimensions
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
self.wpe = nn.Embedding(config.n_positions, config.n_embd)
self.drop = nn.Dropout(config.embd_pdrop)
# Build transformer layers based on new n_layer value and dimensions
self.layers = nn.ModuleList([
CustomTransformerLayer(config)
for _ in range(config.n_layer)
])
self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
# New: LM head for generation
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.apply(self._init_weights)
# Ensure adapter layers are not frozen (already commented out freezing)
for name, param in self.named_parameters():
if "adapter" in name:
param.requires_grad = True # Explicit: make sure adapters train
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def forward(self, input_ids=None, attention_mask=None, position_ids=None, **kwargs):
if input_ids is None:
raise ValueError("input_ids must be provided")
# Ensure input_ids are within bounds
input_ids = torch.clamp(input_ids, 0, self.config.vocab_size - 1)
input_shape = input_ids.shape
batch_size, seq_length = input_shape
# Handle position IDs correctly
if position_ids is None:
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand(batch_size, seq_length)
# Embeddings
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
hidden_states = self.drop(inputs_embeds + position_embeds)
# Ensure attention mask is bool tensor
if attention_mask is not None:
attention_mask = attention_mask.bool()
# Process through transformer layers
for layer in self.layers:
hidden_states = layer(hidden_states, attention_mask)
hidden_states = self.ln_f(hidden_states)
logits = self.lm_head(hidden_states)
# NEW: Return a CausalLMOutput with logits attribute so generate() works correctly
return CausalLMOutput(logits=logits, loss=logits.mean())
# NEW: Override generation input preparation
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {"input_ids": input_ids} |