|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from dataclasses import dataclass |
|
from typing import Dict, List, Optional, Tuple |
|
|
|
from transformers import PreTrainedModel |
|
from transformers.utils import ModelOutput |
|
|
|
from .configuration_roberta_zinc_compression_encoder import RZCompressionConfig |
|
|
|
|
|
def pairwise_cosine(x: torch.Tensor) -> torch.Tensor: |
|
x = F.normalize(x, p=2, dim=-1) |
|
return x @ x.t() |
|
|
|
|
|
def drop_diag(M: torch.Tensor) -> torch.Tensor: |
|
n = M.size(0) |
|
return M.masked_select(~torch.eye(n, dtype=torch.bool, device=M.device)).view(n, n - 1) |
|
|
|
|
|
def rowwise_pearson(ref: torch.Tensor, comp: torch.Tensor, rm_diag: bool=True) -> torch.Tensor: |
|
if rm_diag: |
|
ref = drop_diag(ref) |
|
comp = drop_diag(comp) |
|
ref_z = F.normalize(ref - ref.mean(dim=1, keepdim=True), p=2, dim=1) |
|
cmp_z = F.normalize(comp - comp.mean(dim=1, keepdim=True), p=2, dim=1) |
|
return 1 - (ref_z * cmp_z).sum(dim=1).mean() |
|
|
|
|
|
def compute_losses( |
|
embedding: torch.Tensor, |
|
compressed: Dict[int, torch.Tensor], |
|
recon_stack: torch.Tensor | None, |
|
cfg, |
|
) -> tuple[torch.Tensor, dict[str, float]]: |
|
"""Return (total_loss, terms_dict)""" |
|
device = embedding.device |
|
loss_total = torch.zeros((), device=device) |
|
terms: dict[str, float] = {} |
|
|
|
|
|
with torch.no_grad(): |
|
base_sims = pairwise_cosine(embedding) |
|
ranks = base_sims.argsort(-1, descending=True) |
|
|
|
|
|
|
|
|
|
for size, z in compressed.items(): |
|
tag = f"cmp{size}" |
|
comp_sims = pairwise_cosine(z) |
|
|
|
|
|
if cfg.mse_loss_weight: |
|
mse = F.mse_loss(drop_diag(base_sims), drop_diag(comp_sims)) |
|
loss_total += cfg.mse_loss_weight * mse |
|
terms[f"{tag}_mse"] = mse.detach() |
|
|
|
|
|
if cfg.mse_loss_weight and cfg.topk_values: |
|
tk_vals = [] |
|
for k in cfg.topk_values: |
|
idx = ranks[:, 1 : k + 1] |
|
ref_k = torch.gather(base_sims, 1, idx) |
|
cmp_k = torch.gather(comp_sims, 1, idx) |
|
tk_mse = F.mse_loss(ref_k, cmp_k) |
|
tk_vals.append(tk_mse) |
|
terms[f"{tag}_top{k}"] = tk_mse.detach() |
|
tk_agg = torch.stack(tk_vals).mean() |
|
loss_total += cfg.topk_mse_loss_weight * tk_agg |
|
terms[f"{tag}_topk_mean"] = tk_agg.detach() |
|
|
|
|
|
if cfg.pearson_loss_weight: |
|
pr = rowwise_pearson(base_sims, comp_sims) |
|
loss_total += cfg.pearson_loss_weight * pr |
|
terms[f"{tag}_pearson"] = pr.detach() |
|
|
|
if cfg.pearson_loss_weight and cfg.topk_values: |
|
pr_vals = [] |
|
for k in cfg.topk_values: |
|
idx = ranks[:, 1 : k + 1] |
|
ref_k = torch.gather(base_sims, 1, idx) |
|
cmp_k = torch.gather(comp_sims, 1, idx) |
|
pr = rowwise_pearson(ref_k, cmp_k, rm_diag=False) |
|
pr_vals.append(pr) |
|
terms[f"{tag}_pearson_top{k}"] = pr.detach() |
|
pr_agg = torch.stack(pr_vals).sum() |
|
loss_total += cfg.pearson_loss_weight * pr_agg |
|
|
|
|
|
|
|
|
|
if recon_stack is not None: |
|
|
|
if cfg.decoder_cosine_weight: |
|
cos_loss = 1 - F.cosine_similarity( |
|
recon_stack, |
|
embedding.unsqueeze(1).expand_as(recon_stack), |
|
dim=-1, |
|
).mean() |
|
loss_total += cfg.decoder_cosine_weight * cos_loss |
|
terms["dec_cosine"] = cos_loss.detach() |
|
|
|
return loss_total, terms |
|
|
|
|
|
|
|
class FeedForward(nn.Module): |
|
def __init__(self, d_in: int, d_out: int): |
|
super().__init__() |
|
self.fc1 = nn.Linear(d_in, d_out * 2) |
|
self.fc2 = nn.Linear(d_out, d_out) |
|
|
|
def forward(self, x): |
|
x = self.fc1(x) |
|
x1, x2 = x.chunk(2, dim=-1) |
|
return self.fc2(F.silu(x1) * x2) |
|
|
|
class FeedForwardLayer(nn.Module): |
|
def __init__( |
|
self, d_in: int, d_out: int, dropout: float = 0.1, layer_norm_eps: Optional[float] = 1e-12 |
|
): |
|
super().__init__() |
|
self.ff = FeedForward(d_in, d_out) |
|
self.skip = nn.Linear(d_in, d_out) if d_in != d_out else nn.Identity() |
|
self.dropout = nn.Dropout(dropout) |
|
self.norm = ( |
|
nn.LayerNorm(d_out, eps=layer_norm_eps) |
|
if layer_norm_eps is not None else nn.Identity() |
|
) |
|
|
|
def forward(self, x): |
|
y = self.ff(self.dropout(x)) + self.skip(x) |
|
return self.norm(y) |
|
|
|
|
|
|
|
class CompressionModel(nn.Module): |
|
""" |
|
Encoder β (optional) Decoder. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
d_in: int, |
|
d_comp: int, |
|
encoder_layers: int, |
|
decoder_layers: int, |
|
dropout: float, |
|
layer_norm_eps: Optional[float], |
|
): |
|
super().__init__() |
|
|
|
enc_layers: List[nn.Module] = [] |
|
for i in range(encoder_layers): |
|
last = i == encoder_layers - 1 |
|
enc_layers.append( |
|
FeedForwardLayer( |
|
d_in, |
|
d_comp if last else d_in, |
|
dropout if not last else 0.0, |
|
None if last else layer_norm_eps, |
|
) |
|
) |
|
self.encoder = nn.Sequential(*enc_layers) |
|
|
|
|
|
dec_layers: List[nn.Module] = [] |
|
for i in range(decoder_layers): |
|
last = i == decoder_layers - 1 |
|
d_prev = d_comp if i==0 else d_in |
|
dec_layers.append( |
|
FeedForwardLayer( |
|
d_prev, |
|
d_in, |
|
dropout if not last else 0.0, |
|
None if last else layer_norm_eps, |
|
) |
|
) |
|
self.decoder = nn.Sequential(*dec_layers) if dec_layers else None |
|
|
|
def forward(self, x) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
|
z = self.encoder(x) |
|
x_recon = self.decoder(z) if self.decoder is not None else None |
|
return z, x_recon |
|
|
|
|
|
|
|
@dataclass |
|
class RZCompressionOutput(ModelOutput): |
|
loss: torch.FloatTensor |
|
loss_terms: Dict[str, torch.Tensor] | None = None |
|
compressed: Dict[int, torch.FloatTensor] | None = None |
|
reconstructed: torch.FloatTensor | None = None |
|
|
|
class RZCompressionModel(PreTrainedModel): |
|
config_class = RZCompressionConfig |
|
|
|
def __init__(self, config: RZCompressionConfig): |
|
super().__init__(config) |
|
|
|
self.compressors = nn.ModuleDict( |
|
{ |
|
str(size): CompressionModel( |
|
d_in=config.input_size, |
|
d_comp=size, |
|
encoder_layers=config.encoder_layers, |
|
decoder_layers=config.decoder_layers, |
|
dropout=config.dropout, |
|
layer_norm_eps=config.layer_norm_eps, |
|
) |
|
for size in config.compression_sizes |
|
} |
|
) |
|
|
|
self.post_init() |
|
|
|
def get_encoders(self, unpack_single=False): |
|
encoders = {} |
|
for k,v in self.compressors.items(): |
|
v = v.encoder |
|
if len(v)==1 and unpack_single: |
|
|
|
v = v[0] |
|
encoders[k] = v |
|
encoders = nn.ModuleDict(encoders) |
|
return encoders |
|
|
|
def save_encoders(self, path, unpack_single=False): |
|
encoders = self.get_encoders(unpack_single) |
|
torch.save(encoders.state_dict(), path) |
|
|
|
def compress(self, |
|
inputs: torch.Tensor, |
|
compression_sizes: List[int]): |
|
compressed = {d: self.compressors[str(d)].encoder(inputs) for d in compression_sizes} |
|
return compressed |
|
|
|
def forward(self, embedding, return_dict=True, compute_loss=True): |
|
|
|
compressed, recons = {}, [] |
|
for size, module in self.compressors.items(): |
|
z, rec = module(embedding) |
|
compressed[int(size)] = z |
|
if rec is not None: |
|
recons.append(rec) |
|
recon_stack = torch.stack(recons, dim=1) if recons else None |
|
|
|
|
|
if compute_loss: |
|
loss_total, terms = compute_losses(embedding, compressed, recon_stack, self.config) |
|
else: |
|
loss_total, terms = torch.zeros((), device=embedding.device), {} |
|
|
|
if not return_dict: |
|
return compressed, recon_stack, loss_total, terms |
|
|
|
return RZCompressionOutput( |
|
loss=loss_total, |
|
loss_terms=terms, |
|
compressed=compressed, |
|
reconstructed=recon_stack, |
|
) |