nozomuteruyo14 commited on
Commit
9b5d8a8
·
verified ·
1 Parent(s): 243299b

Create model.py

Browse files
Files changed (1) hide show
  1. diff_lora/model.py +83 -0
diff_lora/model.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ class DiffLoRALinear(nn.Module):
8
+ """
9
+ Fused DiffLoRALinear implements a differential low-rank adapter:
10
+ Δy = (α/r) * [A_pos @ B_pos - τ * (A_neg @ B_neg)]
11
+ The fused version computes:
12
+ update = x_dropped @ concat(A_pos, A_neg) @ concat(B_pos, -τ * B_neg)
13
+ This version explicitly moves τ to the same device as the input.
14
+ """
15
+ def __init__(self, in_features: int, out_features: int, r: int = 8,
16
+ lora_alpha: float = 16.0, dropout: float = 0.0,
17
+ merge_weights: bool = False, init_method: str = "kaiming"):
18
+ super().__init__()
19
+ # Base linear layer (frozen)
20
+ self.linear = nn.Linear(in_features, out_features, bias=False)
21
+ self.linear.weight.requires_grad = False
22
+ self.in_features = in_features
23
+ self.out_features = out_features
24
+ self.r = r
25
+ self.scaling = lora_alpha / r
26
+ self.merge_weights = merge_weights
27
+ self.merged = False
28
+ self.lora_dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity()
29
+ # Low-rank parameters for positive and negative components
30
+ self.A_pos = nn.Parameter(torch.zeros(in_features, r))
31
+ self.B_pos = nn.Parameter(torch.zeros(r, out_features))
32
+ self.A_neg = nn.Parameter(torch.zeros(in_features, r))
33
+ self.B_neg = nn.Parameter(torch.zeros(r, out_features))
34
+ self.tau = nn.Parameter(torch.tensor(1.0)) # Scalar parameter
35
+ self.reset_parameters(init_method)
36
+
37
+ def reset_parameters(self, init_method: str = "kaiming"):
38
+ if init_method == "kaiming":
39
+ nn.init.kaiming_uniform_(self.A_pos, a=math.sqrt(5))
40
+ nn.init.zeros_(self.B_pos)
41
+ nn.init.kaiming_uniform_(self.A_neg, a=math.sqrt(5))
42
+ nn.init.zeros_(self.B_neg)
43
+ elif init_method == "xavier":
44
+ nn.init.xavier_uniform_(self.A_pos)
45
+ nn.init.zeros_(self.B_pos)
46
+ nn.init.xavier_uniform_(self.A_neg)
47
+ nn.init.zeros_(self.B_neg)
48
+ else:
49
+ raise ValueError(f"Unknown init_method: {init_method}")
50
+
51
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
52
+ if self.merge_weights and self.merged:
53
+ return self.linear(x)
54
+ base_out = self.linear(x)
55
+ x_dropped = self.lora_dropout(x)
56
+ # Ensure tau is on the same device as x
57
+ tau = self.tau.to(x.device)
58
+ # Concatenate positive and negative parameters along the rank dimension
59
+ combined_A = torch.cat([self.A_pos, self.A_neg], dim=1) # (in_features, 2*r)
60
+ combined_B = torch.cat([self.B_pos, -tau * self.B_neg], dim=0) # (2*r, out_features)
61
+ update = x_dropped @ combined_A @ combined_B
62
+ delta = self.scaling * update
63
+ return base_out + delta
64
+
65
+ def replace_linear_with_diff_lora(module: nn.Module, target_regex: str, r: int):
66
+ """
67
+ Recursively replace nn.Linear modules whose names match target_regex
68
+ with DiffLoRALinear modules using rank r.
69
+ """
70
+ for name, child in module.named_children():
71
+ if isinstance(child, nn.Linear) and re.search(target_regex, name, re.IGNORECASE):
72
+ new_layer = DiffLoRALinear(
73
+ in_features=child.in_features,
74
+ out_features=child.out_features,
75
+ r=r,
76
+ lora_alpha=16.0,
77
+ dropout=0.1,
78
+ merge_weights=False,
79
+ )
80
+ new_layer.linear.weight.data.copy_(child.weight.data)
81
+ setattr(module, name, new_layer)
82
+ else:
83
+ replace_linear_with_diff_lora(child, target_regex, r)