roberta_zinc_compression_head / modeling_compression.py
entropy's picture
Upload model
9700d2e verified
raw
history blame contribute delete
4.49 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional, List
from dataclasses import dataclass
from transformers import PreTrainedModel
from transformers.utils import ModelOutput
from .configuration_compression import CompressionConfig
def cosine_pairwise(embeddings):
return F.cosine_similarity(embeddings.unsqueeze(1), embeddings.unsqueeze(0), dim=2)
def cov(tensor, rowvar=True, bias=False):
"""Estimate a covariance matrix (np.cov)"""
tensor = tensor if rowvar else tensor.transpose(-1, -2)
tensor = tensor - tensor.mean(dim=-1, keepdim=True)
factor = 1 / (tensor.shape[-1] - int(not bool(bias)))
return factor * tensor @ tensor.transpose(-1, -2).conj()
def remove_diag(x):
n = x.shape[0]
return x.masked_select(~torch.eye(n, dtype=bool, device=x.device)).view(n, n - 1)
def corrcoef(tensor, rowvar=True):
"""Get Pearson product-moment correlation coefficients (np.corrcoef)"""
covariance = cov(tensor, rowvar=rowvar)
variance = covariance.diagonal(0, -1, -2)
if variance.is_complex():
variance = variance.real
stddev = variance.sqrt()
covariance /= stddev.unsqueeze(-1)
covariance /= stddev.unsqueeze(-2)
if covariance.is_complex():
covariance.real.clip_(-1, 1)
covariance.imag.clip_(-1, 1)
else:
covariance.clip_(-1, 1)
return covariance
def compute_correlation(base_sims, compressed_sims, rm_diag=True):
if rm_diag:
base_sims = remove_diag(base_sims)
compressed_sims = remove_diag(compressed_sims)
inputs = torch.stack([base_sims,
compressed_sims], dim=1)
return (1-corrcoef(inputs)[:, 0, 1]).mean()
def loss_function(base_sims, compressed_sims, k_vals):
outputs = [compute_correlation(base_sims, compressed_sims)]
if k_vals:
base_ranks = base_sims.argsort(-1, descending=True)[:, 1:]
n = base_ranks.shape[1]
for k in k_vals:
base_sims_k = torch.gather(base_sims, 1, base_ranks[:, :k])
compressed_sims_k = torch.gather(compressed_sims, 1, base_ranks[:, :k])
outputs.append(compute_correlation(base_sims_k, compressed_sims_k, rm_diag=False))
return torch.stack(outputs).unsqueeze(0)
class FeedForward(nn.Module):
def __init__(self, d_in, d_out):
super().__init__()
self.fc1 = nn.Linear(d_in, d_out*2)
self.fc2 = nn.Linear(d_out, d_out)
def forward(self, x):
x = self.fc1(x)
x1, x2 = x.chunk(2, dim=-1)
x = self.fc2(F.silu(x1) * x2)
return x
class CompressionHead(nn.Module):
def __init__(self, d_in, d_out, dropout=0.1):
super().__init__()
self.ff = FeedForward(d_in, d_out)
self.skip = nn.Linear(d_in, d_out)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = self.dropout(x)
x = self.ff(x) + self.skip(x)
return x
@dataclass
class CompressionModelOutput(ModelOutput):
loss: Optional[torch.FloatTensor] = None
losses: Optional[List[torch.FloatTensor]] = None
base_embedding: Optional[torch.FloatTensor] = None
compressed_embeddings: Optional[List[torch.FloatTensor]] = None
class CompressionModel(PreTrainedModel):
config_class = CompressionConfig
def __init__(self, config):
super().__init__(config)
self.heads = nn.ModuleList([CompressionHead(config.input_size, i, config.dropout)
for i in config.compression_sizes])
def forward(self, embedding, compute_loss=True, return_dict=True):
outputs = []
losses = None
if compute_loss:
losses = []
emb_sims = cosine_pairwise(embedding)
for head in self.heads:
compressed_embedding = head(embedding)
outputs.append(compressed_embedding)
if compute_loss:
comp_sims = cosine_pairwise(compressed_embedding)
loss = loss_function(emb_sims, comp_sims, self.config.loss_k_vals)
losses.append(loss)
loss = torch.cat(losses).sum()
if not return_dict:
return (loss, losses, embedding, outputs)
return CompressionModelOutput(loss=loss,
losses=losses,
base_embedding=embedding,
compressed_embeddings=outputs)