Diff_LoRA / diff_lora /model.py
nozomuteruyo14's picture
Create model.py
9b5d8a8 verified
import re
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class DiffLoRALinear(nn.Module):
"""
Fused DiffLoRALinear implements a differential low-rank adapter:
Δy = (α/r) * [A_pos @ B_pos - τ * (A_neg @ B_neg)]
The fused version computes:
update = x_dropped @ concat(A_pos, A_neg) @ concat(B_pos, -τ * B_neg)
This version explicitly moves τ to the same device as the input.
"""
def __init__(self, in_features: int, out_features: int, r: int = 8,
lora_alpha: float = 16.0, dropout: float = 0.0,
merge_weights: bool = False, init_method: str = "kaiming"):
super().__init__()
# Base linear layer (frozen)
self.linear = nn.Linear(in_features, out_features, bias=False)
self.linear.weight.requires_grad = False
self.in_features = in_features
self.out_features = out_features
self.r = r
self.scaling = lora_alpha / r
self.merge_weights = merge_weights
self.merged = False
self.lora_dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity()
# Low-rank parameters for positive and negative components
self.A_pos = nn.Parameter(torch.zeros(in_features, r))
self.B_pos = nn.Parameter(torch.zeros(r, out_features))
self.A_neg = nn.Parameter(torch.zeros(in_features, r))
self.B_neg = nn.Parameter(torch.zeros(r, out_features))
self.tau = nn.Parameter(torch.tensor(1.0)) # Scalar parameter
self.reset_parameters(init_method)
def reset_parameters(self, init_method: str = "kaiming"):
if init_method == "kaiming":
nn.init.kaiming_uniform_(self.A_pos, a=math.sqrt(5))
nn.init.zeros_(self.B_pos)
nn.init.kaiming_uniform_(self.A_neg, a=math.sqrt(5))
nn.init.zeros_(self.B_neg)
elif init_method == "xavier":
nn.init.xavier_uniform_(self.A_pos)
nn.init.zeros_(self.B_pos)
nn.init.xavier_uniform_(self.A_neg)
nn.init.zeros_(self.B_neg)
else:
raise ValueError(f"Unknown init_method: {init_method}")
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.merge_weights and self.merged:
return self.linear(x)
base_out = self.linear(x)
x_dropped = self.lora_dropout(x)
# Ensure tau is on the same device as x
tau = self.tau.to(x.device)
# Concatenate positive and negative parameters along the rank dimension
combined_A = torch.cat([self.A_pos, self.A_neg], dim=1) # (in_features, 2*r)
combined_B = torch.cat([self.B_pos, -tau * self.B_neg], dim=0) # (2*r, out_features)
update = x_dropped @ combined_A @ combined_B
delta = self.scaling * update
return base_out + delta
def replace_linear_with_diff_lora(module: nn.Module, target_regex: str, r: int):
"""
Recursively replace nn.Linear modules whose names match target_regex
with DiffLoRALinear modules using rank r.
"""
for name, child in module.named_children():
if isinstance(child, nn.Linear) and re.search(target_regex, name, re.IGNORECASE):
new_layer = DiffLoRALinear(
in_features=child.in_features,
out_features=child.out_features,
r=r,
lora_alpha=16.0,
dropout=0.1,
merge_weights=False,
)
new_layer.linear.weight.data.copy_(child.weight.data)
setattr(module, name, new_layer)
else:
replace_linear_with_diff_lora(child, target_regex, r)