File size: 3,006 Bytes
ed221d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from transformers import PreTrainedModel
from transformers.utils import ModelOutput

from .configuration_erbb1_mlp import Erbb1MlpConfig


# ─── building blocks ────────────────────────────────────────────
class FeedForward(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.fc1 = nn.Linear(d_in, d_out * 2)
        self.fc2 = nn.Linear(d_out, d_out)

    def forward(self, x):
        x = self.fc1(x)
        x1, x2 = x.chunk(2, dim=-1)
        return self.fc2(F.silu(x1) * x2)


class FeedForwardLayer(nn.Module):
    def __init__(self, d_in, d_out, dropout=0.1, layer_norm_eps=1e-12):
        super().__init__()
        self.ff = FeedForward(d_in, d_out)
        self.skip = nn.Linear(d_in, d_out) if d_in != d_out else nn.Identity()
        self.dropout = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(d_out, eps=layer_norm_eps) if layer_norm_eps else nn.Identity() 

    def forward(self, x):
        y = self.ff(self.dropout(x)) + self.skip(x)
        return self.norm(y)


# ─── HF Model wrapper ───────────────────────────────────────────
@dataclass
class Erbb1MlpOutput(ModelOutput):
    loss: torch.FloatTensor | None = None
    prediction: torch.FloatTensor = None         # denormalised
    prediction_norm: torch.FloatTensor = None    # normalised


class Erbb1MlpModel(PreTrainedModel):
    config_class = Erbb1MlpConfig

    def __init__(self, config: Erbb1MlpConfig):
        super().__init__(config)

        layers = [FeedForwardLayer(config.d_in, config.d_hidden, 0.0, config.layer_norm_eps)]
        layers += [
            FeedForwardLayer(config.d_hidden, config.d_hidden, config.dropout, config.layer_norm_eps)
            for _ in range(config.n_layers - 1)
        ]
        self.body = nn.Sequential(*layers)
        self.out_proj = nn.Linear(config.d_hidden, 1)

        # stats for de-normalising (stored in state dict)
        mean = torch.tensor(config.dataset_mean or 0.0, dtype=torch.float32)
        std = torch.tensor(config.dataset_std or 1.0, dtype=torch.float32)
        self.register_buffer("target_mean", mean, persistent=True)
        self.register_buffer("target_std", std, persistent=True)

        self.post_init()

    def forward(self, embedding, labels=None, return_dict=True):
        x = self.body(embedding)
        pred_norm = self.out_proj(x).squeeze(-1)
        pred = pred_norm * self.target_std + self.target_mean

        loss = None
        if labels is not None:
            loss = F.mse_loss(pred_norm, labels)

        if not return_dict:
            return (loss, pred, pred_norm)

        return Erbb1MlpOutput(
            loss=loss,
            prediction=pred,
            prediction_norm=pred_norm,
        )