import math import torch import torch.nn as nn from transformers.modeling_utils import PreTrainedModel from transformers.modeling_outputs import CausalLMOutput from transformers.generation import GenerationMixin # NEW: Import GenerationMixin from .configuration_dakitari_instruct import DakitariInstructConfig class SimpleAttention(nn.Module): def __init__(self, config): super().__init__() self.n_embd = config.n_embd self.in_proj = nn.Linear(config.n_embd, config.n_embd) self.out_proj = nn.Linear(config.n_embd, config.n_embd) def forward(self, x, attention_mask=None): B, L, D = x.size() # batch, length, dimension q = k = v = self.in_proj(x) # Compute attention scores scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(D) if attention_mask is not None: # Expand mask to correct shape attention_mask = attention_mask.view(B, 1, L) scores = scores.masked_fill(~attention_mask, float('-inf')) attn = torch.softmax(scores, dim=-1) context = torch.matmul(attn, v) return self.out_proj(context) class CustomTransformerLayer(nn.Module): def __init__(self, config): super().__init__() self.attention = SimpleAttention(config) self.linear1 = nn.Linear(config.n_embd, config.n_embd) # Keep dimensions consistent self.linear2 = nn.Linear(config.n_embd, config.n_embd) # Keep dimensions consistent self.norm1 = nn.LayerNorm(config.n_embd) self.norm2 = nn.LayerNorm(config.n_embd) self.dropout = nn.Dropout(config.resid_pdrop) self.activation = nn.GELU() # New: Adapter layers for domain-specific finetuning self.adapter_down = nn.Linear(config.n_embd, config.adapter_bottleneck) self.adapter_up = nn.Linear(config.adapter_bottleneck, config.n_embd) self.norm_adapter = nn.LayerNorm(config.n_embd) def forward(self, x, attention_mask=None): residual = x x_norm = self.norm1(x) x_attn = self.attention(x_norm, attention_mask) x = residual + self.dropout(x_attn) residual = x x_norm = self.norm2(x) x_ff = self.linear2(self.dropout(self.activation(self.linear1(x_norm)))) x = residual + self.dropout(x_ff) # New: Adapter branch (only train adapter layers) adapter_input = self.norm_adapter(x) adapter_out = self.adapter_up(self.adapter_down(adapter_input)) x = x + adapter_out return x class DakitariInstructModel(PreTrainedModel, GenerationMixin): # Updated: Inherit from GenerationMixin config_class = DakitariInstructConfig def __init__(self, config): super().__init__(config) self.config = config # Update embeddings with new dimensions self.wte = nn.Embedding(config.vocab_size, config.n_embd) self.wpe = nn.Embedding(config.n_positions, config.n_embd) self.drop = nn.Dropout(config.embd_pdrop) # Build transformer layers based on new n_layer value and dimensions self.layers = nn.ModuleList([ CustomTransformerLayer(config) for _ in range(config.n_layer) ]) self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) # New: LM head for generation self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) self.apply(self._init_weights) # Ensure adapter layers are not frozen (already commented out freezing) for name, param in self.named_parameters(): if "adapter" in name: param.requires_grad = True # Explicit: make sure adapters train def _init_weights(self, module): if isinstance(module, (nn.Linear, nn.Embedding)): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) def forward(self, input_ids=None, attention_mask=None, position_ids=None, **kwargs): if input_ids is None: raise ValueError("input_ids must be provided") # Ensure input_ids are within bounds input_ids = torch.clamp(input_ids, 0, self.config.vocab_size - 1) input_shape = input_ids.shape batch_size, seq_length = input_shape # Handle position IDs correctly if position_ids is None: position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) position_ids = position_ids.unsqueeze(0).expand(batch_size, seq_length) # Embeddings inputs_embeds = self.wte(input_ids) position_embeds = self.wpe(position_ids) hidden_states = self.drop(inputs_embeds + position_embeds) # Ensure attention mask is bool tensor if attention_mask is not None: attention_mask = attention_mask.bool() # Process through transformer layers for layer in self.layers: hidden_states = layer(hidden_states, attention_mask) hidden_states = self.ln_f(hidden_states) logits = self.lm_head(hidden_states) # NEW: Return a CausalLMOutput with logits attribute so generate() works correctly return CausalLMOutput(logits=logits, loss=logits.mean()) # NEW: Override generation input preparation def prepare_inputs_for_generation(self, input_ids, **kwargs): return {"input_ids": input_ids}