|
import re |
|
import math |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
class DiffLoRALinear(nn.Module): |
|
""" |
|
Fused DiffLoRALinear implements a differential low-rank adapter: |
|
Δy = (α/r) * [A_pos @ B_pos - τ * (A_neg @ B_neg)] |
|
The fused version computes: |
|
update = x_dropped @ concat(A_pos, A_neg) @ concat(B_pos, -τ * B_neg) |
|
This version explicitly moves τ to the same device as the input. |
|
""" |
|
def __init__(self, in_features: int, out_features: int, r: int = 8, |
|
lora_alpha: float = 16.0, dropout: float = 0.0, |
|
merge_weights: bool = False, init_method: str = "kaiming"): |
|
super().__init__() |
|
|
|
self.linear = nn.Linear(in_features, out_features, bias=False) |
|
self.linear.weight.requires_grad = False |
|
self.in_features = in_features |
|
self.out_features = out_features |
|
self.r = r |
|
self.scaling = lora_alpha / r |
|
self.merge_weights = merge_weights |
|
self.merged = False |
|
self.lora_dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity() |
|
|
|
self.A_pos = nn.Parameter(torch.zeros(in_features, r)) |
|
self.B_pos = nn.Parameter(torch.zeros(r, out_features)) |
|
self.A_neg = nn.Parameter(torch.zeros(in_features, r)) |
|
self.B_neg = nn.Parameter(torch.zeros(r, out_features)) |
|
self.tau = nn.Parameter(torch.tensor(1.0)) |
|
self.reset_parameters(init_method) |
|
|
|
def reset_parameters(self, init_method: str = "kaiming"): |
|
if init_method == "kaiming": |
|
nn.init.kaiming_uniform_(self.A_pos, a=math.sqrt(5)) |
|
nn.init.zeros_(self.B_pos) |
|
nn.init.kaiming_uniform_(self.A_neg, a=math.sqrt(5)) |
|
nn.init.zeros_(self.B_neg) |
|
elif init_method == "xavier": |
|
nn.init.xavier_uniform_(self.A_pos) |
|
nn.init.zeros_(self.B_pos) |
|
nn.init.xavier_uniform_(self.A_neg) |
|
nn.init.zeros_(self.B_neg) |
|
else: |
|
raise ValueError(f"Unknown init_method: {init_method}") |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
if self.merge_weights and self.merged: |
|
return self.linear(x) |
|
base_out = self.linear(x) |
|
x_dropped = self.lora_dropout(x) |
|
|
|
tau = self.tau.to(x.device) |
|
|
|
combined_A = torch.cat([self.A_pos, self.A_neg], dim=1) |
|
combined_B = torch.cat([self.B_pos, -tau * self.B_neg], dim=0) |
|
update = x_dropped @ combined_A @ combined_B |
|
delta = self.scaling * update |
|
return base_out + delta |
|
|
|
def replace_linear_with_diff_lora(module: nn.Module, target_regex: str, r: int): |
|
""" |
|
Recursively replace nn.Linear modules whose names match target_regex |
|
with DiffLoRALinear modules using rank r. |
|
""" |
|
for name, child in module.named_children(): |
|
if isinstance(child, nn.Linear) and re.search(target_regex, name, re.IGNORECASE): |
|
new_layer = DiffLoRALinear( |
|
in_features=child.in_features, |
|
out_features=child.out_features, |
|
r=r, |
|
lora_alpha=16.0, |
|
dropout=0.1, |
|
merge_weights=False, |
|
) |
|
new_layer.linear.weight.data.copy_(child.weight.data) |
|
setattr(module, name, new_layer) |
|
else: |
|
replace_linear_with_diff_lora(child, target_regex, r) |
|
|