roberta_zinc_compression_encoder / modeling_roberta_zinc_compression_encoder.py
entropy's picture
Upload model
bbc7ddf verified
raw
history blame contribute delete
10.3 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
from transformers import PreTrainedModel
from transformers.utils import ModelOutput
from .configuration_roberta_zinc_compression_encoder import RZCompressionConfig
# pairwise cosine ----------------------------------------------------------------
def pairwise_cosine(x: torch.Tensor) -> torch.Tensor:
x = F.normalize(x, p=2, dim=-1)
return x @ x.t() # [B, B]
# remove diagonal -----------------------------------------------------------------
def drop_diag(M: torch.Tensor) -> torch.Tensor:
n = M.size(0)
return M.masked_select(~torch.eye(n, dtype=torch.bool, device=M.device)).view(n, n - 1)
# pearson row-wise ----------------------------------------------------------------
def rowwise_pearson(ref: torch.Tensor, comp: torch.Tensor, rm_diag: bool=True) -> torch.Tensor:
if rm_diag:
ref = drop_diag(ref)
comp = drop_diag(comp)
ref_z = F.normalize(ref - ref.mean(dim=1, keepdim=True), p=2, dim=1)
cmp_z = F.normalize(comp - comp.mean(dim=1, keepdim=True), p=2, dim=1)
return 1 - (ref_z * cmp_z).sum(dim=1).mean() # 0 = perfect corr
# aggregate loss ------------------------------------------------------------------
def compute_losses(
embedding: torch.Tensor, # (batch_size, d)
compressed: Dict[int, torch.Tensor], # Dict[size, (batch_size, size)]
recon_stack: torch.Tensor | None, # (batch_size, n_heads, d)
cfg,
) -> tuple[torch.Tensor, dict[str, float]]:
"""Return (total_loss, terms_dict)"""
device = embedding.device
loss_total = torch.zeros((), device=device)
terms: dict[str, float] = {}
# ---- base similarities (detach to save mem) ---------------------------
with torch.no_grad():
base_sims = pairwise_cosine(embedding)
ranks = base_sims.argsort(-1, descending=True)
# ======================================================================
# 1) encoder / compressed losses
# ======================================================================
for size, z in compressed.items():
tag = f"cmp{size}"
comp_sims = pairwise_cosine(z)
# plain MSE --------------------------------------------------------
if cfg.mse_loss_weight:
mse = F.mse_loss(drop_diag(base_sims), drop_diag(comp_sims))
loss_total += cfg.mse_loss_weight * mse
terms[f"{tag}_mse"] = mse.detach()
# top-k MSE --------------------------------------------------------
if cfg.mse_loss_weight and cfg.topk_values:
tk_vals = []
for k in cfg.topk_values:
idx = ranks[:, 1 : k + 1]
ref_k = torch.gather(base_sims, 1, idx)
cmp_k = torch.gather(comp_sims, 1, idx)
tk_mse = F.mse_loss(ref_k, cmp_k)
tk_vals.append(tk_mse)
terms[f"{tag}_top{k}"] = tk_mse.detach()
tk_agg = torch.stack(tk_vals).mean()
loss_total += cfg.topk_mse_loss_weight * tk_agg
terms[f"{tag}_topk_mean"] = tk_agg.detach()
# Pearson ----------------------------------------------------------
if cfg.pearson_loss_weight:
pr = rowwise_pearson(base_sims, comp_sims)
loss_total += cfg.pearson_loss_weight * pr
terms[f"{tag}_pearson"] = pr.detach()
if cfg.pearson_loss_weight and cfg.topk_values:
pr_vals = []
for k in cfg.topk_values:
idx = ranks[:, 1 : k + 1]
ref_k = torch.gather(base_sims, 1, idx)
cmp_k = torch.gather(comp_sims, 1, idx)
pr = rowwise_pearson(ref_k, cmp_k, rm_diag=False)
pr_vals.append(pr)
terms[f"{tag}_pearson_top{k}"] = pr.detach()
pr_agg = torch.stack(pr_vals).sum()
loss_total += cfg.pearson_loss_weight * pr_agg
# ======================================================================
# 2) decoder losses
# ======================================================================
if recon_stack is not None:
# cosine -----------------------------------------------------------
if cfg.decoder_cosine_weight:
cos_loss = 1 - F.cosine_similarity(
recon_stack,
embedding.unsqueeze(1).expand_as(recon_stack),
dim=-1,
).mean()
loss_total += cfg.decoder_cosine_weight * cos_loss
terms["dec_cosine"] = cos_loss.detach()
return loss_total, terms
# ─── basic blocks ───────────────────────────────────────────────
class FeedForward(nn.Module):
def __init__(self, d_in: int, d_out: int):
super().__init__()
self.fc1 = nn.Linear(d_in, d_out * 2)
self.fc2 = nn.Linear(d_out, d_out)
def forward(self, x):
x = self.fc1(x)
x1, x2 = x.chunk(2, dim=-1)
return self.fc2(F.silu(x1) * x2)
class FeedForwardLayer(nn.Module):
def __init__(
self, d_in: int, d_out: int, dropout: float = 0.1, layer_norm_eps: Optional[float] = 1e-12
):
super().__init__()
self.ff = FeedForward(d_in, d_out)
self.skip = nn.Linear(d_in, d_out) if d_in != d_out else nn.Identity()
self.dropout = nn.Dropout(dropout)
self.norm = (
nn.LayerNorm(d_out, eps=layer_norm_eps)
if layer_norm_eps is not None else nn.Identity()
)
def forward(self, x):
y = self.ff(self.dropout(x)) + self.skip(x)
return self.norm(y)
# ─── pure PyTorch compressor ────────────────────────────────────
class CompressionModel(nn.Module):
"""
Encoder β†’ (optional) Decoder.
"""
def __init__(
self,
d_in: int,
d_comp: int,
encoder_layers: int,
decoder_layers: int,
dropout: float,
layer_norm_eps: Optional[float],
):
super().__init__()
enc_layers: List[nn.Module] = []
for i in range(encoder_layers):
last = i == encoder_layers - 1
enc_layers.append(
FeedForwardLayer(
d_in,
d_comp if last else d_in,
dropout if not last else 0.0,
None if last else layer_norm_eps,
)
)
self.encoder = nn.Sequential(*enc_layers)
# optional decoder
dec_layers: List[nn.Module] = []
for i in range(decoder_layers):
last = i == decoder_layers - 1
d_prev = d_comp if i==0 else d_in
dec_layers.append(
FeedForwardLayer(
d_prev,
d_in,
dropout if not last else 0.0,
None if last else layer_norm_eps,
)
)
self.decoder = nn.Sequential(*dec_layers) if dec_layers else None
def forward(self, x) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
z = self.encoder(x)
x_recon = self.decoder(z) if self.decoder is not None else None
return z, x_recon
# ─── HF wrapper ─────────────────────────────────────────────────
@dataclass
class RZCompressionOutput(ModelOutput):
loss: torch.FloatTensor
loss_terms: Dict[str, torch.Tensor] | None = None
compressed: Dict[int, torch.FloatTensor] | None = None
reconstructed: torch.FloatTensor | None = None
class RZCompressionModel(PreTrainedModel):
config_class = RZCompressionConfig
def __init__(self, config: RZCompressionConfig):
super().__init__(config)
self.compressors = nn.ModuleDict(
{
str(size): CompressionModel(
d_in=config.input_size,
d_comp=size,
encoder_layers=config.encoder_layers,
decoder_layers=config.decoder_layers,
dropout=config.dropout,
layer_norm_eps=config.layer_norm_eps,
)
for size in config.compression_sizes
}
)
self.post_init()
def get_encoders(self, unpack_single=False):
encoders = {}
for k,v in self.compressors.items():
v = v.encoder
if len(v)==1 and unpack_single:
# unpack from nn.Sequential if only a single layer
v = v[0]
encoders[k] = v
encoders = nn.ModuleDict(encoders)
return encoders
def save_encoders(self, path, unpack_single=False):
encoders = self.get_encoders(unpack_single)
torch.save(encoders.state_dict(), path)
def compress(self,
inputs: torch.Tensor,
compression_sizes: List[int]):
compressed = {d: self.compressors[str(d)].encoder(inputs) for d in compression_sizes}
return compressed
def forward(self, embedding, return_dict=True, compute_loss=True):
# ---------- forward passes ------------------------------------------------
compressed, recons = {}, []
for size, module in self.compressors.items():
z, rec = module(embedding)
compressed[int(size)] = z
if rec is not None:
recons.append(rec)
recon_stack = torch.stack(recons, dim=1) if recons else None
# ---------- losses --------------------------------------------------------
if compute_loss:
loss_total, terms = compute_losses(embedding, compressed, recon_stack, self.config)
else:
loss_total, terms = torch.zeros((), device=embedding.device), {}
if not return_dict:
return compressed, recon_stack, loss_total, terms
return RZCompressionOutput(
loss=loss_total,
loss_terms=terms,
compressed=compressed,
reconstructed=recon_stack,
)