File size: 3,006 Bytes
ed221d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from transformers import PreTrainedModel
from transformers.utils import ModelOutput
from .configuration_erbb1_mlp import Erbb1MlpConfig
# βββ building blocks ββββββββββββββββββββββββββββββββββββββββββββ
class FeedForward(nn.Module):
def __init__(self, d_in, d_out):
super().__init__()
self.fc1 = nn.Linear(d_in, d_out * 2)
self.fc2 = nn.Linear(d_out, d_out)
def forward(self, x):
x = self.fc1(x)
x1, x2 = x.chunk(2, dim=-1)
return self.fc2(F.silu(x1) * x2)
class FeedForwardLayer(nn.Module):
def __init__(self, d_in, d_out, dropout=0.1, layer_norm_eps=1e-12):
super().__init__()
self.ff = FeedForward(d_in, d_out)
self.skip = nn.Linear(d_in, d_out) if d_in != d_out else nn.Identity()
self.dropout = nn.Dropout(dropout)
self.norm = nn.LayerNorm(d_out, eps=layer_norm_eps) if layer_norm_eps else nn.Identity()
def forward(self, x):
y = self.ff(self.dropout(x)) + self.skip(x)
return self.norm(y)
# βββ HF Model wrapper βββββββββββββββββββββββββββββββββββββββββββ
@dataclass
class Erbb1MlpOutput(ModelOutput):
loss: torch.FloatTensor | None = None
prediction: torch.FloatTensor = None # denormalised
prediction_norm: torch.FloatTensor = None # normalised
class Erbb1MlpModel(PreTrainedModel):
config_class = Erbb1MlpConfig
def __init__(self, config: Erbb1MlpConfig):
super().__init__(config)
layers = [FeedForwardLayer(config.d_in, config.d_hidden, 0.0, config.layer_norm_eps)]
layers += [
FeedForwardLayer(config.d_hidden, config.d_hidden, config.dropout, config.layer_norm_eps)
for _ in range(config.n_layers - 1)
]
self.body = nn.Sequential(*layers)
self.out_proj = nn.Linear(config.d_hidden, 1)
# stats for de-normalising (stored in state dict)
mean = torch.tensor(config.dataset_mean or 0.0, dtype=torch.float32)
std = torch.tensor(config.dataset_std or 1.0, dtype=torch.float32)
self.register_buffer("target_mean", mean, persistent=True)
self.register_buffer("target_std", std, persistent=True)
self.post_init()
def forward(self, embedding, labels=None, return_dict=True):
x = self.body(embedding)
pred_norm = self.out_proj(x).squeeze(-1)
pred = pred_norm * self.target_std + self.target_mean
loss = None
if labels is not None:
loss = F.mse_loss(pred_norm, labels)
if not return_dict:
return (loss, pred, pred_norm)
return Erbb1MlpOutput(
loss=loss,
prediction=pred,
prediction_norm=pred_norm,
)
|