entropy commited on
Commit
b90946c
Β·
verified Β·
1 Parent(s): 2eab706

Upload model

Browse files
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a πŸ€— transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "DecomposerModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_decomposer.DecomposerConfig",
7
+ "AutoModel": "modeling_decomposer.DecomposerModel"
8
+ },
9
+ "comp_sizes": [
10
+ 768,
11
+ 512,
12
+ 256,
13
+ 128,
14
+ 64,
15
+ 32
16
+ ],
17
+ "corr_k_vals": [
18
+ 10,
19
+ 100
20
+ ],
21
+ "corr_loss_type": "pearson",
22
+ "corr_weight": 1.0,
23
+ "cosine_weight": 1.0,
24
+ "dropout": 0.1,
25
+ "input_size": 768,
26
+ "layer_norm_eps": 1e-12,
27
+ "model_type": "embedding_decomposer",
28
+ "mse_weight": 0.0,
29
+ "n_comp_layers": 4,
30
+ "n_head_layers": 1,
31
+ "n_output": 2,
32
+ "n_refs_batch": 3072,
33
+ "n_refs_total": 0,
34
+ "n_shared_layers": 8,
35
+ "output_sizes": [
36
+ 768,
37
+ 512,
38
+ 256,
39
+ 128,
40
+ 64,
41
+ 32
42
+ ],
43
+ "shared_dim": 1024,
44
+ "torch_dtype": "float32",
45
+ "transformers_version": "4.51.3"
46
+ }
configuration_decomposer.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+ from transformers import PretrainedConfig
3
+
4
+
5
+ class DecomposerConfig(PretrainedConfig):
6
+ """
7
+ Config for the embedding-decomposition model.
8
+
9
+ Args:
10
+ input_size (int): input embedding size
11
+ comp_sizes (List[int]): compressed embedding sizes
12
+ output_sizes (List[int]): desired output dims (for the two blocks).
13
+ shared_dim (int): common hidden size after input projection.
14
+ n_shared_layers (int): how many FeedForwardLayers in shared trunk.
15
+ dropout (float): dropout prob in *every* non-final layer.
16
+ layer_norm_eps (float|None): epsilon for LayerNorm (None β†’ no LN).
17
+ n_output (int): number of output embeddings.
18
+ n_refs_batch (int): number of reference embeddings to sample per batch
19
+ n_refs_total (int): number of reference embeddings total - set to 0 to skip creating embeddings
20
+ cosine_weight (float): weight of 1-1 cosine similarity loss
21
+ mse_weight (float): weight of 1-1 mse loss
22
+ corr_weight (float): pairwise correlation loss weight
23
+ ref_corr (bool): compute self-to-reference loss
24
+ corr_loss_type (str): correlation loss type - "pearson" or "mse"
25
+ corr_k_vals (List[int]): k-vals for weighting correlation loss
26
+ """
27
+ model_type = "embedding_decomposer"
28
+
29
+ def __init__(
30
+ self,
31
+ input_size: int = 768,
32
+ comp_sizes: List[int] = (768, 512, 256, 128, 64, 32),
33
+ output_sizes: List[int] = (768, 512, 256, 128, 64, 32),
34
+ n_comp_layers: int = 4,
35
+ shared_dim: int = 1024,
36
+ n_shared_layers: int = 8,
37
+ n_head_layers: int = 1,
38
+ dropout: float = 0.1,
39
+ layer_norm_eps: Optional[float] = 1e-12,
40
+ n_output: int = 2,
41
+ n_refs_batch: int = 128,
42
+ n_refs_total: int = 2000,
43
+ cosine_weight: float = 1.0,
44
+ mse_weight: float = 1.0,
45
+ corr_weight: float = 1.0,
46
+ corr_loss_type: str = "pearson", # "pearson" or "mse"
47
+ corr_k_vals: List[int] = [10, 100],
48
+ **kwargs,
49
+ ):
50
+ self.input_size = input_size
51
+ self.comp_sizes = list(comp_sizes)
52
+ self.output_sizes = list(output_sizes)
53
+ self.n_comp_layers = n_comp_layers
54
+ self.shared_dim = shared_dim
55
+ self.n_shared_layers = n_shared_layers
56
+ self.n_head_layers = n_head_layers
57
+ self.dropout = dropout
58
+ self.layer_norm_eps = layer_norm_eps
59
+ self.n_output = n_output
60
+ self.n_refs_batch = n_refs_batch
61
+ self.n_refs_total = n_refs_total
62
+ self.cosine_weight = cosine_weight
63
+ self.mse_weight = mse_weight
64
+ self.corr_weight = corr_weight
65
+ self.corr_loss_type = corr_loss_type
66
+ self.corr_k_vals = corr_k_vals
67
+ super().__init__(**kwargs)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4fa2d69d2ea6349ff666b7418c088f3b0a394cc772d956b1771df3dca2a42e52
3
+ size 291771256
modeling_decomposer.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass
3
+ from typing import Dict, List, Optional
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from transformers import PreTrainedModel
9
+ from transformers.utils import ModelOutput
10
+
11
+ from .configuration_decomposer import DecomposerConfig
12
+
13
+ def pairwise_cosine(x: torch.Tensor) -> torch.Tensor:
14
+ """
15
+ x : [B,d] or [N,B,d]
16
+ returns a square similarity matrix:
17
+ [B,B] or [N,B,B]
18
+ """
19
+ x = F.normalize(x, p=2, dim=-1)
20
+ return torch.matmul(x, x.transpose(-1, -2))
21
+
22
+ def cross_cosine(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
23
+ """
24
+ a : [M,d] or [N,M,d]
25
+ b : [K,d] (reference set - no extra axis)
26
+ returns:
27
+ [M,K] or [N,M,K]
28
+ """
29
+ a_n = F.normalize(a, 2, -1)
30
+ b_n = F.normalize(b, 2, -1)
31
+
32
+ if a.ndim == 2: # [M,d]
33
+ return a_n @ b_n.T # [M,K]
34
+
35
+ if a.ndim == 3: # [N,M,d]
36
+ return torch.einsum("n m d , k d -> n m k", a_n, b_n) # [N,M,K]
37
+
38
+ raise ValueError("cross_cosine: unexpected tensor rank.")
39
+
40
+ def _drop_diag(M: torch.Tensor) -> torch.Tensor:
41
+ """
42
+ Remove the main diagonal per similarity matrix.
43
+ works for 2-D [B,B] or 3-D [N,B,B] tensors.
44
+ """
45
+ if M.ndim == 2:
46
+ n = M.size(0)
47
+ return M.masked_select(~torch.eye(n, dtype=torch.bool, device=M.device)
48
+ ).view(n, n - 1)
49
+
50
+ if M.ndim == 3:
51
+ n = M.size(1)
52
+ mask = torch.eye(n, dtype=torch.bool, device=M.device).unsqueeze(0) # [1,B,B]
53
+ return M.masked_select(~mask).view(M.size(0), n, n - 1)
54
+
55
+ raise ValueError("_drop_diag expects 2- or 3-D tensor")
56
+
57
+
58
+ def rowwise_pearson(ref: torch.Tensor,
59
+ pred: torch.Tensor,
60
+ *,
61
+ rm_diag: bool = True) -> torch.Tensor:
62
+ """
63
+ Pearson row-by-row; supports 2-D or 3-D inputs with identical shape.
64
+ returns mean correlation error (0 β†’ perfect).
65
+ """
66
+ if rm_diag:
67
+ ref = _drop_diag(ref)
68
+ pred = _drop_diag(pred)
69
+
70
+ ref_z = F.normalize(ref - ref.mean(-1, keepdim=True), p=2, dim=-1)
71
+ pred_z = F.normalize(pred - pred.mean(-1, keepdim=True), p=2, dim=-1)
72
+ loss = 1 - (ref_z * pred_z).sum(-1).mean(-1)
73
+ if loss.ndim==0:
74
+ loss = loss.unsqueeze(0)
75
+ return loss
76
+
77
+ def similarity_mse(ref: torch.Tensor,
78
+ pred: torch.Tensor,
79
+ *,
80
+ rm_diag: bool = True) -> torch.Tensor:
81
+ if rm_diag:
82
+ ref, pred = _drop_diag(ref), _drop_diag(pred)
83
+
84
+ if pred.ndim==2:
85
+ loss = F.mse_loss(pred, ref).mean().unsqueeze(0)
86
+ elif pred.ndim==3:
87
+ loss = F.mse_loss(pred,
88
+ ref.expand_as(pred),
89
+ reduction="none"
90
+ ).reshape(pred.size(0), -1).mean(-1)
91
+
92
+ return loss
93
+
94
+
95
+ def sim_loss(pred: torch.Tensor, # [N,B,d] or [B,d]
96
+ targ: torch.Tensor, # [B,d] (ground truth)
97
+ ref: Optional[torch.Tensor],
98
+ k_vals: Optional[List[int]],
99
+ loss_type: str = "pearson") -> torch.Tensor:
100
+ """
101
+ Returns stacked tensor of losses:
102
+ len = 1 + len(k_vals)
103
+ If `ref` is given we compute cross-similarities pred↔ref / targ↔ref,
104
+ otherwise self-similarities pred↔pred / targ↔targ.
105
+ """
106
+
107
+ loss_fn = rowwise_pearson if loss_type == "pearson" else similarity_mse
108
+
109
+ if ref is None: # self-sim
110
+ p_sim, t_sim = pairwise_cosine(pred), pairwise_cosine(targ)
111
+ rm_diag = True
112
+ else: # cross-sim vs fixed reference
113
+ p_sim, t_sim = cross_cosine(pred, ref), cross_cosine(targ, ref)
114
+ rm_diag = False
115
+
116
+ losses = [loss_fn(t_sim, p_sim, rm_diag=rm_diag)]
117
+
118
+ if k_vals:
119
+ # ranks based on target sims (works for 2- or 3-D)
120
+ ranks = t_sim.argsort(-1, descending=True)
121
+ start = 1 if rm_diag else 0
122
+ for k in k_vals:
123
+ idx = ranks[..., start:start + k]
124
+ t_k = torch.gather(t_sim, -1, idx)
125
+ if p_sim.ndim==2:
126
+ p_k = torch.gather(p_sim, -1, idx)
127
+ elif p_sim.ndim==3:
128
+ p_k = torch.gather(p_sim, -1, idx.repeat(p_sim.size(0), 1, 1))
129
+ losses.append(loss_fn(t_k, p_k, rm_diag=False))
130
+
131
+ return torch.stack(losses, 1) # shape [n_losses]
132
+
133
+
134
+ # ─────────────────────────────── building blocks ──────────────────────────────
135
+ class FeedForward(nn.Module):
136
+ def __init__(self, d_in: int, d_out: int):
137
+ super().__init__()
138
+ self.fc1 = nn.Linear(d_in, d_out * 2)
139
+ self.fc2 = nn.Linear(d_out, d_out)
140
+
141
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
142
+ x1, x2 = self.fc1(x).chunk(2, -1)
143
+ return self.fc2(F.silu(x1) * x2)
144
+
145
+
146
+ class FeedForwardLayer(nn.Module):
147
+ def __init__(self,
148
+ d_in: int,
149
+ d_out: int,
150
+ *,
151
+ dropout: float = .1,
152
+ ln_eps: Optional[float] = 1e-12):
153
+ super().__init__()
154
+ self.ff = FeedForward(d_in, d_out)
155
+ self.skip = nn.Linear(d_in, d_out) if d_in != d_out else nn.Identity()
156
+ self.drop = nn.Dropout(dropout)
157
+ self.norm = nn.LayerNorm(d_out, eps=ln_eps) if ln_eps else nn.Identity()
158
+
159
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
160
+ return self.norm(self.ff(self.drop(x)) + self.skip(x))
161
+
162
+ class OutputLinear(nn.Module):
163
+ def __init__(self,
164
+ input_size: int,
165
+ n_head_layers: int,
166
+ n_output: int,
167
+ output_sizes: List[int],
168
+ dropout: float=0.1,
169
+ ln_eps: Optional[float] = 1e-12):
170
+ super().__init__()
171
+ self.n_output = n_output
172
+ ff_layers = [FeedForwardLayer(input_size, input_size, dropout=dropout,
173
+ ln_eps=None if i==n_head_layers-1 else ln_eps)
174
+ for i in range(n_head_layers)]
175
+ self.ff = nn.Sequential(*ff_layers)
176
+ self.layers = nn.ModuleDict({str(d): nn.Linear(input_size, d*n_output)
177
+ for d in output_sizes})
178
+
179
+ def forward(self, inputs: torch.Tensor, sizes: List[int]):
180
+ inputs = self.ff(inputs)
181
+ weights = torch.cat([self.layers[str(i)].weight for i in sizes])
182
+ biases = torch.cat([self.layers[str(i)].bias for i in sizes])
183
+ outputs = F.linear(inputs, weights, biases)
184
+ output_dict = {}
185
+ current = 0
186
+
187
+ slice_sizes = [d*self.n_output for d in sizes]
188
+ for size in slice_sizes:
189
+ p = outputs[:, :, current:current+size]
190
+ p = p.view(p.size(0), p.size(1), self.n_output, size//self.n_output)
191
+ output_dict[size//self.n_output] = p
192
+ current += size
193
+ return output_dict
194
+
195
+ def get_compression_heads(d_in, comp_sizes, n_layers, add_input_identity=False):
196
+ compression_heads = nn.ModuleDict({})
197
+ for d in comp_sizes:
198
+ enc_layers = []
199
+ for i in range(n_layers):
200
+ last = i == n_layers - 1
201
+ enc_layers.append(
202
+ FeedForwardLayer(
203
+ d_in,
204
+ d if last else d_in,
205
+ dropout=0.0,
206
+ ln_eps=None if last else 1e-12,
207
+ )
208
+ )
209
+ compression_heads[str(d)] = nn.Sequential(*enc_layers)
210
+ if add_input_identity:
211
+ compression_heads[str(d_in)] = nn.Identity()
212
+
213
+ return compression_heads
214
+
215
+ # ───────────────────────────── output dataclass ───────────────────────────────
216
+ @dataclass
217
+ class DecomposerOutput(ModelOutput):
218
+ loss: torch.FloatTensor
219
+ loss_terms: Optional[Dict[str, torch.Tensor]] = None
220
+ decomp: Optional[Dict[int, torch.FloatTensor]] = None # {size:[B,2,size]}
221
+ ref_idxs: Optional[torch.LongTensor] = None
222
+
223
+
224
+ # ──────────────────────────────── main model ──────────────────────────────────
225
+ class DecomposerModel(PreTrainedModel):
226
+ """Maps an embedding to *n_output* building-block embeddings for every
227
+ requested `output_size`. All loops are left intact for clarity."""
228
+ config_class = DecomposerConfig
229
+
230
+ # ---------------------------------------------------------------- init
231
+ def __init__(self, config: DecomposerConfig):
232
+ super().__init__(config)
233
+
234
+ # compression heads to avoid needing to save all embedding sizes for training
235
+ self.compression_heads = get_compression_heads(config.input_size,
236
+ config.comp_sizes,
237
+ config.n_comp_layers,
238
+ add_input_identity=True)
239
+ # input β†’ shared_dim
240
+ self.in_proj = nn.ModuleDict({
241
+ str(d): FeedForwardLayer(d, config.shared_dim,
242
+ dropout=config.dropout,
243
+ ln_eps=config.layer_norm_eps)
244
+ for d in config.comp_sizes
245
+ })
246
+
247
+ # shared trunk
248
+ blk = lambda: FeedForwardLayer(config.shared_dim,
249
+ config.shared_dim,
250
+ dropout=config.dropout,
251
+ ln_eps=config.layer_norm_eps)
252
+ self.trunk = nn.Sequential(*[blk() for _ in range(config.n_shared_layers)])
253
+
254
+ # shared_dim β†’ each output size Γ— n_output
255
+ self.out_proj = OutputLinear(self.config.shared_dim,
256
+ self.config.n_head_layers,
257
+ config.n_output,
258
+ config.output_sizes,
259
+ config.dropout,
260
+ config.layer_norm_eps)
261
+
262
+ # reference embeddings (optional corr-loss)
263
+ self.ref_emb = nn.ModuleDict({
264
+ str(d): nn.Embedding(config.n_refs_total, d)
265
+ for d in config.output_sizes if config.n_refs_total
266
+ })
267
+
268
+ self.post_init()
269
+
270
+ # ---------------------------------------------------------------- forward
271
+ def compress(self,
272
+ inputs: torch.Tensor, # {size: [B,size]}
273
+ comp_sizes: List[int]):
274
+ compressed = {d: self.compression_heads[str(d)](inputs) for d in comp_sizes}
275
+ return compressed
276
+
277
+ def decompose(self,
278
+ inputs: Dict[int, torch.Tensor], # {size: [B,size]}
279
+ output_sizes: List[int]):
280
+ hiddens = []
281
+ for input_size in self.config.comp_sizes:
282
+ if input_size not in inputs:
283
+ continue
284
+
285
+ h = self.in_proj[str(input_size)](inputs[input_size]) # [B,shared_dim]
286
+ hiddens.append(h)
287
+
288
+ hiddens = torch.stack(hiddens, dim=0) # [n_sizes, B, shared_dim]
289
+ hiddens = self.trunk(hiddens)
290
+
291
+ preds = self.out_proj(hiddens, output_sizes) # {size: [n_sizes, B, n_output, size]}
292
+ return preds
293
+
294
+ def load_targets(self,
295
+ bb1_ids: torch.LongTensor, # [B,]
296
+ bb2_ids: torch.LongTensor): # [B,]
297
+ targets = {}
298
+ for size in self.config.output_sizes:
299
+ embedding = self.ref_emb[str(size)]
300
+ targets[size] = torch.stack([embedding(bb1_ids), embedding(bb2_ids)], dim=1)
301
+ return targets
302
+
303
+ def compute_loss(self,
304
+ inputs: Dict[int, torch.Tensor],
305
+ preds: Dict[int, torch.Tensor],
306
+ targets: Dict[int, torch.Tensor],
307
+ ref_idxs: Optional[torch.LongTensor]=None,):
308
+ device = next(iter(preds.values())).device
309
+ loss_terms: Dict[str, torch.Tensor] = {}
310
+ loss_total = torch.zeros((), device=device)
311
+ cfg = self.config
312
+ for out_size in cfg.output_sizes:
313
+ p = preds[out_size]
314
+ t = targets[out_size] # [B, n_out, d]
315
+
316
+ # 1) cosine to target ------------------------------------
317
+ if cfg.cosine_weight>0:
318
+ cos = 1 - F.cosine_similarity(p, t, dim=-1).view(p.size(0), -1).mean(-1)
319
+ loss_total += cfg.cosine_weight * cos.sum()
320
+ for i, in_size in enumerate(cfg.comp_sizes):
321
+ loss_terms[f"{in_size}->{out_size}_cos"] = cos[i]
322
+
323
+ # 2) mse to target ---------------------------------------
324
+ if cfg.mse_weight>0:
325
+ mse = F.mse_loss(p, t.expand_as(p), reduction="none").view(p.size(0), -1).mean(-1)
326
+ loss_total += cfg.mse_weight * mse.sum()
327
+ for i, in_size in enumerate(cfg.comp_sizes):
328
+ loss_terms[f"{in_size}->{out_size}_mse"] = mse[i]
329
+
330
+ # 3) correlation losses ----------------------------------
331
+ if cfg.corr_weight:
332
+ flat_p = p.flatten(1, 2)
333
+ flat_t = t.flatten(0, 1)
334
+
335
+ with torch.no_grad():
336
+ ref = self.ref_emb[str(out_size)](ref_idxs)
337
+
338
+ ref_corr = sim_loss(flat_p, flat_t, ref,
339
+ cfg.corr_k_vals, cfg.corr_loss_type).mean(-1)
340
+ loss_total += cfg.corr_weight * ref_corr.sum()
341
+ for i, in_size in enumerate(cfg.comp_sizes):
342
+ loss_terms[f"{in_size}->{out_size}_corr_ref"] = ref_corr[i]
343
+
344
+ return loss_total, loss_terms
345
+
346
+ def forward(self,
347
+ embedding: torch.Tensor, # [B,size]
348
+ bb1_id: torch.LongTensor, # [B,]
349
+ bb2_id: torch.LongTensor, # [B,]
350
+ *,
351
+ ref_idxs: Optional[torch.LongTensor]=None,
352
+ return_preds: bool = False,
353
+ compute_loss: bool = True,
354
+ return_dict: bool = True) -> DecomposerOutput: # | tuple:
355
+
356
+ cfg = self.config
357
+ device = embedding.device
358
+ targets = self.load_targets(bb1_id, bb2_id)
359
+
360
+ if cfg.corr_weight and cfg.n_refs_total and ref_idxs is None:
361
+ ref_idxs = torch.randint(cfg.n_refs_total,
362
+ (cfg.n_refs_batch,),
363
+ device=device)
364
+
365
+ loss_terms: Dict[str, torch.Tensor] = {}
366
+ loss_total = torch.zeros((), device=device) if compute_loss else None
367
+
368
+ with torch.no_grad():
369
+ compressed_inputs = self.compress(embedding, cfg.comp_sizes)
370
+
371
+ if cfg.input_size in cfg.comp_sizes:
372
+ compressed_inputs[cfg.input_size] = embedding
373
+
374
+ preds = self.decompose(compressed_inputs, cfg.output_sizes)
375
+
376
+ loss_total = None
377
+ loss_terms = {}
378
+ if compute_loss:
379
+ loss_total, loss_terms = self.compute_loss(compressed_inputs, preds, targets, ref_idxs)
380
+
381
+ decomp = {k:v.permute(1,0,2,3) for k,v in preds.items()}
382
+
383
+ return DecomposerOutput(loss = loss_total,
384
+ loss_terms = loss_terms,
385
+ decomp = decomp,
386
+ ref_idxs = ref_idxs)
387
+
388
+