hsiangyualex commited on
Commit
f97a499
·
verified ·
1 Parent(s): af10d58

Upload 64 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
app.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ['CUDA_VISIBLE_DEVICES'] = '3'
3
+ import gradio as gr
4
+ import yaml
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import numpy as np
8
+ import pandas as pd
9
+ from models.modules.networks import PromptAttentionUNet, HighResEnhancer
10
+ from models.modules.biomedclip import BiomedCLIPTextEncoder
11
+ from monai.inferers import sliding_window_inference
12
+ from markers import breast_markers, prostatic_markers, pancreatic_markers
13
+
14
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
+
16
+ # load the convertion model
17
+ # load the cfg file for convertion
18
+ cfg_file = 'configs/{}.cfg'.format('convertion')
19
+ with open(cfg_file, 'r') as f:
20
+ cfg = yaml.safe_load(f)
21
+ print("successfully loaded config file: ", cfg)
22
+
23
+ # convertion models
24
+ convertion_ckpt = './checkpoint/stage_ii.pkl'
25
+ convertion_net = PromptAttentionUNet(in_channels=cfg['MODEL']['IMC_IN'], out_channels=cfg['MODEL']['IMC_OUT'], channels=(128, 256, 512, 1024, 2048))
26
+ prompt_model = BiomedCLIPTextEncoder(device=device)
27
+
28
+ # load state_dict
29
+ state_dict = torch.load(convertion_ckpt, map_location='cpu')['generator']
30
+ # remove all the 'module.' prefix
31
+ state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
32
+ convertion_net.load_state_dict(state_dict)
33
+
34
+ # load the translation model
35
+ cfg_file = 'configs/{}.cfg'.format('translation')
36
+ with open(cfg_file, 'r') as f:
37
+ cfg = yaml.safe_load(f)
38
+ print("successfully loaded config file: ", cfg)
39
+ translation_ckpt = './checkpoint/stage_i.pkl'
40
+
41
+ imc_net = HighResEnhancer(model_name=cfg['MODEL']['TIMM_MODEL'],
42
+ in_channels=cfg['MODEL']['IMC_IN'],
43
+ out_channels=cfg['MODEL']['IMC_OUT'],
44
+ norm=cfg['MODEL']['NORM'],
45
+ use_dilated_bottleneck=True)
46
+
47
+ # load state_dict for IMC
48
+ state_dict = torch.load(translation_ckpt, map_location='cpu')['imc_G']
49
+ # remove all the 'module.0.' prefix
50
+ state_dict = {k.replace('module.0.', ''): v for k, v in state_dict.items()}
51
+ # remove the key "sobel.filter" in the state_dict
52
+ state_dict.pop('sobel.filter.weight')
53
+ imc_net.load_state_dict(state_dict, strict=False)
54
+
55
+ convertion_net.eval().to(device)
56
+ imc_net.eval().to(device)
57
+
58
+ # load the metadata for demo data
59
+ df = pd.read_csv('./test_data/test_metadata.csv')
60
+ breast_df = df[df['source'] == 'BreastCancer_V2']
61
+ prostatic_df = df[df['source'] == 'ProstaticCancer_V2']
62
+ pancreatic_df = df[df['source'] == 'PancreaticCancer_V2']
63
+
64
+
65
+ def load_image(pair_index):
66
+ # select the item from the dataframe and convert to Series using `squeeze()`
67
+ item = df[df['name'] == pair_index].squeeze()
68
+ data = np.load(item['path'])['arr_0']
69
+ x1 = data[:, :, 0]
70
+ x2 = data[:, :, 1]
71
+ return gr.Image(value=x1), gr.Image(value=x2)
72
+
73
+
74
+ def generate_imc(x1, x2, marker_name):
75
+ # stage I
76
+ inputs = np.concatenate([x1, x2[:, :, 2:3]], axis=2)
77
+ # normalize to [0, 1]
78
+ inputs = inputs / 255.0
79
+ # to tensor
80
+ inputs = torch.from_numpy(inputs.transpose(2, 0, 1)).unsqueeze(0).float()
81
+ # rescale to [-1, 1]
82
+ inputs = 2 * inputs - 1
83
+ output = sliding_window_inference(inputs.to(device), roi_size=(320, 320), sw_batch_size=2, predictor=imc_net, overlap=0.5)
84
+ output = F.tanh(output)
85
+ # to numpy
86
+ pred_nuclei = output[0].detach().cpu().numpy().transpose(1, 2, 0)
87
+ pred_nuclei = (pred_nuclei + 1) / 2 # normalize to [0, 1]
88
+ # stage II
89
+ nuclei_inputs = torch.from_numpy(pred_nuclei).permute(2, 0, 1).unsqueeze(0).float()
90
+ # rescale to [-1, 1]
91
+ nuclei_inputs = 2 * nuclei_inputs - 1
92
+ prompt_in = torch.as_tensor(prompt_model([marker_name])).to(device)
93
+ output = F.tanh(convertion_net(nuclei_inputs.to(device), prompt_in))
94
+ marker = output[0].detach().cpu().numpy().transpose(1, 2, 0)
95
+ marker = (marker + 1) / 2 # normalize to [0, 1]
96
+ # visualization
97
+ vis = np.concatenate([marker, np.zeros_like(pred_nuclei, dtype=np.float32), pred_nuclei], axis=2)
98
+ # normalize to [0, 255] and convert to uint8
99
+ vis = (vis * 255).astype(np.uint8)
100
+ return gr.Image(value=vis)
101
+
102
+
103
+ # Function to update the second dropdown based on the first dropdown's selection
104
+ def update_dropdown_by_tissue(selected_category):
105
+ if selected_category == "Breast":
106
+ image_selector = gr.Dropdown(choices=breast_df['name'].values.tolist(), value=breast_df['name'].values[0], interactive=True)
107
+ marker_selector = gr.Dropdown(choices=breast_markers, value=breast_markers[0], interactive=True)
108
+ elif selected_category == "Pancreatic":
109
+ image_selector = gr.Dropdown(choices=pancreatic_df['name'].values.tolist(), value=pancreatic_df['name'].values[0], interactive=True)
110
+ marker_selector = gr.Dropdown(choices=pancreatic_markers, value=pancreatic_markers[0], interactive=True)
111
+ elif selected_category == "Prostatic":
112
+ image_selector = gr.Dropdown(choices=prostatic_df['name'].values.tolist(), value=prostatic_df['name'].values[0], interactive=True)
113
+ marker_selector = gr.Dropdown(choices=prostatic_markers, value=prostatic_markers[0], interactive=True)
114
+ return [image_selector, marker_selector]
115
+
116
+
117
+ # Create the Gradio interface
118
+ def create_gradio_ui():
119
+ with gr.Blocks() as demo:
120
+ with gr.Tab("Mbi2Spi"):
121
+ with gr.Row():
122
+ with gr.Column(scale=1):
123
+ with gr.Row():
124
+ # image visualizer
125
+ brightfield = gr.Image(label="Brightfield Image", type="numpy", interactive=False)
126
+ aux = gr.Image(type="numpy", visible=False, interactive=False)
127
+
128
+ with gr.Row():
129
+ with gr.Column():
130
+ # tissue selector (Breast, Pancreatic, Prostatic)
131
+ tissue_selector = gr.Dropdown(choices=["Breast", "Pancreatic", "Prostatic"], label="Select Tissue Type")
132
+ # marker selector
133
+ marker_selector = gr.Dropdown(label="Marker Selector", interactive=False)
134
+
135
+ with gr.Column():
136
+ # image selector
137
+ image_selector = gr.Dropdown(label="Brightfield Selector", interactive=False)
138
+ # update the image selector based on the tissue type
139
+ tissue_selector.change(update_dropdown_by_tissue, inputs=tissue_selector, outputs=[image_selector, marker_selector])
140
+
141
+ with gr.Column(scale=1):
142
+ output_image = gr.Image(label="Generated Image", type="numpy")
143
+ button1 = gr.Button("Predict IMC")
144
+
145
+ # Load the selected image and update the input image and infrared image
146
+ image_selector.change(load_image, inputs=image_selector, outputs=[brightfield, aux])
147
+
148
+ # Event handler for button click
149
+ button1.click(generate_imc, inputs=[brightfield, aux, marker_selector], outputs=output_image)
150
+
151
+ return demo
152
+
153
+ # Launch the demo
154
+ if __name__ == '__main__':
155
+ demo = create_gradio_ui()
156
+ demo.launch(show_error=True)
base/__init__.py ADDED
File without changes
base/base_modules.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ norm_dict = {'BATCH': nn.BatchNorm2d, 'INSTANCE': nn.InstanceNorm2d, 'GROUP': nn.GroupNorm}
7
+ NUM_GROUPS = 16
8
+ __all__ = ['ConvNorm', 'ConvBlock', 'ConvBottleNeck', 'ResBlock', 'ResBottleneck', 'PromptResBlock', 'PromptResBottleneck', 'PromptAttentionModule', 'norm_dict', 'SobelEdge']
9
+
10
+
11
+ class Identity(nn.Module):
12
+ """
13
+ Identity mapping for building a residual connection
14
+ """
15
+ def __init__(self):
16
+ super().__init__()
17
+
18
+ def forward(self, x):
19
+ return x
20
+
21
+
22
+ class ConvNorm(nn.Module):
23
+ """
24
+ Convolution and normalization
25
+ """
26
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, leaky=True, norm='INSTANCE', activation=True):
27
+ super().__init__()
28
+ # determine basic attributes
29
+ self.norm_type = norm
30
+ padding = (kernel_size - 1) // 2
31
+
32
+ # activation, support PReLU and common ReLU
33
+ if activation:
34
+ self.act = nn.LeakyReLU() if leaky else nn.ReLU(inplace=False)
35
+ # self.act = nn.ELU() if leaky else nn.ReLU(inplace=True)
36
+ else:
37
+ self.act = None
38
+
39
+ # instantiate layers
40
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)
41
+
42
+ if norm in ['BATCH', 'INSTANCE']:
43
+ norm_layer = norm_dict[norm]
44
+ self.norm = norm_layer(out_channels)
45
+ elif norm == 'GROUP':
46
+ norm_layer = norm_dict[norm]
47
+ self.norm = norm_layer(NUM_GROUPS, in_channels)
48
+ elif norm == 'NONE':
49
+ self.norm = nn.Identity()
50
+ else:
51
+ raise NotImplementedError(f'Normalization type {norm} not implemented')
52
+
53
+ def basic_forward(self, x):
54
+ x = self.conv(x)
55
+ x = self.norm(x)
56
+ if self.act:
57
+ x = self.act(x)
58
+ return x
59
+
60
+ def group_forward(self, x):
61
+ x = self.norm(x)
62
+ if self.act:
63
+ x = self.act(x)
64
+ x = self.conv(x)
65
+ return x
66
+
67
+ def forward(self, x):
68
+ if self.norm_type in ['BATCH', 'INSTANCE']:
69
+ return self.basic_forward(x)
70
+ else:
71
+ return self.group_forward(x)
72
+
73
+
74
+ class PromptAttentionModule(nn.Module):
75
+ def __init__(self, in_channels: int, prompt_channels: int, mid_channels: int) -> None:
76
+ super().__init__()
77
+ self.gap = nn.AdaptiveAvgPool2d(1)
78
+ self.conv_down = nn.Linear(in_channels, mid_channels)
79
+ self.prompt_down = nn.Linear(prompt_channels, mid_channels)
80
+ self.fc = nn.Linear(2 * mid_channels, in_channels)
81
+
82
+ def forward(self, x: torch.Tensor, prompt_in: torch.Tensor):
83
+ """
84
+ Args:
85
+ x: (B, C_im, H, W)
86
+ prompt_in: (B, C_prompt)
87
+ """
88
+ x_gap = self.gap(x).squeeze(-1).squeeze(-1) # (B, C_im)
89
+ x_gap = self.conv_down(x_gap) # (B, C_mid)
90
+ prompt_down = self.prompt_down(prompt_in) # (B, C_mid)
91
+ gating = torch.cat([x_gap, prompt_down], dim=-1) # (B, 2 * C_mid)
92
+ gating = F.sigmoid(self.fc(F.relu(gating)))[..., None, None] # (B, C_im, 1, 1)
93
+ return x * gating
94
+
95
+
96
+ class ConvBlock(nn.Module):
97
+ """
98
+ Convolutional blocks
99
+ """
100
+ def __init__(self, in_channels, out_channels, stride=1, leaky=False, norm='INSTANCE'):
101
+ super().__init__()
102
+ self.norm_type = norm
103
+ # activation, support PReLU and common ReLU
104
+ self.act = nn.LeakyReLU() if leaky else nn.ReLU(inplace=False)
105
+ # self.act = nn.ELU() if leaky else nn.ReLU(inplace=True)
106
+
107
+ self.conv1 = ConvNorm(in_channels, out_channels, 3, stride, leaky, norm, True)
108
+ self.conv2 = ConvNorm(out_channels, out_channels, 3, 1, leaky, norm, False)
109
+
110
+ def forward(self, x):
111
+ out = self.conv1(x)
112
+ out = self.conv2(out)
113
+
114
+ if self.norm_type != 'GROUP':
115
+ out = self.act(out)
116
+
117
+ return out
118
+
119
+
120
+ class ResBlock(nn.Module):
121
+ """
122
+ Residual blocks
123
+ """
124
+ def __init__(self, in_channels, out_channels, stride=1, use_dropout=False, leaky=False, norm='INSTANCE'):
125
+ super().__init__()
126
+ self.norm_type = norm
127
+ self.act = nn.LeakyReLU() if leaky else nn.ReLU(inplace=False)
128
+ self.dropout = nn.Dropout2d(p=0.1) if use_dropout else None
129
+ # self.act = nn.ELU() if leaky else nn.ReLU(inplace=True)
130
+
131
+ self.conv1 = ConvNorm(in_channels, out_channels, 3, stride, leaky, norm, True)
132
+ self.conv2 = ConvNorm(out_channels, out_channels, 3, 1, leaky, norm, False)
133
+
134
+ need_map = in_channels != out_channels or stride != 1
135
+ self.id = ConvNorm(in_channels, out_channels, 1, stride, leaky, norm, False) if need_map else Identity()
136
+
137
+ def forward(self, x):
138
+ identity = x
139
+ out = self.conv1(x)
140
+ out = self.conv2(out)
141
+ identity = self.id(identity)
142
+
143
+ out = out + identity
144
+ if self.norm_type != 'GROUP':
145
+ out = self.act(out)
146
+
147
+ if self.dropout:
148
+ out = self.dropout(out)
149
+
150
+ return out
151
+
152
+
153
+ class ConvBottleNeck(nn.Module):
154
+ """
155
+ Convolutional bottleneck blocks
156
+ """
157
+ def __init__(self, in_channels, out_channels, stride=1, leaky=False, norm='INSTANCE'):
158
+ super().__init__()
159
+ self.norm_type = norm
160
+ middle_channels = in_channels // 4
161
+ self.act = nn.LeakyReLU() if leaky else nn.ReLU(inplace=False)
162
+ # self.act = nn.ELU() if leaky else nn.ReLU(inplace=True)
163
+
164
+ self.conv1 = ConvNorm(in_channels, middle_channels, 1, 1, leaky, norm, True)
165
+ self.conv2 = ConvNorm(middle_channels, middle_channels, 3, stride, leaky, norm, True)
166
+ self.conv3 = ConvNorm(middle_channels, out_channels, 1, 1, leaky, norm, False)
167
+
168
+ def forward(self, x):
169
+ out = self.conv1(x)
170
+ out = self.conv2(out)
171
+ out = self.conv3(out)
172
+
173
+ if self.norm_type != 'GROUP':
174
+ out = self.act(out)
175
+
176
+ return out
177
+
178
+
179
+ class ResBottleneck(nn.Module):
180
+ """
181
+ Residual bottleneck blocks
182
+ """
183
+ def __init__(self, in_channels, out_channels, stride=1, use_dropout=False, leaky=False, norm='INSTANCE'):
184
+ super().__init__()
185
+ self.norm_type = norm
186
+ middle_channels = in_channels // 4
187
+ self.act = nn.LeakyReLU() if leaky else nn.ReLU(inplace=False)
188
+ self.dropout = nn.Dropout2d(p=0.1) if use_dropout else None
189
+ # self.act = nn.ELU() if leaky else nn.ReLU(inplace=True)
190
+
191
+ self.conv1 = ConvNorm(in_channels, middle_channels, 1, 1, leaky, norm, True)
192
+ self.conv2 = ConvNorm(middle_channels, middle_channels, 3, stride, leaky, norm, True)
193
+ self.conv3 = ConvNorm(middle_channels, out_channels, 1, 1, leaky, norm, False)
194
+
195
+ need_map = in_channels != out_channels or stride != 1
196
+ self.id = ConvNorm(in_channels, out_channels, 1, stride, leaky, norm, False) if need_map else Identity()
197
+
198
+ def forward(self, x):
199
+ identity = x
200
+ out = self.conv1(x)
201
+ out = self.conv2(out)
202
+ out = self.conv3(out)
203
+ identity = self.id(identity)
204
+
205
+ out = out + identity
206
+ if self.norm_type != 'GROUP':
207
+ out = self.act(out)
208
+
209
+ if self.dropout:
210
+ out = self.dropout(out)
211
+
212
+ return out
213
+
214
+
215
+ class PromptResBlock(nn.Module):
216
+ """
217
+ Residual blocks
218
+ """
219
+ def __init__(self, in_channels, out_channels, stride=1, use_dropout=False, leaky=False, norm='INSTANCE'):
220
+ super().__init__()
221
+ self.norm_type = norm
222
+ self.act = nn.LeakyReLU() if leaky else nn.ReLU(inplace=False)
223
+ self.dropout = nn.Dropout2d(p=0.1) if use_dropout else None
224
+ # self.act = nn.ELU() if leaky else nn.ReLU(inplace=True)
225
+
226
+ self.conv1 = ConvNorm(in_channels, out_channels, 3, stride, leaky, norm, True)
227
+ self.conv2 = ConvNorm(out_channels, out_channels, 3, 1, leaky, norm, False)
228
+ self.attn = PromptAttentionModule(out_channels, 512, out_channels // 4)
229
+
230
+ need_map = in_channels != out_channels or stride != 1
231
+ self.id = ConvNorm(in_channels, out_channels, 1, stride, leaky, norm, False) if need_map else Identity()
232
+
233
+ def forward(self, x, prompt_in):
234
+ identity = x
235
+ out = self.conv1(x)
236
+ out = self.conv2(out)
237
+ out = self.attn(out, prompt_in)
238
+ identity = self.id(identity)
239
+
240
+ out = out + identity
241
+ if self.norm_type != 'GROUP':
242
+ out = self.act(out)
243
+
244
+ if self.dropout:
245
+ out = self.dropout(out)
246
+
247
+ return out
248
+
249
+
250
+ class PromptResBottleneck(nn.Module):
251
+ """
252
+ Residual bottleneck blocks
253
+ """
254
+ def __init__(self, in_channels, out_channels, stride=1, use_dropout=False, leaky=False, norm='INSTANCE'):
255
+ super().__init__()
256
+ self.norm_type = norm
257
+ middle_channels = in_channels // 4
258
+ self.act = nn.LeakyReLU() if leaky else nn.ReLU(inplace=False)
259
+ self.dropout = nn.Dropout2d(p=0.1) if use_dropout else None
260
+ # self.act = nn.ELU() if leaky else nn.ReLU(inplace=True)
261
+
262
+ self.conv1 = ConvNorm(in_channels, middle_channels, 1, 1, leaky, norm, True)
263
+ self.conv2 = ConvNorm(middle_channels, middle_channels, 3, stride, leaky, norm, True)
264
+ self.conv3 = ConvNorm(middle_channels, out_channels, 1, 1, leaky, norm, False)
265
+ self.attn = PromptAttentionModule(out_channels, 512, out_channels // 4)
266
+
267
+ need_map = in_channels != out_channels or stride != 1
268
+ self.id = ConvNorm(in_channels, out_channels, 1, stride, leaky, norm, False) if need_map else Identity()
269
+
270
+ def forward(self, x, prompt_in):
271
+ identity = x
272
+ out = self.conv1(x)
273
+ out = self.conv2(out)
274
+ out = self.conv3(out)
275
+ out = self.attn(out, prompt_in)
276
+ identity = self.id(identity)
277
+
278
+ out = out + identity
279
+ if self.norm_type != 'GROUP':
280
+ out = self.act(out)
281
+
282
+ if self.dropout:
283
+ out = self.dropout(out)
284
+
285
+ return out
286
+
287
+
288
+ class SobelEdge(nn.Module):
289
+ def __init__(self, input_dim, channels, kernel_size=3, stride=1):
290
+ super().__init__()
291
+ conv = getattr(nn, 'Conv%dd' % input_dim)
292
+ self.filter = conv(channels, channels, kernel_size, stride, padding=(kernel_size - 1) // 2,
293
+ groups=channels, bias=False)
294
+ sobel = [[1, 2, 1], [0, 0, 0], [-1, -2, -1]]
295
+ sobel_kernel = torch.tensor(sobel, dtype=torch.float32).unsqueeze(0).expand([channels, 1] + [kernel_size] * input_dim)
296
+ self.filter.weight = nn.Parameter(sobel_kernel, requires_grad=False)
297
+
298
+ def forward(self, x):
299
+ with torch.no_grad():
300
+ out = self.filter(x)
301
+ return out
base/base_segmentation.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import torch.distributed as dist
5
+ from torch.cuda.amp import GradScaler
6
+ from abc import ABC, abstractmethod
7
+ from utils.iteration.iterator import MetricMeter
8
+ from utils.ddp_utils import gather_object_across_processes
9
+
10
+
11
+ class BaseSegmentationModel(ABC):
12
+ """
13
+ This class is an abstract base class (ABC) for segmentation models.
14
+ To create a subclass, you need to implement the following four methods:
15
+ -- <__init__>: initialize the class.
16
+ -- <set_input>: unpack data from dataset.
17
+ -- <optimize_parameters>: calculate losses, gradients, and update network weights.
18
+ -- <evaluate_one_step>: performance evaluation.
19
+ """
20
+ def __init__(self, cfg, num_classes, amp=False):
21
+ # initialize training CUDA devices
22
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
23
+
24
+ # training configuration
25
+ self.cfg = cfg
26
+ self.num_classes = num_classes
27
+ self.is_mixed = amp
28
+ self.scaler = GradScaler()
29
+ self.start_epoch = -1
30
+
31
+ # initialize networks, criterion, optimizer and scheduler
32
+ self.network = None
33
+ self.criterion = None
34
+ self.optimizer = None
35
+ self.scheduler = None
36
+
37
+ # visualization
38
+ self.visual_names = []
39
+ self.loss_names = []
40
+
41
+ def train(self):
42
+ self.network.train()
43
+ return self
44
+
45
+ def eval(self):
46
+ self.network.eval()
47
+ return self
48
+
49
+ def training(self):
50
+ return self.network.training
51
+
52
+ def initialize_metric_meter(self, class_list):
53
+ self.class_list = class_list
54
+ self.metric_meter = MetricMeter(metrics=['dice', 'hd95', 'asd'], class_names=class_list, subject_names=['name'])
55
+ self.train_loss = MetricMeter(metrics=self.loss_names, class_names=['train'])
56
+ self.val_loss = MetricMeter(metrics=['loss'], class_names=['val'])
57
+
58
+ def update_loss_meter(self, print=False):
59
+ loss_dict = {}
60
+ for loss_name in self.loss_names:
61
+ try:
62
+ loss_value = float(getattr(self, loss_name))
63
+ loss_list = gather_object_across_processes(loss_value)
64
+ loss_value = np.mean(loss_list)
65
+ except:
66
+ continue
67
+ loss_dict['train_{}'.format(loss_name)] = loss_value
68
+ self.train_loss.update(loss_dict)
69
+ stats = self.train_loss.report(print_stats=print, mean_only=True)
70
+ return stats
71
+
72
+ @abstractmethod
73
+ def set_input(self, *args, **kwargs):
74
+ raise NotImplementedError
75
+
76
+ @abstractmethod
77
+ def optimize_parameters(self, *args, **kwargs):
78
+ raise NotImplementedError
79
+
80
+ @abstractmethod
81
+ def evaluate_one_step(self, *args, **kwargs):
82
+ raise NotImplementedError
83
+
84
+ def load_networks(self, ckpt_path, resume_training=False):
85
+ checkpoint = torch.load(ckpt_path, map_location=self.device)
86
+ print('Load ckpt weight: {}'.format(ckpt_path))
87
+ self.network.load_state_dict(checkpoint['net'])
88
+ if resume_training:
89
+ print('Load training config for breakpoint continuation')
90
+ self.optimizer.load_state_dict(checkpoint['optimizer'])
91
+ self.scheduler.load_state_dict(checkpoint['scheduler'])
92
+ self.scaler.load_state_dict(checkpoint['scaler'])
93
+ self.start_epoch = checkpoint['epoch']
94
+
95
+ def save_networks(self, epoch_index, save_dir):
96
+ if dist.get_rank() == 0:
97
+ checkpoint = {
98
+ "net": self.network.state_dict(),
99
+ 'optimizer': self.optimizer.state_dict(),
100
+ 'scheduler': self.scheduler.state_dict(),
101
+ 'scaler': self.scaler.state_dict(),
102
+ "epoch": epoch_index
103
+ }
104
+ torch.save(checkpoint,
105
+ os.path.join(save_dir, 'Epoch_{}.pkl'.format(epoch_index + 1)))
106
+
107
+
108
+ class MultiNetworkSegmentationModel(ABC):
109
+ """
110
+ This class is an abstract base class (ABC) for segmentation models.
111
+ To create a subclass, you need to implement the following four methods:
112
+ -- <__init__>: initialize the class.
113
+ -- <set_input>: unpack data from dataset.
114
+ -- <optimize_parameters>: calculate losses, gradients, and update network weights.
115
+ -- <evaluate_one_step>: performance evaluation.
116
+ """
117
+ def __init__(self, cfg, num_classes, amp=False):
118
+ # initialize training CUDA devices
119
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
120
+
121
+ # training configuration
122
+ self.cfg = cfg
123
+ self.num_classes = num_classes
124
+ self.is_mixed = amp
125
+ self.scaler = GradScaler()
126
+ self.start_epoch = -1
127
+
128
+ # initialize networks, criterion, optimizer and scheduler
129
+ self.net_names = []
130
+
131
+ # visualization
132
+ self.visual_names = []
133
+ self.loss_names = []
134
+
135
+ def train(self):
136
+ for name in self.net_names:
137
+ net = getattr(self, name)
138
+ net.train()
139
+ return self
140
+
141
+ def eval(self):
142
+ for name in self.net_names:
143
+ net = getattr(self, name)
144
+ net.eval()
145
+ return self
146
+
147
+ def training(self):
148
+ return getattr(self, self.net_names[0]).training
149
+
150
+ def initialize_metric_meter(self, class_list):
151
+ self.class_list = class_list
152
+ self.metric_meter = MetricMeter(metrics=['dice', 'hd95', 'asd'], class_names=class_list, subject_names=['name'])
153
+ self.train_loss = MetricMeter(metrics=self.loss_names, class_names=['train'])
154
+ self.val_loss = MetricMeter(metrics=['loss'], class_names=['val'])
155
+
156
+ def update_loss_meter(self, print=False):
157
+ loss_dict = {}
158
+ for loss_name in self.loss_names:
159
+ try:
160
+ loss_value = float(getattr(self, loss_name))
161
+ loss_list = gather_object_across_processes(loss_value)
162
+ loss_value = np.mean(loss_list)
163
+ except:
164
+ continue
165
+ loss_dict['train_{}'.format(loss_name)] = loss_value
166
+ self.train_loss.update(loss_dict)
167
+ stats = self.train_loss.report(print_stats=print, mean_only=True)
168
+ return stats
169
+
170
+ @abstractmethod
171
+ def set_input(self, *args, **kwargs):
172
+ raise NotImplementedError
173
+
174
+ @abstractmethod
175
+ def optimize_parameters(self, *args, **kwargs):
176
+ raise NotImplementedError
177
+
178
+ @abstractmethod
179
+ def evaluate_one_step(self, *args, **kwargs):
180
+ raise NotImplementedError
181
+
182
+ def load_networks(self, ckpt_path, resume_training=False, strict=True):
183
+ checkpoint = torch.load(ckpt_path, map_location=self.device)
184
+ print('Load ckpt weight: {}'.format(ckpt_path))
185
+ if resume_training:
186
+ print('Load training config for breakpoint continuation')
187
+ self.scaler.load_state_dict(checkpoint['scaler'])
188
+ self.start_epoch = checkpoint['epoch']
189
+ for name in self.net_names:
190
+ try:
191
+ getattr(self, name).load_state_dict(checkpoint[name], strict=strict)
192
+ if resume_training:
193
+ getattr(self, '{}_optimizer'.format(name)).load_state_dict(checkpoint['{}_optimizer'.format(name)])
194
+ getattr(self, '{}_scheduler'.format(name)).load_state_dict(checkpoint['{}_scheduler'.format(name)])
195
+ except:
196
+ print('Failed to load network: {}'.format(name))
197
+
198
+ def load_single_network(self, ckpt_path, net_name, resume_training=False, strict=True):
199
+ checkpoint = torch.load(ckpt_path, map_location=self.device)
200
+ print('Load ckpt weight: {}'.format(ckpt_path))
201
+ if resume_training:
202
+ print('Load training config for breakpoint continuation')
203
+ self.scaler.load_state_dict(checkpoint['scaler'])
204
+ self.start_epoch = checkpoint['epoch']
205
+ getattr(self, net_name).load_state_dict(checkpoint[net_name], strict=strict)
206
+ if resume_training:
207
+ getattr(self, '{}_optimizer'.format(net_name)).load_state_dict(checkpoint['{}_optimizer'.format(net_name)])
208
+ getattr(self, '{}_scheduler'.format(net_name)).load_state_dict(checkpoint['{}_scheduler'.format(net_name)])
209
+
210
+ def save_networks(self, epoch_index, save_dir):
211
+ if dist.get_rank() == 0:
212
+ checkpoint = {}
213
+ for name in self.net_names:
214
+ checkpoint[name] = getattr(self, name).state_dict()
215
+ checkpoint['{}_optimizer'.format(name)] = getattr(self, '{}_optimizer'.format(name)).state_dict()
216
+ checkpoint['{}_scheduler'.format(name)] = getattr(self, '{}_scheduler'.format(name)).state_dict()
217
+ checkpoint['scaler'] = self.scaler.state_dict()
218
+ checkpoint['epoch'] = epoch_index
219
+ torch.save(checkpoint, os.path.join(save_dir, 'Epoch_{}.pkl'.format(epoch_index)))
220
+
221
+ def save_best_networks(self, epoch_index, save_dir):
222
+ if dist.get_rank() == 0:
223
+ checkpoint = {}
224
+ for name in self.net_names:
225
+ checkpoint[name] = getattr(self, name).state_dict()
226
+ checkpoint['{}_optimizer'.format(name)] = getattr(self, '{}_optimizer'.format(name)).state_dict()
227
+ checkpoint['{}_scheduler'.format(name)] = getattr(self, '{}_scheduler'.format(name)).state_dict()
228
+ checkpoint['scaler'] = self.scaler.state_dict()
229
+ checkpoint['epoch'] = epoch_index
230
+ torch.save(checkpoint, os.path.join(save_dir, 'Epoch_best.pkl'))
base/base_wandb_model.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import wandb
2
+ import torch
3
+ import numpy as np
4
+ from monai.visualize import blend_images
5
+
6
+
7
+ class WandBModel:
8
+ """
9
+ Enable WandB features to the model using multiple inheritance
10
+ """
11
+ def __init__(self, *args, **kwargs):
12
+ # the following attributes should be initialized by class `BaseSegmentationModel`
13
+ self.visual_pairs = None
14
+ self.train_loss = None
15
+ self.val_loss = None
16
+ self.metric_meter = None
17
+ self.name = None
18
+ # the following attributes should be initialized by the child class
19
+ self.val_table = None
20
+
21
+ def volume2videos(self, time_dim=3, tag=''):
22
+ """
23
+ Convert 3D volumes to video in favor of WandB logging
24
+ Args:
25
+ time_dim: the spatial dimension to be converted as the time dimension, default is the axial axis (dim 3)
26
+ tag: extra information for logging
27
+ """
28
+ videos = []
29
+ for image_pair in self.visual_pairs:
30
+ try:
31
+ pair_name = getattr(self, image_pair['name'])
32
+ image = getattr(self, image_pair['image'])
33
+ mask = getattr(self, image_pair['mask'])
34
+ vis_type = image_pair['type']
35
+ except:
36
+ continue
37
+ for i in range(image.shape[0]): # deallocate the batch dim
38
+ image2save = image[i, ...]
39
+ mask2save = mask[i, ...]
40
+ item_name = pair_name[i]
41
+ # detach the tensor, format [C, H, W, D]
42
+ image_numpy = image2save.detach()
43
+ mask_numpy = mask2save.detach()
44
+ if mask_numpy.shape[0] > 1:
45
+ mask_numpy = torch.argmax(mask_numpy, dim=0, keepdim=True)
46
+ # (C, H, W, D), torch.Tensor on device
47
+ pair_blend = blend_images(image_numpy, mask_numpy, alpha=0.5) * 255
48
+ # permute the axes to (time, channel, height, width)
49
+ spatial_dim = list(range(1, len(pair_blend.shape[1:]) + 1))
50
+ spatial_dim.remove(time_dim)
51
+ pair_blend = pair_blend.permute([time_dim, 0] + spatial_dim).cpu().numpy().astype(np.uint8)
52
+ # record in the wandb.Video class
53
+ video = wandb.Video(pair_blend, fps=8, caption='{}_{}{}'.format(item_name, vis_type, tag))
54
+ videos.append(video)
55
+ return videos
56
+
57
+ def log_scaler(self, key, value, step=None):
58
+ """
59
+ Log manually defined scaler data
60
+ """
61
+ wandb.log({key: np.round(value, decimals=4)}, step=step)
62
+
63
+ def log_train_loss(self, step=None):
64
+ """
65
+ Log train loss
66
+ """
67
+ data_dict = self.train_loss.pop_data(True)
68
+ for key, value in data_dict.items():
69
+ wandb.log({'train/{}'.format(key): value}, step=step)
70
+
71
+ def log_val_loss(self, step=None):
72
+ """
73
+ Log val loss
74
+ """
75
+ data_dict = self.val_loss.pop_data(True)
76
+ for key, value in data_dict.items():
77
+ wandb.log({'val/{}'.format(key): value}, step=step)
78
+
79
+ def log_metrics(self, step=None):
80
+ """
81
+ Log validation metrics as wandb.Table
82
+ """
83
+ df = self.metric_meter.to_df()
84
+ wandb.log({'val/metrics': wandb.Table(dataframe=df)}, step=step)
85
+
86
+ def log_vis(self, key, step=None, time_dim=3, tag=''):
87
+ """
88
+ Log training intermediate visualizations
89
+ """
90
+ videos = self.volume2videos(time_dim, tag)
91
+ wandb.log({key: videos}, step=step)
92
+
93
+ def update_val_visualization(self, time_dim=3, tag=''):
94
+ """
95
+ Update the validation visualization to buffer, called every step of evaluation
96
+ """
97
+ videos = self.volume2videos(time_dim, tag)
98
+ self.val_table.add_data(self.name, *videos)
99
+
100
+ def log_val_visualization(self, step=None):
101
+ """
102
+ Log validation visualization
103
+ """
104
+ wandb.log({'val/visualization': self.val_table}, step=step)
105
+ # re-initialize the table for next logging
106
+ del self.val_table
107
+ self.val_table = wandb.Table(columns=['ID'] + [pair['type'] for pair in self.visual_pairs])
checkpoint/stage_i.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aaad39f42a0f916ba44b39071dcfbf1145ee43f6f5a269e3f4364b81d361d794
3
+ size 494807162
checkpoint/stage_ii.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:155209bbe587905366b100cf2e8fadc9e8b9c672a0920eb848fcb80a3fcd5e8c
3
+ size 425297586
ckpt/BiomedCLIP/biomed-vlp-eval.svg ADDED
ckpt/BiomedCLIP/biomed_clip_example.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
ckpt/BiomedCLIP/config.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertForMaskedLM"
4
+ ],
5
+ "model_type": "bert",
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "hidden_act": "gelu",
8
+ "hidden_dropout_prob": 0.1,
9
+ "hidden_size": 768,
10
+ "initializer_range": 0.02,
11
+ "intermediate_size": 3072,
12
+ "max_position_embeddings": 512,
13
+ "num_attention_heads": 12,
14
+ "num_hidden_layers": 12,
15
+ "type_vocab_size": 2,
16
+ "vocab_size": 30522
17
+ }
ckpt/BiomedCLIP/open_clip_config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_cfg": {
3
+ "embed_dim": 512,
4
+ "vision_cfg": {
5
+ "timm_model_name": "vit_base_patch16_224",
6
+ "timm_model_pretrained": false,
7
+ "timm_pool": "",
8
+ "timm_proj": "linear",
9
+ "image_size": 224
10
+ },
11
+ "text_cfg": {
12
+ "hf_model_name": "./ckpt/BiomedCLIP/",
13
+ "hf_tokenizer_name": "./ckpt/BiomedCLIP/",
14
+ "hf_proj_type": "mlp",
15
+ "hf_pooler_type": "cls_last_hidden_state_pooler",
16
+ "context_length": 77
17
+ }
18
+ },
19
+ "preprocess_cfg": {
20
+ "mean": [
21
+ 0.48145466,
22
+ 0.4578275,
23
+ 0.40821073
24
+ ],
25
+ "std": [
26
+ 0.26862954,
27
+ 0.26130258,
28
+ 0.27577711
29
+ ]
30
+ }
31
+ }
ckpt/BiomedCLIP/special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "mask_token": "[MASK]",
4
+ "pad_token": "[PAD]",
5
+ "sep_token": "[SEP]",
6
+ "unk_token": "[UNK]"
7
+ }
ckpt/BiomedCLIP/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
ckpt/BiomedCLIP/tokenizer_config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "clean_up_tokenization_spaces": true,
3
+ "cls_token": "[CLS]",
4
+ "do_basic_tokenize": true,
5
+ "do_lower_case": true,
6
+ "mask_token": "[MASK]",
7
+ "model_max_length": 1000000000000000019884624838656,
8
+ "never_split": null,
9
+ "pad_token": "[PAD]",
10
+ "sep_token": "[SEP]",
11
+ "strip_accents": null,
12
+ "tokenize_chinese_chars": true,
13
+ "tokenizer_class": "BertTokenizer",
14
+ "unk_token": "[UNK]"
15
+ }
ckpt/BiomedCLIP/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
configs/confocal.cfg ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL:
2
+ IMG_SIZE: 1024
3
+ CROP_SIZE: 320
4
+ CONFOCAL_IN: 3 # 3-channel microscope file
5
+ CONFOCAL_OUT: 1 # nuclei
6
+ IMC_IN: 4 # 3-channel microscope file + 1 channel confocal
7
+ IMC_OUT: 1 # nuclei
8
+ GRAD_CKPT: True
9
+ TIMM_MODEL: none
10
+ NORM: INSTANCE
11
+ CONFOCAL_PATH: none
12
+ IMC_PATH: none
13
+ TRAIN:
14
+ LR_G: 0.0002
15
+ LR_D: 0.0002
16
+ DECAY: 0.0
17
+ BETA1: 0.5
18
+ EARLY_STAGE: 0
19
+ BURN_IN: 0
20
+ BURN: 500
21
+ RAMPUP: 1000
22
+ EPOCHS: 1000
23
+ BATCHSIZE: 16
24
+ CROP_SAMPLE_NUM: 16
25
+ RATIO: 0.2
26
+ SEED: 42
27
+ PERTURB_PROB: 0.1
28
+ IMC_RATIO: 100.0
29
+ CON_RATIO: 100.0
30
+ SIM_RATIO: 50.0
31
+ EDGE_RATIO: 100.0
32
+ ADV_RATIO: 1.0
33
+ CLR_RATIO: 0.0
34
+ FREQ_RATIO: 0.00001
35
+ TEST:
36
+ BATCHSIZE: 32
configs/confocal_marker.cfg ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL:
2
+ IMG_SIZE: 1024
3
+ CROP_SIZE: 320
4
+ IMC_IN: 1
5
+ IMC_OUT: 1
6
+ GRAD_CKPT: True
7
+ PRETRAIN: none
8
+ TRAIN:
9
+ LR_G: 0.002
10
+ LR_D: 0.002
11
+ DECAY: 0.0
12
+ BETA1: 0.5
13
+ EARLY_STAGE: 0
14
+ BURN_IN: 0
15
+ BURN: 100
16
+ RAMPUP: 100
17
+ EPOCHS: 100
18
+ BATCHSIZE: 8
19
+ CROP_SAMPLE_NUM: 8
20
+ RATIO: 0.2
21
+ SEED: 42
22
+ IMC_RATIO: 100.0
23
+ EDGE_RATIO: 10.0
24
+ ADV_RATIO: 1.0
25
+ TEST:
26
+ BATCHSIZE: 16
configs/convertion.cfg ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL:
2
+ IMG_SIZE: 1024
3
+ CROP_SIZE: 320
4
+ IMC_IN: 1
5
+ IMC_OUT: 1
6
+ GRAD_CKPT: True
7
+ TRAIN:
8
+ LR_G: 0.002
9
+ LR_D: 0.002
10
+ DECAY: 0.0
11
+ BETA1: 0.5
12
+ EARLY_STAGE: 0
13
+ BURN_IN: 0
14
+ BURN: 100
15
+ RAMPUP: 100
16
+ EPOCHS: 100
17
+ BATCHSIZE: 16
18
+ CROP_SAMPLE_NUM: 8
19
+ RATIO: 0.2
20
+ SEED: 42
21
+ IMC_RATIO: 100.0
22
+ EDGE_RATIO: 10.0
23
+ ADV_RATIO: 1.0
24
+ TEST:
25
+ BATCHSIZE: 64
configs/extend_1.cfg ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL:
2
+ IMG_SIZE: 1024
3
+ CROP_SIZE: 320
4
+ CONFOCAL_IN: 3 # 3-channel microscope file
5
+ CONFOCAL_OUT: 1 # nuclei
6
+ IMC_IN: 4 # 3-channel microscope file + 1 channel confocal
7
+ IMC_OUT: 1 # nuclei
8
+ GRAD_CKPT: True
9
+ TIMM_MODEL: none
10
+ NORM: INSTANCE
11
+ CONFOCAL_PATH: none
12
+ IMC_PATH: none
13
+ TRAIN:
14
+ LR_G: 0.0002
15
+ LR_D: 0.0002
16
+ DECAY: 0.0
17
+ BETA1: 0.5
18
+ EARLY_STAGE: 0
19
+ BURN_IN: 0
20
+ BURN: 500
21
+ RAMPUP: 1000
22
+ EPOCHS: 1000
23
+ BATCHSIZE: 16
24
+ CROP_SAMPLE_NUM: 16
25
+ RATIO: 0.2
26
+ SEED: 42
27
+ PERTURB_PROB: 0.1
28
+ IMC_RATIO: 100.0
29
+ CON_RATIO: 100.0
30
+ SIM_RATIO: 50.0
31
+ EDGE_RATIO: 100.0
32
+ ADV_RATIO: 1.0
33
+ CLR_RATIO: 0.0
34
+ FREQ_RATIO: 0.00001
35
+ TEST:
36
+ BATCHSIZE: 32
configs/extend_2.cfg ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL:
2
+ IMG_SIZE: 1024
3
+ CROP_SIZE: 320
4
+ IMC_IN: 1
5
+ IMC_OUT: 1
6
+ GRAD_CKPT: True
7
+ PRETRAIN: /mnt/shared_storage/zhaoxiangyu/experiments/IMC_translation_v2/checkpoints/convertion/convertion_0918-task_convertion-ratio_0.2/Epoch_39.pkl
8
+ TRAIN:
9
+ LR_G: 0.002
10
+ LR_D: 0.002
11
+ DECAY: 0.0
12
+ BETA1: 0.5
13
+ EARLY_STAGE: 0
14
+ BURN_IN: 0
15
+ BURN: 100
16
+ RAMPUP: 100
17
+ EPOCHS: 100
18
+ BATCHSIZE: 8
19
+ CROP_SAMPLE_NUM: 8
20
+ RATIO: 0.2
21
+ SEED: 42
22
+ IMC_RATIO: 100.0
23
+ EDGE_RATIO: 10.0
24
+ ADV_RATIO: 1.0
25
+ TEST:
26
+ BATCHSIZE: 16
configs/full.cfg ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL:
2
+ IMG_SIZE: 1024
3
+ CROP_SIZE: 320
4
+ CONFOCAL_IN: 3 # 3-channel microscope file
5
+ CONFOCAL_OUT: 1 # nuclei
6
+ IMC_IN: 4 # 3-channel microscope file + 1 channel confocal
7
+ IMC_OUT: 1 # nuclei
8
+ CONVERTION_IN: 1
9
+ CONVERTION_OUT: 1
10
+ GRAD_CKPT: True
11
+ TIMM_MODEL: none
12
+ NORM: INSTANCE
13
+ CONFOCAL_PATH: none
14
+ IMC_PATH: none
15
+ CONVERTION_PATH: /mnt/shared_storage/zhaoxiangyu/experiments/IMC_translation_v2/checkpoints/convertion/convertion_0918-task_convertion-ratio_0.2/Epoch_39.pkl
16
+ TRAIN:
17
+ LR_G: 0.0002
18
+ LR_D: 0.0002
19
+ DECAY: 0.0
20
+ BETA1: 0.5
21
+ EARLY_STAGE: 0
22
+ BURN_IN: 0
23
+ BURN: 500
24
+ RAMPUP: 1000
25
+ EPOCHS: 1000
26
+ BATCHSIZE: 16
27
+ CROP_SAMPLE_NUM: 16
28
+ RATIO: 0.2
29
+ SEED: 42
30
+ PERTURB_PROB: 0.1
31
+ IMC_RATIO: 100.0
32
+ CON_RATIO: 100.0
33
+ SIM_RATIO: 50.0
34
+ EDGE_RATIO: 100.0
35
+ ADV_RATIO: 1.0
36
+ CLR_RATIO: 0.0
37
+ FREQ_RATIO: 0.00001
38
+ TEST:
39
+ BATCHSIZE: 32
configs/imc.cfg ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL:
2
+ IMG_SIZE: 1024
3
+ CROP_SIZE: 320
4
+ CONFOCAL_IN: 3 # 3-channel microscope file
5
+ CONFOCAL_OUT: 1 # nuclei
6
+ IMC_IN: 4 # 3-channel microscope file + 1 channel confocal
7
+ IMC_OUT: 1 # nuclei
8
+ GRAD_CKPT: True
9
+ TIMM_MODEL: none
10
+ NORM: INSTANCE
11
+ CONFOCAL_PATH: none
12
+ IMC_PATH: none
13
+ TRAIN:
14
+ LR_G: 0.0002
15
+ LR_D: 0.0002
16
+ DECAY: 0.0
17
+ BETA1: 0.5
18
+ EARLY_STAGE: 0
19
+ BURN_IN: 0
20
+ BURN: 500
21
+ RAMPUP: 1000
22
+ EPOCHS: 1000
23
+ BATCHSIZE: 16
24
+ CROP_SAMPLE_NUM: 16
25
+ RATIO: 0.2
26
+ SEED: 42
27
+ PERTURB_PROB: 0.1
28
+ IMC_RATIO: 100.0
29
+ CON_RATIO: 100.0
30
+ SIM_RATIO: 50.0
31
+ EDGE_RATIO: 100.0
32
+ ADV_RATIO: 1.0
33
+ CLR_RATIO: 0.0
34
+ FREQ_RATIO: 0.00001
35
+ TEST:
36
+ BATCHSIZE: 32
configs/translation.cfg ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL:
2
+ IMG_SIZE: 1024
3
+ CROP_SIZE: 320
4
+ CONFOCAL_IN: 3 # 3-channel microscope file
5
+ CONFOCAL_OUT: 1 # nuclei
6
+ IMC_IN: 4 # 3-channel microscope file + 1 channel confocal
7
+ IMC_OUT: 1 # nuclei
8
+ GRAD_CKPT: True
9
+ TIMM_MODEL: none
10
+ NORM: INSTANCE
11
+ CONFOCAL_PATH: none
12
+ IMC_PATH: none
13
+ TRAIN:
14
+ LR_G: 0.0002
15
+ LR_D: 0.0002
16
+ DECAY: 0.0
17
+ BETA1: 0.5
18
+ EARLY_STAGE: 0
19
+ BURN_IN: 0
20
+ BURN: 500
21
+ RAMPUP: 1000
22
+ EPOCHS: 1000
23
+ BATCHSIZE: 16
24
+ CROP_SAMPLE_NUM: 16
25
+ RATIO: 0.2
26
+ SEED: 42
27
+ PERTURB_PROB: 0.1
28
+ IMC_RATIO: 100.0
29
+ CON_RATIO: 100.0
30
+ SIM_RATIO: 50.0
31
+ EDGE_RATIO: 100.0
32
+ ADV_RATIO: 1.0
33
+ CLR_RATIO: 0.0
34
+ FREQ_RATIO: 0.00001
35
+ TEST:
36
+ BATCHSIZE: 32
markers.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ breast_markers = ['HER2',
2
+ 'TAPAN8',
3
+ 'CD15',
4
+ 'CD206',
5
+ 'CD11b',
6
+ 'HLA_DR',
7
+ 'H3',
8
+ 'CD8a',
9
+ 'ISG15',
10
+ 'CD14',
11
+ 'ZC3HV1',
12
+ 'Collagen1',
13
+ 'CD4',
14
+ 'CD66b',
15
+ 'ALDH1',
16
+ 'FOXP3',
17
+ 'SMA',
18
+ 'CD24',
19
+ 'CD44',
20
+ 'CD54',
21
+ 'PPARG',
22
+ 'CD31',
23
+ 'PD1',
24
+ 'CD19',
25
+ 'CD69',
26
+ 'PKCD',
27
+ 'Ki67',
28
+ 'ER',
29
+ 'CD11c',
30
+ 'CD27',
31
+ 'LPS',
32
+ 'CD11a',
33
+ 'PR',
34
+ 'CD3',
35
+ 'CD68',
36
+ 'CD83',
37
+ 'LTA',
38
+ 'IFI6',
39
+ 'CD45',
40
+ 'CDH1',
41
+ 'CD62L']
42
+
43
+ pancreatic_markers = ['PGAM1',
44
+ 'CD44',
45
+ 'Amy2A',
46
+ 'PGK1',
47
+ 'PGAM5',
48
+ 'CD99',
49
+ 'CoL1',
50
+ 'TALDO',
51
+ 'ALDOB',
52
+ 'ALDO',
53
+ 'HK2',
54
+ 'HK3',
55
+ 'TPI',
56
+ 'PKM',
57
+ 'LDH',
58
+ 'CK7',
59
+ 'PDPN',
60
+ 'HK1',
61
+ 'NSE',
62
+ 'AMF',
63
+ 'PFKM',
64
+ 'CD45',
65
+ 'PGAM4',
66
+ 'GAPDH',
67
+ 'CD31',
68
+ 'ECAD',
69
+ 'PGAM2',
70
+ 'aSMA',
71
+ 'LDHB']
72
+
73
+ prostatic_markers = ['CXCR4',
74
+ 'EGFR',
75
+ 'LAG-3',
76
+ 'CD278',
77
+ 'PSMA',
78
+ 'CD15',
79
+ 'CD134',
80
+ 'CTLA4',
81
+ 'Nestin',
82
+ 'CD16',
83
+ 'CD56',
84
+ 'PD-1',
85
+ 'CD11b',
86
+ 'CD66a',
87
+ 'CXCL12',
88
+ 'CCR7',
89
+ 'IDO',
90
+ 'CD73',
91
+ 'CD33',
92
+ 'VEGF',
93
+ 'CD8a',
94
+ 'aSMA',
95
+ 'CD14',
96
+ 'AMACR',
97
+ 'CD20',
98
+ 'Ki-67',
99
+ 'CD4',
100
+ 'SOX-9',
101
+ 'B7-H4',
102
+ 'CD11C',
103
+ 'IFNgamma',
104
+ 'CD25',
105
+ 'Pan-Keratin',
106
+ 'Pan-Actin',
107
+ 'CD45AR',
108
+ 'CD74',
109
+ 'CD276',
110
+ 'HLA-DR',
111
+ 'CD31',
112
+ 'CD45RO',
113
+ 'TGFbeta',
114
+ 'CD366',
115
+ 'CD19',
116
+ 'PSA',
117
+ 'Foxp3',
118
+ 'EpCAM',
119
+ 'GranzymeB',
120
+ 'BCL-2',
121
+ 'ARG1',
122
+ 'CD27',
123
+ 'hFAP',
124
+ 'PDL-2',
125
+ 'Keratin8',
126
+ 'PDL-1',
127
+ 'CD127',
128
+ 'CD304',
129
+ 'CD3',
130
+ 'CD68',
131
+ 'AR',
132
+ 'CD45',
133
+ 'Vista',
134
+ 'CD62L',
135
+ 'CD163',
136
+ 'pan-actin']
models/modules/biomedclip.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import torch.nn as nn
3
+ from open_clip.factory import *
4
+
5
+
6
+ # def create_model_and_transforms(
7
+ # model_name: str,
8
+ # config: str,
9
+ # device: Union[str, torch.device] = 'cpu',
10
+ # cache_dir: Optional[str] = None,
11
+ # force_preprocess_cfg: Optional[Dict[str, Any]] = None,
12
+ # ):
13
+ # force_preprocess_cfg = force_preprocess_cfg or {}
14
+ # preprocess_cfg = asdict(PreprocessCfg())
15
+ # with open(config, 'r') as f:
16
+ # config = json.load(f)
17
+
18
+ # checkpoint_path = os.path.join(cache_dir, 'open_clip_pytorch_model.bin')
19
+ # preprocess_cfg = merge_preprocess_dict(preprocess_cfg, config['preprocess_cfg'])
20
+ # model_cfg = config['model_cfg']
21
+
22
+ # if isinstance(device, str):
23
+ # device = torch.device(device)
24
+ # print(f'Loaded {model_name} model config.')
25
+
26
+ # # load pretrained weights for HF text model IFF no CLIP weights being loaded
27
+ # model_cfg['text_cfg']['hf_model_pretrained'] = False
28
+
29
+ # model = CustomTextCLIP(**model_cfg)
30
+ # model.to(device=device)
31
+
32
+ # print(f'Loading pretrained {model_name} weights ({checkpoint_path}).')
33
+ # load_checkpoint(model, checkpoint_path)
34
+
35
+ # # set image preprocessing configuration in model attributes for convenience
36
+ # if getattr(model.visual, 'image_size', None) is not None:
37
+ # # use image_size set on model creation (via config or force_image_size arg)
38
+ # force_preprocess_cfg['size'] = model.visual.image_size
39
+ # set_model_preprocess_cfg(model, merge_preprocess_dict(preprocess_cfg, force_preprocess_cfg))
40
+
41
+ # pp_cfg = PreprocessCfg(**model.visual.preprocess_cfg)
42
+
43
+ # preprocess_train = image_transform_v2(
44
+ # pp_cfg,
45
+ # is_train=True,
46
+ # aug_cfg=None,
47
+ # )
48
+ # preprocess_val = image_transform_v2(
49
+ # pp_cfg,
50
+ # is_train=False,
51
+ # )
52
+
53
+ # return model, preprocess_train, preprocess_val
54
+
55
+
56
+ def get_my_tokenizer(
57
+ config: str,
58
+ context_length: Optional[int] = None,
59
+ **kwargs,
60
+ ):
61
+ with open(config, 'r') as f:
62
+ config = json.load(f)
63
+
64
+ text_config = config['model_cfg']['text_cfg']
65
+ if 'tokenizer_kwargs' in text_config:
66
+ tokenizer_kwargs = dict(text_config['tokenizer_kwargs'], **kwargs)
67
+ else:
68
+ tokenizer_kwargs = kwargs
69
+
70
+ if context_length is None:
71
+ context_length = text_config.get('context_length', DEFAULT_CONTEXT_LENGTH)
72
+
73
+ if 'hf_tokenizer_name' in text_config:
74
+ tokenizer = HFTokenizer(
75
+ text_config['hf_tokenizer_name'],
76
+ context_length=context_length,
77
+ **tokenizer_kwargs,
78
+ )
79
+ else:
80
+ tokenizer = SimpleTokenizer(
81
+ context_length=context_length,
82
+ **tokenizer_kwargs,
83
+ )
84
+
85
+ return tokenizer
86
+
87
+
88
+ class BiomedCLIPTextEncoder(nn.Module):
89
+ def __init__(self, device: torch.device) -> None:
90
+ super().__init__()
91
+ # self.model, _, _ = create_model_and_transforms(
92
+ # model_name='hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224',
93
+ # # config='./ckpt/BiomedCLIP/open_clip_config.json',
94
+ # cache_dir='./ckpt/BiomedCLIP/'
95
+ # )
96
+ self.model, _, _ = create_model_and_transforms('hf-hub:hsiangyualex/biomedclip4imc')
97
+ self.model.eval()
98
+ self.model.to(device)
99
+ for param in self.model.parameters():
100
+ param.requires_grad = False
101
+ # self.tokenizer = get_my_tokenizer(config='./ckpt/BiomedCLIP/open_clip_config.json')
102
+ self.tokenizer = get_tokenizer('hf-hub:hsiangyualex/biomedclip4imc')
103
+ self.device = device
104
+
105
+ @torch.no_grad()
106
+ def forward(self, prompts):
107
+ """
108
+ Args:
109
+ prompts: a series of protein names
110
+ """
111
+ prompts = [f"An imaging mass cytometry staining image of {prompt} protein." for prompt in prompts]
112
+ prompts = self.tokenizer(prompts).to(self.device)
113
+ text_features = self.model.encode_text(prompts).detach()
114
+ return text_features
models/modules/dct.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ try:
6
+ # PyTorch 1.7.0 and newer versions
7
+ import torch.fft
8
+
9
+ def dct1_rfft_impl(x):
10
+ return torch.view_as_real(torch.fft.rfft(x, dim=1))
11
+
12
+ def dct_fft_impl(v):
13
+ return torch.view_as_real(torch.fft.fft(v, dim=1))
14
+
15
+ def idct_irfft_impl(V):
16
+ return torch.fft.irfft(torch.view_as_complex(V), n=V.shape[1], dim=1)
17
+ except ImportError:
18
+ # PyTorch 1.6.0 and older versions
19
+ def dct1_rfft_impl(x):
20
+ return torch.rfft(x, 1)
21
+
22
+ def dct_fft_impl(v):
23
+ return torch.rfft(v, 1, onesided=False)
24
+
25
+ def idct_irfft_impl(V):
26
+ return torch.irfft(V, 1, onesided=False)
27
+
28
+
29
+
30
+ def dct1(x):
31
+ """
32
+ Discrete Cosine Transform, Type I
33
+
34
+ :param x: the input signal
35
+ :return: the DCT-I of the signal over the last dimension
36
+ """
37
+ x_shape = x.shape
38
+ x = x.view(-1, x_shape[-1])
39
+ x = torch.cat([x, x.flip([1])[:, 1:-1]], dim=1)
40
+
41
+ return dct1_rfft_impl(x)[:, :, 0].view(*x_shape)
42
+
43
+
44
+ def idct1(X):
45
+ """
46
+ The inverse of DCT-I, which is just a scaled DCT-I
47
+
48
+ Our definition if idct1 is such that idct1(dct1(x)) == x
49
+
50
+ :param X: the input signal
51
+ :return: the inverse DCT-I of the signal over the last dimension
52
+ """
53
+ n = X.shape[-1]
54
+ return dct1(X) / (2 * (n - 1))
55
+
56
+
57
+ def dct(x, norm=None):
58
+ """
59
+ Discrete Cosine Transform, Type II (a.k.a. the DCT)
60
+
61
+ For the meaning of the parameter `norm`, see:
62
+ https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
63
+
64
+ :param x: the input signal
65
+ :param norm: the normalization, None or 'ortho'
66
+ :return: the DCT-II of the signal over the last dimension
67
+ """
68
+ x_shape = x.shape
69
+ N = x_shape[-1]
70
+ x = x.contiguous().view(-1, N)
71
+
72
+ v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1)
73
+
74
+ Vc = dct_fft_impl(v)
75
+
76
+ k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N)
77
+ W_r = torch.cos(k)
78
+ W_i = torch.sin(k)
79
+
80
+ V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i
81
+
82
+ if norm == 'ortho':
83
+ V[:, 0] /= np.sqrt(N) * 2
84
+ V[:, 1:] /= np.sqrt(N / 2) * 2
85
+
86
+ V = 2 * V.view(*x_shape)
87
+
88
+ return V
89
+
90
+
91
+ def idct(X, norm=None):
92
+ """
93
+ The inverse to DCT-II, which is a scaled Discrete Cosine Transform, Type III
94
+
95
+ Our definition of idct is that idct(dct(x)) == x
96
+
97
+ For the meaning of the parameter `norm`, see:
98
+ https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
99
+
100
+ :param X: the input signal
101
+ :param norm: the normalization, None or 'ortho'
102
+ :return: the inverse DCT-II of the signal over the last dimension
103
+ """
104
+
105
+ x_shape = X.shape
106
+ N = x_shape[-1]
107
+
108
+ X_v = X.contiguous().view(-1, x_shape[-1]) / 2
109
+
110
+ if norm == 'ortho':
111
+ X_v[:, 0] *= np.sqrt(N) * 2
112
+ X_v[:, 1:] *= np.sqrt(N / 2) * 2
113
+
114
+ k = torch.arange(x_shape[-1], dtype=X.dtype, device=X.device)[None, :] * np.pi / (2 * N)
115
+ W_r = torch.cos(k)
116
+ W_i = torch.sin(k)
117
+
118
+ V_t_r = X_v
119
+ V_t_i = torch.cat([X_v[:, :1] * 0, -X_v.flip([1])[:, :-1]], dim=1)
120
+
121
+ V_r = V_t_r * W_r - V_t_i * W_i
122
+ V_i = V_t_r * W_i + V_t_i * W_r
123
+
124
+ V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2)
125
+
126
+ v = idct_irfft_impl(V)
127
+ x = v.new_zeros(v.shape)
128
+ x[:, ::2] += v[:, :N - (N // 2)]
129
+ x[:, 1::2] += v.flip([1])[:, :N // 2]
130
+
131
+ return x.view(*x_shape)
132
+
133
+
134
+ def dct_2d(x, norm=None):
135
+ """
136
+ 2-dimentional Discrete Cosine Transform, Type II (a.k.a. the DCT)
137
+
138
+ For the meaning of the parameter `norm`, see:
139
+ https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
140
+
141
+ :param x: the input signal
142
+ :param norm: the normalization, None or 'ortho'
143
+ :return: the DCT-II of the signal over the last 2 dimensions
144
+ """
145
+ X1 = dct(x, norm=norm)
146
+ X2 = dct(X1.transpose(-1, -2), norm=norm)
147
+ return X2.transpose(-1, -2)
148
+
149
+
150
+ def idct_2d(X, norm=None):
151
+ """
152
+ The inverse to 2D DCT-II, which is a scaled Discrete Cosine Transform, Type III
153
+
154
+ Our definition of idct is that idct_2d(dct_2d(x)) == x
155
+
156
+ For the meaning of the parameter `norm`, see:
157
+ https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
158
+
159
+ :param X: the input signal
160
+ :param norm: the normalization, None or 'ortho'
161
+ :return: the DCT-II of the signal over the last 2 dimensions
162
+ """
163
+ x1 = idct(X, norm=norm)
164
+ x2 = idct(x1.transpose(-1, -2), norm=norm)
165
+ return x2.transpose(-1, -2)
166
+
167
+
168
+ def dct_3d(x, norm=None):
169
+ """
170
+ 3-dimentional Discrete Cosine Transform, Type II (a.k.a. the DCT)
171
+
172
+ For the meaning of the parameter `norm`, see:
173
+ https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
174
+
175
+ :param x: the input signal
176
+ :param norm: the normalization, None or 'ortho'
177
+ :return: the DCT-II of the signal over the last 3 dimensions
178
+ """
179
+ X1 = dct(x, norm=norm)
180
+ X2 = dct(X1.transpose(-1, -2), norm=norm)
181
+ X3 = dct(X2.transpose(-1, -3), norm=norm)
182
+ return X3.transpose(-1, -3).transpose(-1, -2)
183
+
184
+
185
+ def idct_3d(X, norm=None):
186
+ """
187
+ The inverse to 3D DCT-II, which is a scaled Discrete Cosine Transform, Type III
188
+
189
+ Our definition of idct is that idct_3d(dct_3d(x)) == x
190
+
191
+ For the meaning of the parameter `norm`, see:
192
+ https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
193
+
194
+ :param X: the input signal
195
+ :param norm: the normalization, None or 'ortho'
196
+ :return: the DCT-II of the signal over the last 3 dimensions
197
+ """
198
+ x1 = idct(X, norm=norm)
199
+ x2 = idct(x1.transpose(-1, -2), norm=norm)
200
+ x3 = idct(x2.transpose(-1, -3), norm=norm)
201
+ return x3.transpose(-1, -3).transpose(-1, -2)
202
+
203
+
204
+ class LinearDCT(nn.Linear):
205
+ """Implement any DCT as a linear layer; in practice this executes around
206
+ 50x faster on GPU. Unfortunately, the DCT matrix is stored, which will
207
+ increase memory usage.
208
+ :param in_features: size of expected input
209
+ :param type: which dct function in this file to use"""
210
+ def __init__(self, in_features, type, norm=None, bias=False):
211
+ self.type = type
212
+ self.N = in_features
213
+ self.norm = norm
214
+ super(LinearDCT, self).__init__(in_features, in_features, bias=bias)
215
+
216
+ def reset_parameters(self):
217
+ # initialise using dct function
218
+ I = torch.eye(self.N)
219
+ if self.type == 'dct1':
220
+ self.weight.data = dct1(I).data.t()
221
+ elif self.type == 'idct1':
222
+ self.weight.data = idct1(I).data.t()
223
+ elif self.type == 'dct':
224
+ self.weight.data = dct(I, norm=self.norm).data.t()
225
+ elif self.type == 'idct':
226
+ self.weight.data = idct(I, norm=self.norm).data.t()
227
+ self.weight.requires_grad = False # don't learn this!
228
+
229
+
230
+ def apply_linear_2d(x, linear_layer):
231
+ """Can be used with a LinearDCT layer to do a 2D DCT.
232
+ :param x: the input signal
233
+ :param linear_layer: any PyTorch Linear layer
234
+ :return: result of linear layer applied to last 2 dimensions
235
+ """
236
+ X1 = linear_layer(x)
237
+ X2 = linear_layer(X1.transpose(-1, -2))
238
+ return X2.transpose(-1, -2)
239
+
240
+ def apply_linear_3d(x, linear_layer):
241
+ """Can be used with a LinearDCT layer to do a 3D DCT.
242
+ :param x: the input signal
243
+ :param linear_layer: any PyTorch Linear layer
244
+ :return: result of linear layer applied to last 3 dimensions
245
+ """
246
+ X1 = linear_layer(x)
247
+ X2 = linear_layer(X1.transpose(-1, -2))
248
+ X3 = linear_layer(X2.transpose(-1, -3))
249
+ return X3.transpose(-1, -3).transpose(-1, -2)
250
+
251
+
252
+ class DCTHelper(nn.Module):
253
+ """
254
+ Implement DCT operations and corresponding masking.
255
+ """
256
+ def __init__(self, side_length: int, norm: str = None, cutoff: float = 0.8, data_range: tuple = (-1.0, 1.0)):
257
+ """
258
+ Args:
259
+ side_length: the side length of the image
260
+ norm: the normalization, None or 'ortho'
261
+ cutoff: the cutoff frequency ratio for low-pass filtering
262
+ """
263
+ super().__init__()
264
+ self.dct = LinearDCT(side_length, 'dct')
265
+ self.idct = LinearDCT(side_length, 'idct')
266
+ mask = self.create_circular_mask(side_length, side_length, radius=side_length * cutoff, center=(0, 0))
267
+ self.register_buffer('mask', torch.from_numpy(mask).float()[None, None, ...])
268
+ self.data_range = data_range
269
+
270
+ @staticmethod
271
+ def create_circular_mask(h, w, center=None, radius=None):
272
+ if center is None: # use the middle of the image
273
+ center = (int(w/2), int(h/2))
274
+ if radius is None: # use the smallest distance between the center and image walls
275
+ radius = min(center[0], center[1], w-center[0], h-center[1])
276
+
277
+ Y, X = np.ogrid[:h, :w]
278
+ dist_from_center = np.sqrt((X - center[0])**2 + (Y-center[1])**2)
279
+
280
+ mask = dist_from_center <= radius
281
+ return mask
282
+
283
+ def run_dct(self, x):
284
+ return apply_linear_2d(x, self.dct)
285
+
286
+ def run_idct(self, x):
287
+ return apply_linear_2d(x, self.idct)
288
+
289
+ def forward(self, x, mode: str = 'dct'):
290
+ if mode == 'dct':
291
+ return self.run_dct(x)
292
+ elif mode == 'idct':
293
+ return self.run_idct(x)
294
+ else:
295
+ raise ValueError(f"Invalid mode: {mode}")
296
+
297
+ if __name__ == '__main__':
298
+ x = torch.Tensor(1000,4096)
299
+ x.normal_(0,1)
300
+ linear_dct = LinearDCT(4096, 'dct')
301
+ error = torch.abs(dct(x) - linear_dct(x))
302
+ assert error.max() < 1e-3, (error, error.max())
303
+ linear_idct = LinearDCT(4096, 'idct')
304
+ error = torch.abs(idct(x) - linear_idct(x))
305
+ assert error.max() < 1e-3, (error, error.max())
models/modules/networks.py ADDED
@@ -0,0 +1,714 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from base.base_modules import *
5
+ from timm.models import create_model
6
+ from functools import partial
7
+
8
+
9
+ class Backbone(nn.Module):
10
+ """
11
+ Model backbone to extract features
12
+ """
13
+ def __init__(self,
14
+ input_channels: int = 3,
15
+ channels: tuple = (32, 64, 128, 256, 512),
16
+ strides: tuple = (2, 2, 2, 2),
17
+ use_dropout: bool = False,
18
+ norm: str = 'BATCH',
19
+ leaky: bool = True):
20
+ """
21
+ Args:
22
+ input_channels: the number of input channels
23
+ channels: length-5 tuple, define the number of channels in each stage
24
+ strides: tuple, define the stride in each stage
25
+ use_dropout: bool, whether to use dropout
26
+ norm: str, normalization type
27
+ leaky: bool, whether to use leaky relu
28
+ """
29
+ super().__init__()
30
+ self.nb_filter = channels
31
+ self.strides = strides + (5 - len(strides)) * (1,)
32
+ res_unit = ResBlock if channels[-1] <= 320 else ResBottleneck
33
+
34
+ self.conv0_0 = nn.Sequential(
35
+ nn.Conv2d(input_channels, channels[0], kernel_size=7, stride=self.strides[0], padding=3),
36
+ nn.GroupNorm(1, channels[0]) if norm == 'GROUP' else nn.BatchNorm2d(channels[0]) if norm == 'BATCH' else nn.InstanceNorm2d(channels[0]),
37
+ nn.LeakyReLU() if leaky else nn.ReLU(),
38
+ )
39
+ self.conv1_0 = res_unit(self.nb_filter[0], self.nb_filter[1], self.strides[1], use_dropout=use_dropout, norm=norm, leaky=leaky)
40
+ self.conv2_0 = res_unit(self.nb_filter[1], self.nb_filter[2], self.strides[2], use_dropout=use_dropout, norm=norm, leaky=leaky)
41
+ self.conv3_0 = res_unit(self.nb_filter[2], self.nb_filter[3], self.strides[3], use_dropout=use_dropout, norm=norm, leaky=leaky)
42
+ self.conv4_0 = res_unit(self.nb_filter[3], self.nb_filter[4], self.strides[4], use_dropout=use_dropout, norm=norm, leaky=leaky)
43
+
44
+ def forward(self, x):
45
+ x0_0 = self.conv0_0(x)
46
+ x1_0 = self.conv1_0(x0_0)
47
+ x2_0 = self.conv2_0(x1_0)
48
+ x3_0 = self.conv3_0(x2_0)
49
+ x4_0 = self.conv4_0(x3_0)
50
+ return x0_0, x1_0, x2_0, x3_0, x4_0
51
+
52
+
53
+ class TimmBackbone(nn.Module):
54
+ """
55
+ Timm backbone to extract features, utilizing pretrained weights
56
+ """
57
+ def __init__(self, model_name) -> None:
58
+ super().__init__()
59
+ self.backbone = create_model(model_name, pretrained=True, features_only=True)
60
+ self.determine_nb_filters()
61
+
62
+ def determine_nb_filters(self):
63
+ dummy = torch.randn(1, 3, 256, 256)
64
+ out = self.backbone(dummy)
65
+ nb_filters = []
66
+ for o in out:
67
+ nb_filters.append(o.size(1))
68
+ self.nb_filter = nb_filters
69
+
70
+ def forward(self, inputs):
71
+ return self.backbone(inputs)
72
+
73
+
74
+ class UNet(nn.Module):
75
+ def __init__(self,
76
+ model_name: str = None,
77
+ in_channels: int = 1,
78
+ out_channels: int = None,
79
+ channels: tuple = (64, 128, 256, 320, 512),
80
+ strides: tuple = (2, 2, 2, 2, 2),
81
+ use_dropout: bool = False,
82
+ norm: str = 'INSTANCE',
83
+ leaky: bool = True,
84
+ use_dilated_bottleneck: bool = False):
85
+ """
86
+ Args:
87
+ model_name: timm model name
88
+ input_channels: the number of input channels
89
+ in_channels: the number of output channels
90
+ channels: length-5 tuple, define the number of channels in each stage
91
+ strides: tuple, define the stride in each stage
92
+ use_dropout: bool, whether to use dropout
93
+ norm: str, normalization type
94
+ leaky: bool, whether to use leaky relu
95
+ """
96
+ super().__init__()
97
+ if model_name not in [None, 'none', 'None']:
98
+ # use Timm backbone and overrides any other input arguments
99
+ self.backbone = TimmBackbone(model_name)
100
+ else:
101
+ self.backbone = Backbone(input_channels=in_channels, channels=channels, strides=strides,
102
+ use_dropout=use_dropout, norm=norm, leaky=leaky)
103
+ nb_filter = self.backbone.nb_filter
104
+ res_unit = ResBlock if nb_filter[-1] <= 512 else ResBottleneck
105
+
106
+ # decoder
107
+ self.conv3_1 = res_unit(nb_filter[3] + nb_filter[4], nb_filter[3], use_dropout=use_dropout, norm=norm, leaky=leaky)
108
+ self.conv2_2 = res_unit(nb_filter[2] + nb_filter[3], nb_filter[2], use_dropout=use_dropout, norm=norm, leaky=leaky)
109
+ self.conv1_3 = res_unit(nb_filter[1] + nb_filter[2], nb_filter[1], use_dropout=use_dropout, norm=norm, leaky=leaky)
110
+ self.conv0_4 = res_unit(nb_filter[0] + nb_filter[1], nb_filter[0], use_dropout=use_dropout, norm=norm, leaky=leaky)
111
+
112
+ # dilated bottleneck: optional
113
+ if use_dilated_bottleneck:
114
+ self.dilation = nn.Sequential(
115
+ nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=1, dilation=1),
116
+ nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]),
117
+ nn.LeakyReLU() if leaky else nn.ReLU(),
118
+ nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=2, dilation=2),
119
+ nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]),
120
+ nn.LeakyReLU() if leaky else nn.ReLU(),
121
+ nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=5, dilation=5),
122
+ nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]),
123
+ nn.LeakyReLU() if leaky else nn.ReLU(),
124
+ nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=1, dilation=1),
125
+ nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]),
126
+ nn.LeakyReLU() if leaky else nn.ReLU(),
127
+ nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=2, dilation=2),
128
+ nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]),
129
+ nn.LeakyReLU() if leaky else nn.ReLU(),
130
+ nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=5, dilation=5),
131
+ nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]),
132
+ nn.LeakyReLU() if leaky else nn.ReLU(),
133
+ )
134
+ else:
135
+ self.dilation = nn.Identity()
136
+
137
+ if out_channels is not None:
138
+ self.convds0 = nn.Conv2d(nb_filter[0], out_channels, kernel_size=1, bias=False)
139
+ else:
140
+ self.convds0 = None
141
+
142
+ def upsample(self, inputs, target):
143
+ return F.interpolate(inputs, size=target.shape[2:], mode='bilinear', align_corners=False)
144
+
145
+ def extract_features(self, x):
146
+ x0, x1, x2, x3, x4 = self.backbone(x)
147
+
148
+ x4 = self.dilation(x4)
149
+
150
+ x3_1 = self.conv3_1(torch.cat([x3, self.upsample(x4, x3)], dim=1))
151
+ x2_2 = self.conv2_2(torch.cat([x2, self.upsample(x3_1, x2)], dim=1))
152
+ x1_3 = self.conv1_3(torch.cat([x1, self.upsample(x2_2, x1)], dim=1))
153
+ x0_4 = self.conv0_4(torch.cat([x0, self.upsample(x1_3, x0)], dim=1))
154
+ return x4, x0_4
155
+
156
+ def forward(self, x):
157
+ size = x.shape[2:]
158
+ x0, x1, x2, x3, x4 = self.backbone(x)
159
+
160
+ x4 = self.dilation(x4)
161
+
162
+ x3_1 = self.conv3_1(torch.cat([x3, self.upsample(x4, x3)], dim=1))
163
+ x2_2 = self.conv2_2(torch.cat([x2, self.upsample(x3_1, x2)], dim=1))
164
+ x1_3 = self.conv1_3(torch.cat([x1, self.upsample(x2_2, x1)], dim=1))
165
+ x0_4 = self.conv0_4(torch.cat([x0, self.upsample(x1_3, x0)], dim=1))
166
+ if self.convds0 is not None:
167
+ x_out = self.convds0(x0_4)
168
+ out = F.interpolate(x_out, size=size, mode='bilinear', align_corners=False)
169
+ else:
170
+ out = x0_4
171
+ return out
172
+
173
+ def freeze(self):
174
+ # freeze the network
175
+ for p in self.parameters():
176
+ p.requires_grad = False
177
+
178
+ def unfreeze(self):
179
+ # unfreeze the network to allow parameter update
180
+ for p in self.parameters():
181
+ p.requires_grad = True
182
+
183
+
184
+ class PromptAttentionUNet(nn.Module):
185
+ def __init__(self,
186
+ model_name: str = None,
187
+ in_channels: int = 1,
188
+ out_channels: int = None,
189
+ channels: tuple = (64, 128, 256, 320, 512),
190
+ strides: tuple = (2, 2, 2, 2, 2),
191
+ use_dropout: bool = False,
192
+ norm: str = 'INSTANCE',
193
+ leaky: bool = True,
194
+ use_dilated_bottleneck: bool = False):
195
+ """
196
+ Args:
197
+ model_name: timm model name
198
+ input_channels: the number of input channels
199
+ in_channels: the number of output channels
200
+ channels: length-5 tuple, define the number of channels in each stage
201
+ strides: tuple, define the stride in each stage
202
+ use_dropout: bool, whether to use dropout
203
+ norm: str, normalization type
204
+ leaky: bool, whether to use leaky relu
205
+ """
206
+ super().__init__()
207
+ if model_name not in [None, 'none', 'None']:
208
+ # use Timm backbone and overrides any other input arguments
209
+ self.backbone = TimmBackbone(model_name)
210
+ else:
211
+ self.backbone = Backbone(input_channels=in_channels, channels=channels, strides=strides,
212
+ use_dropout=use_dropout, norm=norm, leaky=leaky)
213
+ nb_filter = self.backbone.nb_filter
214
+ res_unit = PromptResBlock if nb_filter[-1] <= 512 else PromptResBottleneck
215
+
216
+ # decoder
217
+ self.conv3_1 = res_unit(nb_filter[3] + nb_filter[4], nb_filter[3], use_dropout=use_dropout, norm=norm, leaky=leaky)
218
+ self.conv2_2 = res_unit(nb_filter[2] + nb_filter[3], nb_filter[2], use_dropout=use_dropout, norm=norm, leaky=leaky)
219
+ self.conv1_3 = res_unit(nb_filter[1] + nb_filter[2], nb_filter[1], use_dropout=use_dropout, norm=norm, leaky=leaky)
220
+ self.conv0_4 = res_unit(nb_filter[0] + nb_filter[1], nb_filter[0], use_dropout=use_dropout, norm=norm, leaky=leaky)
221
+
222
+ # dilated bottleneck: optional
223
+ if use_dilated_bottleneck:
224
+ self.dilation = nn.Sequential(
225
+ nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=1, dilation=1),
226
+ nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]),
227
+ nn.LeakyReLU() if leaky else nn.ReLU(),
228
+ nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=1, dilation=2),
229
+ nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]),
230
+ nn.LeakyReLU() if leaky else nn.ReLU(),
231
+ nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=1, dilation=5),
232
+ nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]),
233
+ nn.LeakyReLU() if leaky else nn.ReLU(),
234
+ nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=1, dilation=1),
235
+ nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]),
236
+ nn.LeakyReLU() if leaky else nn.ReLU(),
237
+ nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=1, dilation=2),
238
+ nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]),
239
+ nn.LeakyReLU() if leaky else nn.ReLU(),
240
+ nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=1, dilation=5),
241
+ nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]),
242
+ nn.LeakyReLU() if leaky else nn.ReLU(),
243
+ )
244
+ else:
245
+ self.dilation = nn.Identity()
246
+
247
+ if out_channels is not None:
248
+ self.convds0 = nn.Conv2d(nb_filter[0], out_channels, kernel_size=1, bias=False)
249
+
250
+ def upsample(self, inputs, target):
251
+ return F.interpolate(inputs, size=target.shape[2:], mode='bilinear', align_corners=False)
252
+
253
+ def extract_features(self, x):
254
+ x0, x1, x2, x3, x4 = self.backbone(x)
255
+
256
+ x4 = self.dilation(x4)
257
+
258
+ x3_1 = self.conv3_1(torch.cat([x3, self.upsample(x4, x3)], dim=1))
259
+ x2_2 = self.conv2_2(torch.cat([x2, self.upsample(x3_1, x2)], dim=1))
260
+ x1_3 = self.conv1_3(torch.cat([x1, self.upsample(x2_2, x1)], dim=1))
261
+ x0_4 = self.conv0_4(torch.cat([x0, self.upsample(x1_3, x0)], dim=1))
262
+ return x4, x0_4
263
+
264
+ def forward(self, x, prompt_in):
265
+ size = x.shape[2:]
266
+ x0, x1, x2, x3, x4 = self.backbone(x)
267
+
268
+ x4 = self.dilation(x4)
269
+
270
+ x3_1 = self.conv3_1(torch.cat([x3, self.upsample(x4, x3)], dim=1), prompt_in)
271
+ x2_2 = self.conv2_2(torch.cat([x2, self.upsample(x3_1, x2)], dim=1), prompt_in)
272
+ x1_3 = self.conv1_3(torch.cat([x1, self.upsample(x2_2, x1)], dim=1), prompt_in)
273
+ x0_4 = self.conv0_4(torch.cat([x0, self.upsample(x1_3, x0)], dim=1), prompt_in)
274
+ x_out = self.convds0(x0_4)
275
+ out = F.interpolate(x_out, size=size, mode='bilinear', align_corners=False)
276
+ return out
277
+
278
+ def freeze(self):
279
+ # freeze the network
280
+ for p in self.parameters():
281
+ p.requires_grad = False
282
+
283
+ def unfreeze(self):
284
+ # unfreeze the network to allow parameter update
285
+ for p in self.parameters():
286
+ p.requires_grad = True
287
+
288
+
289
+ class CLIPDrivenUNet(nn.Module):
290
+ def __init__(self, encoding: str, model_name: str = None, in_channels: int = 1, out_channels: int = 1, channels: tuple = (32, 64, 128, 256, 512),
291
+ strides: tuple = (2, 2, 2, 2, 2), norm: str = 'INSTANCE', leaky: bool = True) -> None:
292
+ super().__init__()
293
+ self.encoding = encoding
294
+ self.num_classes = out_channels
295
+ self.backbone = UNet(model_name=model_name, in_channels=in_channels, out_channels=None, channels=channels,
296
+ strides=strides, use_dropout=False, norm=norm, leaky=leaky)
297
+ self.gap = nn.AdaptiveAvgPool2d(1)
298
+ self.precls_conv = nn.Sequential(
299
+ nn.InstanceNorm2d(32),
300
+ nn.LeakyReLU(),
301
+ nn.Conv2d(32, 8, kernel_size=1)
302
+ )
303
+
304
+ self.weight_nums = [8*8, 8*8, 8*1]
305
+ self.bias_nums = [8, 8, 1]
306
+ self.controller = nn.Conv2d(256 + channels[-1], sum(self.weight_nums + self.bias_nums), kernel_size=1, stride=1, padding=0)
307
+ if encoding == 'CLIP':
308
+ self.register_buffer('protein_embedding', torch.randn(self.num_classes, 512))
309
+ self.text_to_vision = nn.Linear(512, 256)
310
+ elif encoding == 'RAND':
311
+ self.register_buffer('protein_embedding', torch.randn(self.num_classes, 256))
312
+
313
+ def parse_dynamic_params(self, params, channels, weight_nums, bias_nums):
314
+ assert params.dim() == 2
315
+ assert len(weight_nums) == len(bias_nums)
316
+ assert params.size(1) == sum(weight_nums) + sum(bias_nums)
317
+
318
+ num_insts = params.size(0)
319
+ num_layers = len(weight_nums)
320
+
321
+ params_splits = list(torch.split_with_sizes(
322
+ params, weight_nums + bias_nums, dim=1
323
+ ))
324
+
325
+ weight_splits = params_splits[:num_layers]
326
+ bias_splits = params_splits[num_layers:]
327
+
328
+ for l in range(num_layers):
329
+ if l < num_layers - 1:
330
+ weight_splits[l] = weight_splits[l].reshape(num_insts * channels, -1, 1, 1)
331
+ bias_splits[l] = bias_splits[l].reshape(num_insts * channels)
332
+ else:
333
+ weight_splits[l] = weight_splits[l].reshape(num_insts * 1, -1, 1, 1)
334
+ bias_splits[l] = bias_splits[l].reshape(num_insts * 1)
335
+ # print(weight_splits[l].shape, bias_splits[l].shape)
336
+
337
+ return weight_splits, bias_splits
338
+
339
+ def heads_forward(self, features, weights, biases, num_insts):
340
+ n_layers = len(weights)
341
+ x = features
342
+ for i, (w, b) in enumerate(zip(weights, biases)):
343
+ x = F.conv2d(
344
+ x, w, bias=b,
345
+ stride=1, padding=0,
346
+ groups=num_insts
347
+ )
348
+ if i < n_layers - 1:
349
+ x = F.leaky_relu(x)
350
+ return x
351
+
352
+ def forward(self, x_in):
353
+ out_shape = x_in.shape[2:]
354
+ dec4, out = self.backbone.extract_features(x_in) # dec4: (B, channels[-1], H, W), out: (B, channels[0], H, W)
355
+
356
+ if self.encoding == 'RAND':
357
+ task_encoding = self.protein_embedding[..., None, None] # (num_classes, 256, 1, 1)
358
+ elif self.encoding == 'CLIP':
359
+ task_encoding = F.leaky_relu(self.text_to_vision(self.protein_embedding))[..., None, None] # (num_classes, 256, 1, 1)
360
+ else:
361
+ raise NotImplementedError
362
+ x_feat = self.gap(dec4)
363
+ b = x_feat.shape[0]
364
+ logits_array = []
365
+ for i in range(b):
366
+ x_cond = torch.cat([x_feat[i].unsqueeze(0).repeat(self.num_classes, 1, 1, 1), task_encoding], 1)
367
+ params = self.controller(x_cond) # (num_classes, num_params, 1, 1)
368
+ params.squeeze_(-1).squeeze_(-1) # (num_classes, num_params)
369
+
370
+ head_inputs = self.precls_conv(out[i].unsqueeze(0))
371
+ head_inputs = head_inputs.repeat(self.num_classes, 1, 1, 1) # (num_classes, 8, H, W)
372
+ N, _, H, W = head_inputs.size()
373
+ head_inputs = head_inputs.reshape(1, -1, H, W)
374
+ # print(head_inputs.shape, params.shape)
375
+ weights, biases = self.parse_dynamic_params(params, 8, self.weight_nums, self.bias_nums)
376
+
377
+ logits = self.heads_forward(head_inputs, weights, biases, N)
378
+ logits_array.append(logits.reshape(1, -1, H, W))
379
+
380
+ out = torch.cat(logits_array, dim=0)
381
+ out = F.interpolate(out, size=out_shape, mode='bilinear', align_corners=False)
382
+ # print(out.shape)
383
+ return out
384
+
385
+
386
+ class NLayerDiscriminator(nn.Module):
387
+ """Defines a PatchGAN discriminator"""
388
+
389
+ def __init__(self, input_nc, norm='INSTANCE', ndf=64, n_layers=3):
390
+ """Construct a PatchGAN discriminator
391
+
392
+ Parameters:
393
+ input_nc (int) -- the number of channels in input images
394
+ ndf (int) -- the number of filters in the last conv layer
395
+ n_layers (int) -- the number of conv layers in the discriminator
396
+ norm_layer -- normalization layer
397
+ """
398
+ super(NLayerDiscriminator, self).__init__()
399
+ norm_layer = norm_dict[norm]
400
+ use_bias = norm_layer == nn.InstanceNorm2d
401
+
402
+ kw = 4
403
+ padw = 1
404
+ sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
405
+ nf_mult = 1
406
+ nf_mult_prev = 1
407
+ for n in range(1, n_layers): # gradually increase the number of filters
408
+ nf_mult_prev = nf_mult
409
+ nf_mult = min(2 ** n, 8)
410
+ sequence += [
411
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
412
+ norm_layer(ndf * nf_mult),
413
+ nn.LeakyReLU(0.2, True)
414
+ ]
415
+
416
+ nf_mult_prev = nf_mult
417
+ nf_mult = min(2 ** n_layers, 8)
418
+ sequence += [
419
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
420
+ norm_layer(ndf * nf_mult),
421
+ nn.LeakyReLU(0.2, True)
422
+ ]
423
+
424
+ sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
425
+ self.model = nn.Sequential(*sequence)
426
+
427
+ def forward(self, input):
428
+ """Standard forward."""
429
+ return self.model(input)
430
+
431
+
432
+ class PatchDiscriminator(nn.Module):
433
+ def __init__(self, in_channels, norm_type='INSTANCE'):
434
+ super().__init__()
435
+ nb_filters = [32, 64, 128, 256, 512]
436
+ strides = [2, 2, 2, 2, 2]
437
+
438
+ self.layer1 = ConvNorm(in_channels=in_channels, out_channels=nb_filters[0], kernel_size=5, stride=strides[0], norm='NONE', leaky=True)
439
+ self.layer2 = ConvNorm(in_channels=nb_filters[0], out_channels=nb_filters[1], kernel_size=5, stride=strides[1], norm=norm_type, leaky=True)
440
+ self.layer3 = ConvNorm(in_channels=nb_filters[1], out_channels=nb_filters[2], kernel_size=5, stride=strides[2], norm=norm_type, leaky=True)
441
+ self.layer4 = ConvNorm(in_channels=nb_filters[2], out_channels=nb_filters[3], kernel_size=5, stride=strides[3], norm=norm_type, leaky=True)
442
+ self.layer5 = ConvNorm(in_channels=nb_filters[3], out_channels=nb_filters[4], kernel_size=5, stride=strides[4], norm=norm_type, leaky=True)
443
+
444
+ self.dense_pred = ConvNorm(in_channels=nb_filters[4], out_channels=1, kernel_size=3, stride=1, norm='NONE', activation=False)
445
+
446
+ def forward(self, inputs):
447
+ x1 = self.layer1(inputs)
448
+ x2 = self.layer2(x1)
449
+ x3 = self.layer3(x2)
450
+ x4 = self.layer4(x3)
451
+ x5 = self.layer5(x4)
452
+ output = self.dense_pred(x5)
453
+ output_list = [x1, x2, x3, x4, x5, output]
454
+ return output_list
455
+
456
+
457
+ class PromptPatchDiscriminator(nn.Module):
458
+ def __init__(self, in_channels, norm_type='INSTANCE'):
459
+ super().__init__()
460
+ nb_filters = [32, 64, 128, 256, 512]
461
+ strides = [2, 2, 2, 2, 2]
462
+
463
+ self.layer1 = ConvNorm(in_channels=in_channels, out_channels=nb_filters[0], kernel_size=5, stride=strides[0], norm='NONE', leaky=True)
464
+ self.layer2 = ConvNorm(in_channels=nb_filters[0], out_channels=nb_filters[1], kernel_size=5, stride=strides[1], norm=norm_type, leaky=True)
465
+ self.layer3 = ConvNorm(in_channels=nb_filters[1], out_channels=nb_filters[2], kernel_size=5, stride=strides[2], norm=norm_type, leaky=True)
466
+ self.layer4 = ConvNorm(in_channels=nb_filters[2], out_channels=nb_filters[3], kernel_size=5, stride=strides[3], norm=norm_type, leaky=True)
467
+ self.layer5 = ConvNorm(in_channels=nb_filters[3], out_channels=nb_filters[4], kernel_size=5, stride=strides[4], norm=norm_type, leaky=True)
468
+
469
+ self.attn4 = PromptAttentionModule(in_channels=nb_filters[3], prompt_channels=512, mid_channels=nb_filters[3] // 4)
470
+ self.attn5 = PromptAttentionModule(in_channels=nb_filters[4], prompt_channels=512, mid_channels=nb_filters[4] // 4)
471
+
472
+ self.dense_pred = ConvNorm(in_channels=nb_filters[4], out_channels=1, kernel_size=3, stride=1, norm='NONE', activation=False)
473
+
474
+ def forward(self, inputs, prompt_in):
475
+ x1 = self.layer1(inputs)
476
+ x2 = self.layer2(x1)
477
+ x3 = self.layer3(x2)
478
+ x4 = self.layer4(x3)
479
+ x4 = self.attn4(x4, prompt_in)
480
+ x5 = self.layer5(x4)
481
+ x5 = self.attn5(x5, prompt_in)
482
+ output = self.dense_pred(x5)
483
+ output_list = [x1, x2, x3, x4, x5, output]
484
+ return output_list
485
+
486
+
487
+ class MultiScaleDiscriminator(nn.Module):
488
+ def __init__(self, in_channels, norm='INSTANCE', num_D=3):
489
+ super(MultiScaleDiscriminator, self).__init__()
490
+ self.num_D = num_D
491
+ module = PatchDiscriminator
492
+
493
+ for i in range(num_D):
494
+ netD = module(in_channels, norm)
495
+ setattr(self, 'layer' + str(i), netD)
496
+
497
+ self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
498
+
499
+ def singleD_forward(self, model, input):
500
+ return model(input)
501
+
502
+ def forward(self, input):
503
+ num_D = self.num_D
504
+ result = []
505
+ input_downsampled = input
506
+ for i in range(num_D):
507
+ model = getattr(self, 'layer' + str(num_D - 1 - i))
508
+ result.append(self.singleD_forward(model, input_downsampled))
509
+ if i != (num_D - 1):
510
+ input_downsampled = self.downsample(input_downsampled)
511
+ return result
512
+
513
+
514
+ class PromptMultiScaleDiscriminator(nn.Module):
515
+ def __init__(self, in_channels, norm='INSTANCE', num_D=3):
516
+ super(PromptMultiScaleDiscriminator, self).__init__()
517
+ self.num_D = num_D
518
+ module = PromptPatchDiscriminator
519
+
520
+ for i in range(num_D):
521
+ netD = module(in_channels, norm)
522
+ setattr(self, 'layer' + str(i), netD)
523
+
524
+ self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
525
+
526
+ def singleD_forward(self, model, input, prompt_in):
527
+ return model(input, prompt_in)
528
+
529
+ def forward(self, input, prompt_in):
530
+ num_D = self.num_D
531
+ result = []
532
+ input_downsampled = input
533
+ for i in range(num_D):
534
+ model = getattr(self, 'layer' + str(num_D - 1 - i))
535
+ result.append(self.singleD_forward(model, input_downsampled, prompt_in))
536
+ if i != (num_D - 1):
537
+ input_downsampled = self.downsample(input_downsampled)
538
+ return result
539
+
540
+
541
+ class HighResEnhancer(nn.Module):
542
+ """
543
+ Design a global-local network for high res generation and enhance boundary information.
544
+ """
545
+ def __init__(self,
546
+ model_name: str = None,
547
+ in_channels: int = 1,
548
+ out_channels: int = None,
549
+ coarse_channels: tuple = (16, 32, 64, 128, 256),
550
+ channels: tuple = (32, 64, 128, 256, 512),
551
+ use_dropout: bool = False,
552
+ norm: str = 'INSTANCE',
553
+ leaky: bool = True,
554
+ use_dilated_bottleneck: bool = False):
555
+ super().__init__()
556
+ # define basic blocks
557
+ self.norm = norm
558
+ self.leaky = leaky
559
+ norm_layer = self.get_norm_layer()
560
+ act_layer = self.get_act_layer()
561
+ res_unit = ResBlock if channels[-1] <= 512 else ResBottleneck
562
+
563
+ # check input channels
564
+ assert channels[1] == coarse_channels[2], 'The number of channel-2 for coarse and number of channel-1 for fine branch should be the same.'
565
+
566
+ # downsample and edge information extraction:
567
+ # the downsample operation provides the input for coarse branch
568
+ self.downsample = nn.AvgPool2d(3, stride=2, padding=1)
569
+ # the sobel filter is operated on the downsampled image to provide edge information
570
+ self.sobel = SobelEdge(input_dim=2, channels=in_channels)
571
+ self.sobel_conv = nn.Sequential(
572
+ nn.Conv2d(in_channels, channels[0], kernel_size=3, stride=2, padding=1),
573
+ norm_layer(channels[0]),
574
+ act_layer()
575
+ )
576
+
577
+ # coarse generator: in_channels -> coarse_channels[2]
578
+ # input stride: 0
579
+ # output stride: 4 (as input is already 2x downsampled)
580
+ self.coarse = nn.Sequential(
581
+ nn.Conv2d(in_channels, coarse_channels[0], kernel_size=3, stride=2, padding=1),
582
+ norm_layer(coarse_channels[0]),
583
+ act_layer(),
584
+ res_unit(coarse_channels[0], coarse_channels[1], stride=2),
585
+ res_unit(coarse_channels[1], coarse_channels[2], stride=2),
586
+ res_unit(coarse_channels[2], coarse_channels[3], stride=2),
587
+ res_unit(coarse_channels[3], coarse_channels[4], stride=1),
588
+ nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
589
+ res_unit(coarse_channels[4], coarse_channels[3], stride=1),
590
+ nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
591
+ res_unit(coarse_channels[3], coarse_channels[2], stride=1),
592
+ nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
593
+ res_unit(coarse_channels[2], coarse_channels[2], stride=1),
594
+ )
595
+
596
+ # fine generator: used to enhance the generation for better details
597
+ # 1. simple encoder: channels[0] -> channels[1]
598
+ # input stride: 0
599
+ # output stride: 4
600
+ self.fine_encoder = nn.Sequential(
601
+ nn.Conv2d(in_channels, channels[0], kernel_size=3, stride=2, padding=1),
602
+ norm_layer(channels[0]),
603
+ act_layer(),
604
+ nn.Conv2d(channels[0], channels[1], kernel_size=3, stride=2, padding=1),
605
+ norm_layer(channels[1]),
606
+ act_layer()
607
+ )
608
+ # 2. bottleneck: channels[1] -> channels[4]
609
+ # input stride: 4
610
+ # output stride: 16
611
+ self.bottleneck = nn.Sequential(
612
+ res_unit(channels[1], channels[2], stride=2),
613
+ res_unit(channels[2], channels[3], stride=2),
614
+ res_unit(channels[3], channels[4], stride=1),
615
+ res_unit(channels[4], channels[4], stride=1),
616
+ )
617
+ if use_dilated_bottleneck:
618
+ self.bottleneck.add_module('dilated_block_1',
619
+ nn.Sequential(
620
+ nn.Conv2d(channels[4], channels[4], kernel_size=3, stride=1, padding=1, dilation=1),
621
+ norm_layer(channels[4]),
622
+ act_layer()
623
+ ))
624
+ self.bottleneck.add_module('dilated_block_2',
625
+ nn.Sequential(
626
+ nn.Conv2d(channels[4], channels[4], kernel_size=3, stride=1, padding=2, dilation=2),
627
+ norm_layer(channels[4]),
628
+ act_layer()
629
+ ))
630
+ self.bottleneck.add_module('dilated_block_3',
631
+ nn.Sequential(
632
+ nn.Conv2d(channels[4], channels[4], kernel_size=3, stride=1, padding=5, dilation=5),
633
+ norm_layer(channels[4]),
634
+ act_layer()
635
+ ))
636
+ self.bottleneck.add_module('dilated_block_4',
637
+ nn.Sequential(
638
+ nn.Conv2d(channels[4], channels[4], kernel_size=3, stride=1, padding=1, dilation=1),
639
+ norm_layer(channels[4]),
640
+ act_layer()
641
+ ))
642
+ self.bottleneck.add_module('dilated_block_5',
643
+ nn.Sequential(
644
+ nn.Conv2d(channels[4], channels[4], kernel_size=3, stride=1, padding=2, dilation=2),
645
+ norm_layer(channels[4]),
646
+ act_layer()
647
+ ))
648
+ self.bottleneck.add_module('dilated_block_6',
649
+ nn.Sequential(
650
+ nn.Conv2d(channels[4], channels[4], kernel_size=3, stride=1, padding=5, dilation=5),
651
+ norm_layer(channels[4]),
652
+ act_layer()
653
+ ))
654
+
655
+ # 3. simple decoder: channels[4] -> channels[0]
656
+ # input stride: 16
657
+ # output stride: 2
658
+ self.decoder = nn.Sequential(
659
+ res_unit(channels[4], channels[3], stride=1),
660
+ nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
661
+ res_unit(channels[3], channels[2], stride=1),
662
+ nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
663
+ res_unit(channels[2], channels[1], stride=1),
664
+ nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
665
+ res_unit(channels[1], channels[0], stride=1),
666
+ )
667
+
668
+ # output operation that combines both feature branch and edge branch
669
+ # input stride: 2
670
+ # output stride: 0
671
+ self.output = nn.Sequential(
672
+ nn.Conv2d(2 * channels[0], channels[0], kernel_size=3, stride=1, padding=1),
673
+ norm_layer(channels[0]),
674
+ act_layer(),
675
+ nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
676
+ nn.Conv2d(channels[0], out_channels, kernel_size=1, stride=1, bias=False)
677
+ )
678
+
679
+ def get_norm_layer(self):
680
+ if self.norm == 'INSTANCE':
681
+ return partial(nn.InstanceNorm2d, affine=False)
682
+ elif self.norm == 'BATCH':
683
+ return partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
684
+ elif self.norm == 'GROUP':
685
+ return partial(nn.GroupNorm, num_groups=8)
686
+ else:
687
+ raise NotImplementedError(f'Normalization layer {self.norm} is not implemented.')
688
+
689
+ def get_act_layer(self):
690
+ if self.leaky:
691
+ return partial(nn.LeakyReLU, inplace=False)
692
+ else:
693
+ return partial(nn.ReLU, inplace=False)
694
+
695
+ def forward(self, inputs):
696
+ """
697
+ Args:
698
+ inputs: (B, C, H, W), input IMC image
699
+ """
700
+ # downsample and edge information extraction
701
+ downsampled = self.downsample(inputs) # 0 -> 2x stride
702
+ edge = self.sobel(inputs)
703
+ edge = self.sobel_conv(edge)
704
+
705
+ # coarse generator
706
+ coarse = self.coarse(downsampled) # 2x stride -> 4x stride
707
+ # fine generator
708
+ fine = self.fine_encoder(inputs) # 0x stride -> 4x stride
709
+ # add coarse and fine information together
710
+ fine = self.bottleneck(fine + coarse) # 4x stride -> 16x stride
711
+ fine = self.decoder(fine) # 16x stride -> 2x stride
712
+ # output operation
713
+ output = self.output(torch.cat([edge, fine], dim=1))
714
+ return output
test_data/1.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d97725edcd248d40aca15fc720f6caf46e55e5f2eab28fa7a28a0e8a1448dc80
3
+ size 1890089
test_data/10.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e43c1e80c83dc7898163b54338485fb092c3470326914cd697d700970ba247a
3
+ size 1935806
test_data/11.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:144aac1b1e0d4566133eeb62d65e26fe29d430f082e9fcb0b4fd1794df43a406
3
+ size 1920270
test_data/12.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8bb067a4445326aa36775a728d4d0bdf8ea622f3dc1683b4d1d14e84b31b4e98
3
+ size 1286013
test_data/13.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ba481f70e969558838d17e608013cd838d858700fa628b857766ea44060cb96c
3
+ size 1858792
test_data/14.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bd5f4c6da4a8092095f9749869480281835a572b11086b82c1c1a6e230792071
3
+ size 1851990
test_data/15.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7518a1c2aa2432375794330262570d23631d9a2ebaa4ce924a9ad49df87218b1
3
+ size 1905786
test_data/16.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d4d452f027d2d05088142810a0e3ab9d5692898685182b6bfd0a64ebc1d033ee
3
+ size 1894100
test_data/17.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:43d29ce1f026a7b7f19746dba61e69e6682514164d02dae2f43575eb8f779b77
3
+ size 1966934
test_data/18.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0a8cec8db943c7f24a6cb63e0d935db729f883ff7ea2fafe72859bcbc9371711
3
+ size 1894208
test_data/19.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:84795f8d8f41e27e06c4fa1fa0e1a46e753ff8359b84f6fcc334d50ce28bf144
3
+ size 1901645
test_data/2.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6327b7e06cfc10eb075d45715dfb2a1807a7899bafe6d52c5eb5422332121f51
3
+ size 1918917
test_data/20.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:66ab36131f7630743676ca7d60c4c52e518f296c17b613449b1a45a7c565bfdd
3
+ size 1834266
test_data/21.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2fbb4bad0e82ac5fe66f8e56c5a3c45eefe46c9c96274f258d81f5a8da4f196a
3
+ size 1898715
test_data/22.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:15ec9e41c7ec036284dea853035976a8957f75d2974e9821f0f59e082adce622
3
+ size 1898663
test_data/23.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:173f0e81a90fe67458947d317e60c5d5227760a30f84bd172913f29c51604bfe
3
+ size 1772117
test_data/24.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f4bc4ddfe353eabc21369b1615b6b1800ace7fce3f052b15e2fe5a04e897a9cf
3
+ size 1933801
test_data/25.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0db0ddefd47d481b373aceb9aa9f6e9fdb671a256435e5e1e6cb78c1f5a650c7
3
+ size 1971978
test_data/26.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9a2b0f85da6d5c34c205a54017b18b0c9dbeea04b061f45070cbc4b1dca36e70
3
+ size 1802038
test_data/27.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e67efcfe7eec68477e088104f31732719a0f8b3ca86c92cc5e87f1ab1b465370
3
+ size 1633565
test_data/28.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:deb5dbffd27146b0c65378c82707219376c28d40de9d267acdb3f941fb8f3f87
3
+ size 1462921
test_data/29.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0f4e36c7e5203a3097929ffc987bd802ba0d4b7e2d4641a22623938bea0e4a94
3
+ size 1919319
test_data/3.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fce25bb09db8ed1d7d1537573ea86b614c095fe0398227a2cfbbaac70ac190f2
3
+ size 1987452