File size: 5,825 Bytes
8410be3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import math
import torch
import torch.nn as nn
from transformers.modeling_utils import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutput
from transformers.generation import GenerationMixin  # NEW: Import GenerationMixin
from .configuration_dakitari_instruct import DakitariInstructConfig

class SimpleAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_embd = config.n_embd
        self.in_proj = nn.Linear(config.n_embd, config.n_embd)
        self.out_proj = nn.Linear(config.n_embd, config.n_embd)
        
    def forward(self, x, attention_mask=None):
        B, L, D = x.size()  # batch, length, dimension
        q = k = v = self.in_proj(x)
        
        # Compute attention scores
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(D)
        
        if attention_mask is not None:
            # Expand mask to correct shape
            attention_mask = attention_mask.view(B, 1, L)
            scores = scores.masked_fill(~attention_mask, float('-inf'))
        
        attn = torch.softmax(scores, dim=-1)
        context = torch.matmul(attn, v)
        return self.out_proj(context)

class CustomTransformerLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.attention = SimpleAttention(config)
        self.linear1 = nn.Linear(config.n_embd, config.n_embd)  # Keep dimensions consistent
        self.linear2 = nn.Linear(config.n_embd, config.n_embd)  # Keep dimensions consistent
        self.norm1 = nn.LayerNorm(config.n_embd)
        self.norm2 = nn.LayerNorm(config.n_embd)
        self.dropout = nn.Dropout(config.resid_pdrop)
        self.activation = nn.GELU()
        # New: Adapter layers for domain-specific finetuning
        self.adapter_down = nn.Linear(config.n_embd, config.adapter_bottleneck)
        self.adapter_up = nn.Linear(config.adapter_bottleneck, config.n_embd)
        self.norm_adapter = nn.LayerNorm(config.n_embd)

    def forward(self, x, attention_mask=None):
        residual = x
        x_norm = self.norm1(x)
        x_attn = self.attention(x_norm, attention_mask)
        x = residual + self.dropout(x_attn)
        
        residual = x
        x_norm = self.norm2(x)
        x_ff = self.linear2(self.dropout(self.activation(self.linear1(x_norm))))
        x = residual + self.dropout(x_ff)
        
        # New: Adapter branch (only train adapter layers)
        adapter_input = self.norm_adapter(x)
        adapter_out = self.adapter_up(self.adapter_down(adapter_input))
        x = x + adapter_out

        return x

class DakitariInstructModel(PreTrainedModel, GenerationMixin):  # Updated: Inherit from GenerationMixin
    config_class = DakitariInstructConfig
    
    def __init__(self, config):
        super().__init__(config)
        self.config = config
        
        # Update embeddings with new dimensions
        self.wte = nn.Embedding(config.vocab_size, config.n_embd)
        self.wpe = nn.Embedding(config.n_positions, config.n_embd)
        self.drop = nn.Dropout(config.embd_pdrop)
        
        # Build transformer layers based on new n_layer value and dimensions
        self.layers = nn.ModuleList([
            CustomTransformerLayer(config)
            for _ in range(config.n_layer)
        ])
        
        self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
        # New: LM head for generation
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        
        self.apply(self._init_weights)
        # Ensure adapter layers are not frozen (already commented out freezing)
        for name, param in self.named_parameters():
            if "adapter" in name:
                param.requires_grad = True  # Explicit: make sure adapters train
        
    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
            
    def forward(self, input_ids=None, attention_mask=None, position_ids=None, **kwargs):
        if input_ids is None:
            raise ValueError("input_ids must be provided")
            
        # Ensure input_ids are within bounds
        input_ids = torch.clamp(input_ids, 0, self.config.vocab_size - 1)
        
        input_shape = input_ids.shape
        batch_size, seq_length = input_shape
        
        # Handle position IDs correctly
        if position_ids is None:
            position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
            position_ids = position_ids.unsqueeze(0).expand(batch_size, seq_length)
            
        # Embeddings
        inputs_embeds = self.wte(input_ids)
        position_embeds = self.wpe(position_ids)
        hidden_states = self.drop(inputs_embeds + position_embeds)
        
        # Ensure attention mask is bool tensor
        if attention_mask is not None:
            attention_mask = attention_mask.bool()
        
        # Process through transformer layers
        for layer in self.layers:
            hidden_states = layer(hidden_states, attention_mask)
            
        hidden_states = self.ln_f(hidden_states)
        logits = self.lm_head(hidden_states)
        # NEW: Return a CausalLMOutput with logits attribute so generate() works correctly
        return CausalLMOutput(logits=logits, loss=logits.mean())
    
    # NEW: Override generation input preparation    
    def prepare_inputs_for_generation(self, input_ids, **kwargs):
        return {"input_ids": input_ids}