Spaces:
Sleeping
Sleeping
Upload 64 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +156 -0
- base/__init__.py +0 -0
- base/base_modules.py +301 -0
- base/base_segmentation.py +230 -0
- base/base_wandb_model.py +107 -0
- checkpoint/stage_i.pkl +3 -0
- checkpoint/stage_ii.pkl +3 -0
- ckpt/BiomedCLIP/biomed-vlp-eval.svg +1 -0
- ckpt/BiomedCLIP/biomed_clip_example.ipynb +0 -0
- ckpt/BiomedCLIP/config.json +17 -0
- ckpt/BiomedCLIP/open_clip_config.json +31 -0
- ckpt/BiomedCLIP/special_tokens_map.json +7 -0
- ckpt/BiomedCLIP/tokenizer.json +0 -0
- ckpt/BiomedCLIP/tokenizer_config.json +15 -0
- ckpt/BiomedCLIP/vocab.txt +0 -0
- configs/confocal.cfg +36 -0
- configs/confocal_marker.cfg +26 -0
- configs/convertion.cfg +25 -0
- configs/extend_1.cfg +36 -0
- configs/extend_2.cfg +26 -0
- configs/full.cfg +39 -0
- configs/imc.cfg +36 -0
- configs/translation.cfg +36 -0
- markers.py +136 -0
- models/modules/biomedclip.py +114 -0
- models/modules/dct.py +305 -0
- models/modules/networks.py +714 -0
- test_data/1.npz +3 -0
- test_data/10.npz +3 -0
- test_data/11.npz +3 -0
- test_data/12.npz +3 -0
- test_data/13.npz +3 -0
- test_data/14.npz +3 -0
- test_data/15.npz +3 -0
- test_data/16.npz +3 -0
- test_data/17.npz +3 -0
- test_data/18.npz +3 -0
- test_data/19.npz +3 -0
- test_data/2.npz +3 -0
- test_data/20.npz +3 -0
- test_data/21.npz +3 -0
- test_data/22.npz +3 -0
- test_data/23.npz +3 -0
- test_data/24.npz +3 -0
- test_data/25.npz +3 -0
- test_data/26.npz +3 -0
- test_data/27.npz +3 -0
- test_data/28.npz +3 -0
- test_data/29.npz +3 -0
- test_data/3.npz +3 -0
app.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '3'
|
3 |
+
import gradio as gr
|
4 |
+
import yaml
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import numpy as np
|
8 |
+
import pandas as pd
|
9 |
+
from models.modules.networks import PromptAttentionUNet, HighResEnhancer
|
10 |
+
from models.modules.biomedclip import BiomedCLIPTextEncoder
|
11 |
+
from monai.inferers import sliding_window_inference
|
12 |
+
from markers import breast_markers, prostatic_markers, pancreatic_markers
|
13 |
+
|
14 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
15 |
+
|
16 |
+
# load the convertion model
|
17 |
+
# load the cfg file for convertion
|
18 |
+
cfg_file = 'configs/{}.cfg'.format('convertion')
|
19 |
+
with open(cfg_file, 'r') as f:
|
20 |
+
cfg = yaml.safe_load(f)
|
21 |
+
print("successfully loaded config file: ", cfg)
|
22 |
+
|
23 |
+
# convertion models
|
24 |
+
convertion_ckpt = './checkpoint/stage_ii.pkl'
|
25 |
+
convertion_net = PromptAttentionUNet(in_channels=cfg['MODEL']['IMC_IN'], out_channels=cfg['MODEL']['IMC_OUT'], channels=(128, 256, 512, 1024, 2048))
|
26 |
+
prompt_model = BiomedCLIPTextEncoder(device=device)
|
27 |
+
|
28 |
+
# load state_dict
|
29 |
+
state_dict = torch.load(convertion_ckpt, map_location='cpu')['generator']
|
30 |
+
# remove all the 'module.' prefix
|
31 |
+
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
|
32 |
+
convertion_net.load_state_dict(state_dict)
|
33 |
+
|
34 |
+
# load the translation model
|
35 |
+
cfg_file = 'configs/{}.cfg'.format('translation')
|
36 |
+
with open(cfg_file, 'r') as f:
|
37 |
+
cfg = yaml.safe_load(f)
|
38 |
+
print("successfully loaded config file: ", cfg)
|
39 |
+
translation_ckpt = './checkpoint/stage_i.pkl'
|
40 |
+
|
41 |
+
imc_net = HighResEnhancer(model_name=cfg['MODEL']['TIMM_MODEL'],
|
42 |
+
in_channels=cfg['MODEL']['IMC_IN'],
|
43 |
+
out_channels=cfg['MODEL']['IMC_OUT'],
|
44 |
+
norm=cfg['MODEL']['NORM'],
|
45 |
+
use_dilated_bottleneck=True)
|
46 |
+
|
47 |
+
# load state_dict for IMC
|
48 |
+
state_dict = torch.load(translation_ckpt, map_location='cpu')['imc_G']
|
49 |
+
# remove all the 'module.0.' prefix
|
50 |
+
state_dict = {k.replace('module.0.', ''): v for k, v in state_dict.items()}
|
51 |
+
# remove the key "sobel.filter" in the state_dict
|
52 |
+
state_dict.pop('sobel.filter.weight')
|
53 |
+
imc_net.load_state_dict(state_dict, strict=False)
|
54 |
+
|
55 |
+
convertion_net.eval().to(device)
|
56 |
+
imc_net.eval().to(device)
|
57 |
+
|
58 |
+
# load the metadata for demo data
|
59 |
+
df = pd.read_csv('./test_data/test_metadata.csv')
|
60 |
+
breast_df = df[df['source'] == 'BreastCancer_V2']
|
61 |
+
prostatic_df = df[df['source'] == 'ProstaticCancer_V2']
|
62 |
+
pancreatic_df = df[df['source'] == 'PancreaticCancer_V2']
|
63 |
+
|
64 |
+
|
65 |
+
def load_image(pair_index):
|
66 |
+
# select the item from the dataframe and convert to Series using `squeeze()`
|
67 |
+
item = df[df['name'] == pair_index].squeeze()
|
68 |
+
data = np.load(item['path'])['arr_0']
|
69 |
+
x1 = data[:, :, 0]
|
70 |
+
x2 = data[:, :, 1]
|
71 |
+
return gr.Image(value=x1), gr.Image(value=x2)
|
72 |
+
|
73 |
+
|
74 |
+
def generate_imc(x1, x2, marker_name):
|
75 |
+
# stage I
|
76 |
+
inputs = np.concatenate([x1, x2[:, :, 2:3]], axis=2)
|
77 |
+
# normalize to [0, 1]
|
78 |
+
inputs = inputs / 255.0
|
79 |
+
# to tensor
|
80 |
+
inputs = torch.from_numpy(inputs.transpose(2, 0, 1)).unsqueeze(0).float()
|
81 |
+
# rescale to [-1, 1]
|
82 |
+
inputs = 2 * inputs - 1
|
83 |
+
output = sliding_window_inference(inputs.to(device), roi_size=(320, 320), sw_batch_size=2, predictor=imc_net, overlap=0.5)
|
84 |
+
output = F.tanh(output)
|
85 |
+
# to numpy
|
86 |
+
pred_nuclei = output[0].detach().cpu().numpy().transpose(1, 2, 0)
|
87 |
+
pred_nuclei = (pred_nuclei + 1) / 2 # normalize to [0, 1]
|
88 |
+
# stage II
|
89 |
+
nuclei_inputs = torch.from_numpy(pred_nuclei).permute(2, 0, 1).unsqueeze(0).float()
|
90 |
+
# rescale to [-1, 1]
|
91 |
+
nuclei_inputs = 2 * nuclei_inputs - 1
|
92 |
+
prompt_in = torch.as_tensor(prompt_model([marker_name])).to(device)
|
93 |
+
output = F.tanh(convertion_net(nuclei_inputs.to(device), prompt_in))
|
94 |
+
marker = output[0].detach().cpu().numpy().transpose(1, 2, 0)
|
95 |
+
marker = (marker + 1) / 2 # normalize to [0, 1]
|
96 |
+
# visualization
|
97 |
+
vis = np.concatenate([marker, np.zeros_like(pred_nuclei, dtype=np.float32), pred_nuclei], axis=2)
|
98 |
+
# normalize to [0, 255] and convert to uint8
|
99 |
+
vis = (vis * 255).astype(np.uint8)
|
100 |
+
return gr.Image(value=vis)
|
101 |
+
|
102 |
+
|
103 |
+
# Function to update the second dropdown based on the first dropdown's selection
|
104 |
+
def update_dropdown_by_tissue(selected_category):
|
105 |
+
if selected_category == "Breast":
|
106 |
+
image_selector = gr.Dropdown(choices=breast_df['name'].values.tolist(), value=breast_df['name'].values[0], interactive=True)
|
107 |
+
marker_selector = gr.Dropdown(choices=breast_markers, value=breast_markers[0], interactive=True)
|
108 |
+
elif selected_category == "Pancreatic":
|
109 |
+
image_selector = gr.Dropdown(choices=pancreatic_df['name'].values.tolist(), value=pancreatic_df['name'].values[0], interactive=True)
|
110 |
+
marker_selector = gr.Dropdown(choices=pancreatic_markers, value=pancreatic_markers[0], interactive=True)
|
111 |
+
elif selected_category == "Prostatic":
|
112 |
+
image_selector = gr.Dropdown(choices=prostatic_df['name'].values.tolist(), value=prostatic_df['name'].values[0], interactive=True)
|
113 |
+
marker_selector = gr.Dropdown(choices=prostatic_markers, value=prostatic_markers[0], interactive=True)
|
114 |
+
return [image_selector, marker_selector]
|
115 |
+
|
116 |
+
|
117 |
+
# Create the Gradio interface
|
118 |
+
def create_gradio_ui():
|
119 |
+
with gr.Blocks() as demo:
|
120 |
+
with gr.Tab("Mbi2Spi"):
|
121 |
+
with gr.Row():
|
122 |
+
with gr.Column(scale=1):
|
123 |
+
with gr.Row():
|
124 |
+
# image visualizer
|
125 |
+
brightfield = gr.Image(label="Brightfield Image", type="numpy", interactive=False)
|
126 |
+
aux = gr.Image(type="numpy", visible=False, interactive=False)
|
127 |
+
|
128 |
+
with gr.Row():
|
129 |
+
with gr.Column():
|
130 |
+
# tissue selector (Breast, Pancreatic, Prostatic)
|
131 |
+
tissue_selector = gr.Dropdown(choices=["Breast", "Pancreatic", "Prostatic"], label="Select Tissue Type")
|
132 |
+
# marker selector
|
133 |
+
marker_selector = gr.Dropdown(label="Marker Selector", interactive=False)
|
134 |
+
|
135 |
+
with gr.Column():
|
136 |
+
# image selector
|
137 |
+
image_selector = gr.Dropdown(label="Brightfield Selector", interactive=False)
|
138 |
+
# update the image selector based on the tissue type
|
139 |
+
tissue_selector.change(update_dropdown_by_tissue, inputs=tissue_selector, outputs=[image_selector, marker_selector])
|
140 |
+
|
141 |
+
with gr.Column(scale=1):
|
142 |
+
output_image = gr.Image(label="Generated Image", type="numpy")
|
143 |
+
button1 = gr.Button("Predict IMC")
|
144 |
+
|
145 |
+
# Load the selected image and update the input image and infrared image
|
146 |
+
image_selector.change(load_image, inputs=image_selector, outputs=[brightfield, aux])
|
147 |
+
|
148 |
+
# Event handler for button click
|
149 |
+
button1.click(generate_imc, inputs=[brightfield, aux, marker_selector], outputs=output_image)
|
150 |
+
|
151 |
+
return demo
|
152 |
+
|
153 |
+
# Launch the demo
|
154 |
+
if __name__ == '__main__':
|
155 |
+
demo = create_gradio_ui()
|
156 |
+
demo.launch(show_error=True)
|
base/__init__.py
ADDED
File without changes
|
base/base_modules.py
ADDED
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
norm_dict = {'BATCH': nn.BatchNorm2d, 'INSTANCE': nn.InstanceNorm2d, 'GROUP': nn.GroupNorm}
|
7 |
+
NUM_GROUPS = 16
|
8 |
+
__all__ = ['ConvNorm', 'ConvBlock', 'ConvBottleNeck', 'ResBlock', 'ResBottleneck', 'PromptResBlock', 'PromptResBottleneck', 'PromptAttentionModule', 'norm_dict', 'SobelEdge']
|
9 |
+
|
10 |
+
|
11 |
+
class Identity(nn.Module):
|
12 |
+
"""
|
13 |
+
Identity mapping for building a residual connection
|
14 |
+
"""
|
15 |
+
def __init__(self):
|
16 |
+
super().__init__()
|
17 |
+
|
18 |
+
def forward(self, x):
|
19 |
+
return x
|
20 |
+
|
21 |
+
|
22 |
+
class ConvNorm(nn.Module):
|
23 |
+
"""
|
24 |
+
Convolution and normalization
|
25 |
+
"""
|
26 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1, leaky=True, norm='INSTANCE', activation=True):
|
27 |
+
super().__init__()
|
28 |
+
# determine basic attributes
|
29 |
+
self.norm_type = norm
|
30 |
+
padding = (kernel_size - 1) // 2
|
31 |
+
|
32 |
+
# activation, support PReLU and common ReLU
|
33 |
+
if activation:
|
34 |
+
self.act = nn.LeakyReLU() if leaky else nn.ReLU(inplace=False)
|
35 |
+
# self.act = nn.ELU() if leaky else nn.ReLU(inplace=True)
|
36 |
+
else:
|
37 |
+
self.act = None
|
38 |
+
|
39 |
+
# instantiate layers
|
40 |
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)
|
41 |
+
|
42 |
+
if norm in ['BATCH', 'INSTANCE']:
|
43 |
+
norm_layer = norm_dict[norm]
|
44 |
+
self.norm = norm_layer(out_channels)
|
45 |
+
elif norm == 'GROUP':
|
46 |
+
norm_layer = norm_dict[norm]
|
47 |
+
self.norm = norm_layer(NUM_GROUPS, in_channels)
|
48 |
+
elif norm == 'NONE':
|
49 |
+
self.norm = nn.Identity()
|
50 |
+
else:
|
51 |
+
raise NotImplementedError(f'Normalization type {norm} not implemented')
|
52 |
+
|
53 |
+
def basic_forward(self, x):
|
54 |
+
x = self.conv(x)
|
55 |
+
x = self.norm(x)
|
56 |
+
if self.act:
|
57 |
+
x = self.act(x)
|
58 |
+
return x
|
59 |
+
|
60 |
+
def group_forward(self, x):
|
61 |
+
x = self.norm(x)
|
62 |
+
if self.act:
|
63 |
+
x = self.act(x)
|
64 |
+
x = self.conv(x)
|
65 |
+
return x
|
66 |
+
|
67 |
+
def forward(self, x):
|
68 |
+
if self.norm_type in ['BATCH', 'INSTANCE']:
|
69 |
+
return self.basic_forward(x)
|
70 |
+
else:
|
71 |
+
return self.group_forward(x)
|
72 |
+
|
73 |
+
|
74 |
+
class PromptAttentionModule(nn.Module):
|
75 |
+
def __init__(self, in_channels: int, prompt_channels: int, mid_channels: int) -> None:
|
76 |
+
super().__init__()
|
77 |
+
self.gap = nn.AdaptiveAvgPool2d(1)
|
78 |
+
self.conv_down = nn.Linear(in_channels, mid_channels)
|
79 |
+
self.prompt_down = nn.Linear(prompt_channels, mid_channels)
|
80 |
+
self.fc = nn.Linear(2 * mid_channels, in_channels)
|
81 |
+
|
82 |
+
def forward(self, x: torch.Tensor, prompt_in: torch.Tensor):
|
83 |
+
"""
|
84 |
+
Args:
|
85 |
+
x: (B, C_im, H, W)
|
86 |
+
prompt_in: (B, C_prompt)
|
87 |
+
"""
|
88 |
+
x_gap = self.gap(x).squeeze(-1).squeeze(-1) # (B, C_im)
|
89 |
+
x_gap = self.conv_down(x_gap) # (B, C_mid)
|
90 |
+
prompt_down = self.prompt_down(prompt_in) # (B, C_mid)
|
91 |
+
gating = torch.cat([x_gap, prompt_down], dim=-1) # (B, 2 * C_mid)
|
92 |
+
gating = F.sigmoid(self.fc(F.relu(gating)))[..., None, None] # (B, C_im, 1, 1)
|
93 |
+
return x * gating
|
94 |
+
|
95 |
+
|
96 |
+
class ConvBlock(nn.Module):
|
97 |
+
"""
|
98 |
+
Convolutional blocks
|
99 |
+
"""
|
100 |
+
def __init__(self, in_channels, out_channels, stride=1, leaky=False, norm='INSTANCE'):
|
101 |
+
super().__init__()
|
102 |
+
self.norm_type = norm
|
103 |
+
# activation, support PReLU and common ReLU
|
104 |
+
self.act = nn.LeakyReLU() if leaky else nn.ReLU(inplace=False)
|
105 |
+
# self.act = nn.ELU() if leaky else nn.ReLU(inplace=True)
|
106 |
+
|
107 |
+
self.conv1 = ConvNorm(in_channels, out_channels, 3, stride, leaky, norm, True)
|
108 |
+
self.conv2 = ConvNorm(out_channels, out_channels, 3, 1, leaky, norm, False)
|
109 |
+
|
110 |
+
def forward(self, x):
|
111 |
+
out = self.conv1(x)
|
112 |
+
out = self.conv2(out)
|
113 |
+
|
114 |
+
if self.norm_type != 'GROUP':
|
115 |
+
out = self.act(out)
|
116 |
+
|
117 |
+
return out
|
118 |
+
|
119 |
+
|
120 |
+
class ResBlock(nn.Module):
|
121 |
+
"""
|
122 |
+
Residual blocks
|
123 |
+
"""
|
124 |
+
def __init__(self, in_channels, out_channels, stride=1, use_dropout=False, leaky=False, norm='INSTANCE'):
|
125 |
+
super().__init__()
|
126 |
+
self.norm_type = norm
|
127 |
+
self.act = nn.LeakyReLU() if leaky else nn.ReLU(inplace=False)
|
128 |
+
self.dropout = nn.Dropout2d(p=0.1) if use_dropout else None
|
129 |
+
# self.act = nn.ELU() if leaky else nn.ReLU(inplace=True)
|
130 |
+
|
131 |
+
self.conv1 = ConvNorm(in_channels, out_channels, 3, stride, leaky, norm, True)
|
132 |
+
self.conv2 = ConvNorm(out_channels, out_channels, 3, 1, leaky, norm, False)
|
133 |
+
|
134 |
+
need_map = in_channels != out_channels or stride != 1
|
135 |
+
self.id = ConvNorm(in_channels, out_channels, 1, stride, leaky, norm, False) if need_map else Identity()
|
136 |
+
|
137 |
+
def forward(self, x):
|
138 |
+
identity = x
|
139 |
+
out = self.conv1(x)
|
140 |
+
out = self.conv2(out)
|
141 |
+
identity = self.id(identity)
|
142 |
+
|
143 |
+
out = out + identity
|
144 |
+
if self.norm_type != 'GROUP':
|
145 |
+
out = self.act(out)
|
146 |
+
|
147 |
+
if self.dropout:
|
148 |
+
out = self.dropout(out)
|
149 |
+
|
150 |
+
return out
|
151 |
+
|
152 |
+
|
153 |
+
class ConvBottleNeck(nn.Module):
|
154 |
+
"""
|
155 |
+
Convolutional bottleneck blocks
|
156 |
+
"""
|
157 |
+
def __init__(self, in_channels, out_channels, stride=1, leaky=False, norm='INSTANCE'):
|
158 |
+
super().__init__()
|
159 |
+
self.norm_type = norm
|
160 |
+
middle_channels = in_channels // 4
|
161 |
+
self.act = nn.LeakyReLU() if leaky else nn.ReLU(inplace=False)
|
162 |
+
# self.act = nn.ELU() if leaky else nn.ReLU(inplace=True)
|
163 |
+
|
164 |
+
self.conv1 = ConvNorm(in_channels, middle_channels, 1, 1, leaky, norm, True)
|
165 |
+
self.conv2 = ConvNorm(middle_channels, middle_channels, 3, stride, leaky, norm, True)
|
166 |
+
self.conv3 = ConvNorm(middle_channels, out_channels, 1, 1, leaky, norm, False)
|
167 |
+
|
168 |
+
def forward(self, x):
|
169 |
+
out = self.conv1(x)
|
170 |
+
out = self.conv2(out)
|
171 |
+
out = self.conv3(out)
|
172 |
+
|
173 |
+
if self.norm_type != 'GROUP':
|
174 |
+
out = self.act(out)
|
175 |
+
|
176 |
+
return out
|
177 |
+
|
178 |
+
|
179 |
+
class ResBottleneck(nn.Module):
|
180 |
+
"""
|
181 |
+
Residual bottleneck blocks
|
182 |
+
"""
|
183 |
+
def __init__(self, in_channels, out_channels, stride=1, use_dropout=False, leaky=False, norm='INSTANCE'):
|
184 |
+
super().__init__()
|
185 |
+
self.norm_type = norm
|
186 |
+
middle_channels = in_channels // 4
|
187 |
+
self.act = nn.LeakyReLU() if leaky else nn.ReLU(inplace=False)
|
188 |
+
self.dropout = nn.Dropout2d(p=0.1) if use_dropout else None
|
189 |
+
# self.act = nn.ELU() if leaky else nn.ReLU(inplace=True)
|
190 |
+
|
191 |
+
self.conv1 = ConvNorm(in_channels, middle_channels, 1, 1, leaky, norm, True)
|
192 |
+
self.conv2 = ConvNorm(middle_channels, middle_channels, 3, stride, leaky, norm, True)
|
193 |
+
self.conv3 = ConvNorm(middle_channels, out_channels, 1, 1, leaky, norm, False)
|
194 |
+
|
195 |
+
need_map = in_channels != out_channels or stride != 1
|
196 |
+
self.id = ConvNorm(in_channels, out_channels, 1, stride, leaky, norm, False) if need_map else Identity()
|
197 |
+
|
198 |
+
def forward(self, x):
|
199 |
+
identity = x
|
200 |
+
out = self.conv1(x)
|
201 |
+
out = self.conv2(out)
|
202 |
+
out = self.conv3(out)
|
203 |
+
identity = self.id(identity)
|
204 |
+
|
205 |
+
out = out + identity
|
206 |
+
if self.norm_type != 'GROUP':
|
207 |
+
out = self.act(out)
|
208 |
+
|
209 |
+
if self.dropout:
|
210 |
+
out = self.dropout(out)
|
211 |
+
|
212 |
+
return out
|
213 |
+
|
214 |
+
|
215 |
+
class PromptResBlock(nn.Module):
|
216 |
+
"""
|
217 |
+
Residual blocks
|
218 |
+
"""
|
219 |
+
def __init__(self, in_channels, out_channels, stride=1, use_dropout=False, leaky=False, norm='INSTANCE'):
|
220 |
+
super().__init__()
|
221 |
+
self.norm_type = norm
|
222 |
+
self.act = nn.LeakyReLU() if leaky else nn.ReLU(inplace=False)
|
223 |
+
self.dropout = nn.Dropout2d(p=0.1) if use_dropout else None
|
224 |
+
# self.act = nn.ELU() if leaky else nn.ReLU(inplace=True)
|
225 |
+
|
226 |
+
self.conv1 = ConvNorm(in_channels, out_channels, 3, stride, leaky, norm, True)
|
227 |
+
self.conv2 = ConvNorm(out_channels, out_channels, 3, 1, leaky, norm, False)
|
228 |
+
self.attn = PromptAttentionModule(out_channels, 512, out_channels // 4)
|
229 |
+
|
230 |
+
need_map = in_channels != out_channels or stride != 1
|
231 |
+
self.id = ConvNorm(in_channels, out_channels, 1, stride, leaky, norm, False) if need_map else Identity()
|
232 |
+
|
233 |
+
def forward(self, x, prompt_in):
|
234 |
+
identity = x
|
235 |
+
out = self.conv1(x)
|
236 |
+
out = self.conv2(out)
|
237 |
+
out = self.attn(out, prompt_in)
|
238 |
+
identity = self.id(identity)
|
239 |
+
|
240 |
+
out = out + identity
|
241 |
+
if self.norm_type != 'GROUP':
|
242 |
+
out = self.act(out)
|
243 |
+
|
244 |
+
if self.dropout:
|
245 |
+
out = self.dropout(out)
|
246 |
+
|
247 |
+
return out
|
248 |
+
|
249 |
+
|
250 |
+
class PromptResBottleneck(nn.Module):
|
251 |
+
"""
|
252 |
+
Residual bottleneck blocks
|
253 |
+
"""
|
254 |
+
def __init__(self, in_channels, out_channels, stride=1, use_dropout=False, leaky=False, norm='INSTANCE'):
|
255 |
+
super().__init__()
|
256 |
+
self.norm_type = norm
|
257 |
+
middle_channels = in_channels // 4
|
258 |
+
self.act = nn.LeakyReLU() if leaky else nn.ReLU(inplace=False)
|
259 |
+
self.dropout = nn.Dropout2d(p=0.1) if use_dropout else None
|
260 |
+
# self.act = nn.ELU() if leaky else nn.ReLU(inplace=True)
|
261 |
+
|
262 |
+
self.conv1 = ConvNorm(in_channels, middle_channels, 1, 1, leaky, norm, True)
|
263 |
+
self.conv2 = ConvNorm(middle_channels, middle_channels, 3, stride, leaky, norm, True)
|
264 |
+
self.conv3 = ConvNorm(middle_channels, out_channels, 1, 1, leaky, norm, False)
|
265 |
+
self.attn = PromptAttentionModule(out_channels, 512, out_channels // 4)
|
266 |
+
|
267 |
+
need_map = in_channels != out_channels or stride != 1
|
268 |
+
self.id = ConvNorm(in_channels, out_channels, 1, stride, leaky, norm, False) if need_map else Identity()
|
269 |
+
|
270 |
+
def forward(self, x, prompt_in):
|
271 |
+
identity = x
|
272 |
+
out = self.conv1(x)
|
273 |
+
out = self.conv2(out)
|
274 |
+
out = self.conv3(out)
|
275 |
+
out = self.attn(out, prompt_in)
|
276 |
+
identity = self.id(identity)
|
277 |
+
|
278 |
+
out = out + identity
|
279 |
+
if self.norm_type != 'GROUP':
|
280 |
+
out = self.act(out)
|
281 |
+
|
282 |
+
if self.dropout:
|
283 |
+
out = self.dropout(out)
|
284 |
+
|
285 |
+
return out
|
286 |
+
|
287 |
+
|
288 |
+
class SobelEdge(nn.Module):
|
289 |
+
def __init__(self, input_dim, channels, kernel_size=3, stride=1):
|
290 |
+
super().__init__()
|
291 |
+
conv = getattr(nn, 'Conv%dd' % input_dim)
|
292 |
+
self.filter = conv(channels, channels, kernel_size, stride, padding=(kernel_size - 1) // 2,
|
293 |
+
groups=channels, bias=False)
|
294 |
+
sobel = [[1, 2, 1], [0, 0, 0], [-1, -2, -1]]
|
295 |
+
sobel_kernel = torch.tensor(sobel, dtype=torch.float32).unsqueeze(0).expand([channels, 1] + [kernel_size] * input_dim)
|
296 |
+
self.filter.weight = nn.Parameter(sobel_kernel, requires_grad=False)
|
297 |
+
|
298 |
+
def forward(self, x):
|
299 |
+
with torch.no_grad():
|
300 |
+
out = self.filter(x)
|
301 |
+
return out
|
base/base_segmentation.py
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.distributed as dist
|
5 |
+
from torch.cuda.amp import GradScaler
|
6 |
+
from abc import ABC, abstractmethod
|
7 |
+
from utils.iteration.iterator import MetricMeter
|
8 |
+
from utils.ddp_utils import gather_object_across_processes
|
9 |
+
|
10 |
+
|
11 |
+
class BaseSegmentationModel(ABC):
|
12 |
+
"""
|
13 |
+
This class is an abstract base class (ABC) for segmentation models.
|
14 |
+
To create a subclass, you need to implement the following four methods:
|
15 |
+
-- <__init__>: initialize the class.
|
16 |
+
-- <set_input>: unpack data from dataset.
|
17 |
+
-- <optimize_parameters>: calculate losses, gradients, and update network weights.
|
18 |
+
-- <evaluate_one_step>: performance evaluation.
|
19 |
+
"""
|
20 |
+
def __init__(self, cfg, num_classes, amp=False):
|
21 |
+
# initialize training CUDA devices
|
22 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
23 |
+
|
24 |
+
# training configuration
|
25 |
+
self.cfg = cfg
|
26 |
+
self.num_classes = num_classes
|
27 |
+
self.is_mixed = amp
|
28 |
+
self.scaler = GradScaler()
|
29 |
+
self.start_epoch = -1
|
30 |
+
|
31 |
+
# initialize networks, criterion, optimizer and scheduler
|
32 |
+
self.network = None
|
33 |
+
self.criterion = None
|
34 |
+
self.optimizer = None
|
35 |
+
self.scheduler = None
|
36 |
+
|
37 |
+
# visualization
|
38 |
+
self.visual_names = []
|
39 |
+
self.loss_names = []
|
40 |
+
|
41 |
+
def train(self):
|
42 |
+
self.network.train()
|
43 |
+
return self
|
44 |
+
|
45 |
+
def eval(self):
|
46 |
+
self.network.eval()
|
47 |
+
return self
|
48 |
+
|
49 |
+
def training(self):
|
50 |
+
return self.network.training
|
51 |
+
|
52 |
+
def initialize_metric_meter(self, class_list):
|
53 |
+
self.class_list = class_list
|
54 |
+
self.metric_meter = MetricMeter(metrics=['dice', 'hd95', 'asd'], class_names=class_list, subject_names=['name'])
|
55 |
+
self.train_loss = MetricMeter(metrics=self.loss_names, class_names=['train'])
|
56 |
+
self.val_loss = MetricMeter(metrics=['loss'], class_names=['val'])
|
57 |
+
|
58 |
+
def update_loss_meter(self, print=False):
|
59 |
+
loss_dict = {}
|
60 |
+
for loss_name in self.loss_names:
|
61 |
+
try:
|
62 |
+
loss_value = float(getattr(self, loss_name))
|
63 |
+
loss_list = gather_object_across_processes(loss_value)
|
64 |
+
loss_value = np.mean(loss_list)
|
65 |
+
except:
|
66 |
+
continue
|
67 |
+
loss_dict['train_{}'.format(loss_name)] = loss_value
|
68 |
+
self.train_loss.update(loss_dict)
|
69 |
+
stats = self.train_loss.report(print_stats=print, mean_only=True)
|
70 |
+
return stats
|
71 |
+
|
72 |
+
@abstractmethod
|
73 |
+
def set_input(self, *args, **kwargs):
|
74 |
+
raise NotImplementedError
|
75 |
+
|
76 |
+
@abstractmethod
|
77 |
+
def optimize_parameters(self, *args, **kwargs):
|
78 |
+
raise NotImplementedError
|
79 |
+
|
80 |
+
@abstractmethod
|
81 |
+
def evaluate_one_step(self, *args, **kwargs):
|
82 |
+
raise NotImplementedError
|
83 |
+
|
84 |
+
def load_networks(self, ckpt_path, resume_training=False):
|
85 |
+
checkpoint = torch.load(ckpt_path, map_location=self.device)
|
86 |
+
print('Load ckpt weight: {}'.format(ckpt_path))
|
87 |
+
self.network.load_state_dict(checkpoint['net'])
|
88 |
+
if resume_training:
|
89 |
+
print('Load training config for breakpoint continuation')
|
90 |
+
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
91 |
+
self.scheduler.load_state_dict(checkpoint['scheduler'])
|
92 |
+
self.scaler.load_state_dict(checkpoint['scaler'])
|
93 |
+
self.start_epoch = checkpoint['epoch']
|
94 |
+
|
95 |
+
def save_networks(self, epoch_index, save_dir):
|
96 |
+
if dist.get_rank() == 0:
|
97 |
+
checkpoint = {
|
98 |
+
"net": self.network.state_dict(),
|
99 |
+
'optimizer': self.optimizer.state_dict(),
|
100 |
+
'scheduler': self.scheduler.state_dict(),
|
101 |
+
'scaler': self.scaler.state_dict(),
|
102 |
+
"epoch": epoch_index
|
103 |
+
}
|
104 |
+
torch.save(checkpoint,
|
105 |
+
os.path.join(save_dir, 'Epoch_{}.pkl'.format(epoch_index + 1)))
|
106 |
+
|
107 |
+
|
108 |
+
class MultiNetworkSegmentationModel(ABC):
|
109 |
+
"""
|
110 |
+
This class is an abstract base class (ABC) for segmentation models.
|
111 |
+
To create a subclass, you need to implement the following four methods:
|
112 |
+
-- <__init__>: initialize the class.
|
113 |
+
-- <set_input>: unpack data from dataset.
|
114 |
+
-- <optimize_parameters>: calculate losses, gradients, and update network weights.
|
115 |
+
-- <evaluate_one_step>: performance evaluation.
|
116 |
+
"""
|
117 |
+
def __init__(self, cfg, num_classes, amp=False):
|
118 |
+
# initialize training CUDA devices
|
119 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
120 |
+
|
121 |
+
# training configuration
|
122 |
+
self.cfg = cfg
|
123 |
+
self.num_classes = num_classes
|
124 |
+
self.is_mixed = amp
|
125 |
+
self.scaler = GradScaler()
|
126 |
+
self.start_epoch = -1
|
127 |
+
|
128 |
+
# initialize networks, criterion, optimizer and scheduler
|
129 |
+
self.net_names = []
|
130 |
+
|
131 |
+
# visualization
|
132 |
+
self.visual_names = []
|
133 |
+
self.loss_names = []
|
134 |
+
|
135 |
+
def train(self):
|
136 |
+
for name in self.net_names:
|
137 |
+
net = getattr(self, name)
|
138 |
+
net.train()
|
139 |
+
return self
|
140 |
+
|
141 |
+
def eval(self):
|
142 |
+
for name in self.net_names:
|
143 |
+
net = getattr(self, name)
|
144 |
+
net.eval()
|
145 |
+
return self
|
146 |
+
|
147 |
+
def training(self):
|
148 |
+
return getattr(self, self.net_names[0]).training
|
149 |
+
|
150 |
+
def initialize_metric_meter(self, class_list):
|
151 |
+
self.class_list = class_list
|
152 |
+
self.metric_meter = MetricMeter(metrics=['dice', 'hd95', 'asd'], class_names=class_list, subject_names=['name'])
|
153 |
+
self.train_loss = MetricMeter(metrics=self.loss_names, class_names=['train'])
|
154 |
+
self.val_loss = MetricMeter(metrics=['loss'], class_names=['val'])
|
155 |
+
|
156 |
+
def update_loss_meter(self, print=False):
|
157 |
+
loss_dict = {}
|
158 |
+
for loss_name in self.loss_names:
|
159 |
+
try:
|
160 |
+
loss_value = float(getattr(self, loss_name))
|
161 |
+
loss_list = gather_object_across_processes(loss_value)
|
162 |
+
loss_value = np.mean(loss_list)
|
163 |
+
except:
|
164 |
+
continue
|
165 |
+
loss_dict['train_{}'.format(loss_name)] = loss_value
|
166 |
+
self.train_loss.update(loss_dict)
|
167 |
+
stats = self.train_loss.report(print_stats=print, mean_only=True)
|
168 |
+
return stats
|
169 |
+
|
170 |
+
@abstractmethod
|
171 |
+
def set_input(self, *args, **kwargs):
|
172 |
+
raise NotImplementedError
|
173 |
+
|
174 |
+
@abstractmethod
|
175 |
+
def optimize_parameters(self, *args, **kwargs):
|
176 |
+
raise NotImplementedError
|
177 |
+
|
178 |
+
@abstractmethod
|
179 |
+
def evaluate_one_step(self, *args, **kwargs):
|
180 |
+
raise NotImplementedError
|
181 |
+
|
182 |
+
def load_networks(self, ckpt_path, resume_training=False, strict=True):
|
183 |
+
checkpoint = torch.load(ckpt_path, map_location=self.device)
|
184 |
+
print('Load ckpt weight: {}'.format(ckpt_path))
|
185 |
+
if resume_training:
|
186 |
+
print('Load training config for breakpoint continuation')
|
187 |
+
self.scaler.load_state_dict(checkpoint['scaler'])
|
188 |
+
self.start_epoch = checkpoint['epoch']
|
189 |
+
for name in self.net_names:
|
190 |
+
try:
|
191 |
+
getattr(self, name).load_state_dict(checkpoint[name], strict=strict)
|
192 |
+
if resume_training:
|
193 |
+
getattr(self, '{}_optimizer'.format(name)).load_state_dict(checkpoint['{}_optimizer'.format(name)])
|
194 |
+
getattr(self, '{}_scheduler'.format(name)).load_state_dict(checkpoint['{}_scheduler'.format(name)])
|
195 |
+
except:
|
196 |
+
print('Failed to load network: {}'.format(name))
|
197 |
+
|
198 |
+
def load_single_network(self, ckpt_path, net_name, resume_training=False, strict=True):
|
199 |
+
checkpoint = torch.load(ckpt_path, map_location=self.device)
|
200 |
+
print('Load ckpt weight: {}'.format(ckpt_path))
|
201 |
+
if resume_training:
|
202 |
+
print('Load training config for breakpoint continuation')
|
203 |
+
self.scaler.load_state_dict(checkpoint['scaler'])
|
204 |
+
self.start_epoch = checkpoint['epoch']
|
205 |
+
getattr(self, net_name).load_state_dict(checkpoint[net_name], strict=strict)
|
206 |
+
if resume_training:
|
207 |
+
getattr(self, '{}_optimizer'.format(net_name)).load_state_dict(checkpoint['{}_optimizer'.format(net_name)])
|
208 |
+
getattr(self, '{}_scheduler'.format(net_name)).load_state_dict(checkpoint['{}_scheduler'.format(net_name)])
|
209 |
+
|
210 |
+
def save_networks(self, epoch_index, save_dir):
|
211 |
+
if dist.get_rank() == 0:
|
212 |
+
checkpoint = {}
|
213 |
+
for name in self.net_names:
|
214 |
+
checkpoint[name] = getattr(self, name).state_dict()
|
215 |
+
checkpoint['{}_optimizer'.format(name)] = getattr(self, '{}_optimizer'.format(name)).state_dict()
|
216 |
+
checkpoint['{}_scheduler'.format(name)] = getattr(self, '{}_scheduler'.format(name)).state_dict()
|
217 |
+
checkpoint['scaler'] = self.scaler.state_dict()
|
218 |
+
checkpoint['epoch'] = epoch_index
|
219 |
+
torch.save(checkpoint, os.path.join(save_dir, 'Epoch_{}.pkl'.format(epoch_index)))
|
220 |
+
|
221 |
+
def save_best_networks(self, epoch_index, save_dir):
|
222 |
+
if dist.get_rank() == 0:
|
223 |
+
checkpoint = {}
|
224 |
+
for name in self.net_names:
|
225 |
+
checkpoint[name] = getattr(self, name).state_dict()
|
226 |
+
checkpoint['{}_optimizer'.format(name)] = getattr(self, '{}_optimizer'.format(name)).state_dict()
|
227 |
+
checkpoint['{}_scheduler'.format(name)] = getattr(self, '{}_scheduler'.format(name)).state_dict()
|
228 |
+
checkpoint['scaler'] = self.scaler.state_dict()
|
229 |
+
checkpoint['epoch'] = epoch_index
|
230 |
+
torch.save(checkpoint, os.path.join(save_dir, 'Epoch_best.pkl'))
|
base/base_wandb_model.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import wandb
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from monai.visualize import blend_images
|
5 |
+
|
6 |
+
|
7 |
+
class WandBModel:
|
8 |
+
"""
|
9 |
+
Enable WandB features to the model using multiple inheritance
|
10 |
+
"""
|
11 |
+
def __init__(self, *args, **kwargs):
|
12 |
+
# the following attributes should be initialized by class `BaseSegmentationModel`
|
13 |
+
self.visual_pairs = None
|
14 |
+
self.train_loss = None
|
15 |
+
self.val_loss = None
|
16 |
+
self.metric_meter = None
|
17 |
+
self.name = None
|
18 |
+
# the following attributes should be initialized by the child class
|
19 |
+
self.val_table = None
|
20 |
+
|
21 |
+
def volume2videos(self, time_dim=3, tag=''):
|
22 |
+
"""
|
23 |
+
Convert 3D volumes to video in favor of WandB logging
|
24 |
+
Args:
|
25 |
+
time_dim: the spatial dimension to be converted as the time dimension, default is the axial axis (dim 3)
|
26 |
+
tag: extra information for logging
|
27 |
+
"""
|
28 |
+
videos = []
|
29 |
+
for image_pair in self.visual_pairs:
|
30 |
+
try:
|
31 |
+
pair_name = getattr(self, image_pair['name'])
|
32 |
+
image = getattr(self, image_pair['image'])
|
33 |
+
mask = getattr(self, image_pair['mask'])
|
34 |
+
vis_type = image_pair['type']
|
35 |
+
except:
|
36 |
+
continue
|
37 |
+
for i in range(image.shape[0]): # deallocate the batch dim
|
38 |
+
image2save = image[i, ...]
|
39 |
+
mask2save = mask[i, ...]
|
40 |
+
item_name = pair_name[i]
|
41 |
+
# detach the tensor, format [C, H, W, D]
|
42 |
+
image_numpy = image2save.detach()
|
43 |
+
mask_numpy = mask2save.detach()
|
44 |
+
if mask_numpy.shape[0] > 1:
|
45 |
+
mask_numpy = torch.argmax(mask_numpy, dim=0, keepdim=True)
|
46 |
+
# (C, H, W, D), torch.Tensor on device
|
47 |
+
pair_blend = blend_images(image_numpy, mask_numpy, alpha=0.5) * 255
|
48 |
+
# permute the axes to (time, channel, height, width)
|
49 |
+
spatial_dim = list(range(1, len(pair_blend.shape[1:]) + 1))
|
50 |
+
spatial_dim.remove(time_dim)
|
51 |
+
pair_blend = pair_blend.permute([time_dim, 0] + spatial_dim).cpu().numpy().astype(np.uint8)
|
52 |
+
# record in the wandb.Video class
|
53 |
+
video = wandb.Video(pair_blend, fps=8, caption='{}_{}{}'.format(item_name, vis_type, tag))
|
54 |
+
videos.append(video)
|
55 |
+
return videos
|
56 |
+
|
57 |
+
def log_scaler(self, key, value, step=None):
|
58 |
+
"""
|
59 |
+
Log manually defined scaler data
|
60 |
+
"""
|
61 |
+
wandb.log({key: np.round(value, decimals=4)}, step=step)
|
62 |
+
|
63 |
+
def log_train_loss(self, step=None):
|
64 |
+
"""
|
65 |
+
Log train loss
|
66 |
+
"""
|
67 |
+
data_dict = self.train_loss.pop_data(True)
|
68 |
+
for key, value in data_dict.items():
|
69 |
+
wandb.log({'train/{}'.format(key): value}, step=step)
|
70 |
+
|
71 |
+
def log_val_loss(self, step=None):
|
72 |
+
"""
|
73 |
+
Log val loss
|
74 |
+
"""
|
75 |
+
data_dict = self.val_loss.pop_data(True)
|
76 |
+
for key, value in data_dict.items():
|
77 |
+
wandb.log({'val/{}'.format(key): value}, step=step)
|
78 |
+
|
79 |
+
def log_metrics(self, step=None):
|
80 |
+
"""
|
81 |
+
Log validation metrics as wandb.Table
|
82 |
+
"""
|
83 |
+
df = self.metric_meter.to_df()
|
84 |
+
wandb.log({'val/metrics': wandb.Table(dataframe=df)}, step=step)
|
85 |
+
|
86 |
+
def log_vis(self, key, step=None, time_dim=3, tag=''):
|
87 |
+
"""
|
88 |
+
Log training intermediate visualizations
|
89 |
+
"""
|
90 |
+
videos = self.volume2videos(time_dim, tag)
|
91 |
+
wandb.log({key: videos}, step=step)
|
92 |
+
|
93 |
+
def update_val_visualization(self, time_dim=3, tag=''):
|
94 |
+
"""
|
95 |
+
Update the validation visualization to buffer, called every step of evaluation
|
96 |
+
"""
|
97 |
+
videos = self.volume2videos(time_dim, tag)
|
98 |
+
self.val_table.add_data(self.name, *videos)
|
99 |
+
|
100 |
+
def log_val_visualization(self, step=None):
|
101 |
+
"""
|
102 |
+
Log validation visualization
|
103 |
+
"""
|
104 |
+
wandb.log({'val/visualization': self.val_table}, step=step)
|
105 |
+
# re-initialize the table for next logging
|
106 |
+
del self.val_table
|
107 |
+
self.val_table = wandb.Table(columns=['ID'] + [pair['type'] for pair in self.visual_pairs])
|
checkpoint/stage_i.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:aaad39f42a0f916ba44b39071dcfbf1145ee43f6f5a269e3f4364b81d361d794
|
3 |
+
size 494807162
|
checkpoint/stage_ii.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:155209bbe587905366b100cf2e8fadc9e8b9c672a0920eb848fcb80a3fcd5e8c
|
3 |
+
size 425297586
|
ckpt/BiomedCLIP/biomed-vlp-eval.svg
ADDED
|
ckpt/BiomedCLIP/biomed_clip_example.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
ckpt/BiomedCLIP/config.json
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"BertForMaskedLM"
|
4 |
+
],
|
5 |
+
"model_type": "bert",
|
6 |
+
"attention_probs_dropout_prob": 0.1,
|
7 |
+
"hidden_act": "gelu",
|
8 |
+
"hidden_dropout_prob": 0.1,
|
9 |
+
"hidden_size": 768,
|
10 |
+
"initializer_range": 0.02,
|
11 |
+
"intermediate_size": 3072,
|
12 |
+
"max_position_embeddings": 512,
|
13 |
+
"num_attention_heads": 12,
|
14 |
+
"num_hidden_layers": 12,
|
15 |
+
"type_vocab_size": 2,
|
16 |
+
"vocab_size": 30522
|
17 |
+
}
|
ckpt/BiomedCLIP/open_clip_config.json
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"model_cfg": {
|
3 |
+
"embed_dim": 512,
|
4 |
+
"vision_cfg": {
|
5 |
+
"timm_model_name": "vit_base_patch16_224",
|
6 |
+
"timm_model_pretrained": false,
|
7 |
+
"timm_pool": "",
|
8 |
+
"timm_proj": "linear",
|
9 |
+
"image_size": 224
|
10 |
+
},
|
11 |
+
"text_cfg": {
|
12 |
+
"hf_model_name": "./ckpt/BiomedCLIP/",
|
13 |
+
"hf_tokenizer_name": "./ckpt/BiomedCLIP/",
|
14 |
+
"hf_proj_type": "mlp",
|
15 |
+
"hf_pooler_type": "cls_last_hidden_state_pooler",
|
16 |
+
"context_length": 77
|
17 |
+
}
|
18 |
+
},
|
19 |
+
"preprocess_cfg": {
|
20 |
+
"mean": [
|
21 |
+
0.48145466,
|
22 |
+
0.4578275,
|
23 |
+
0.40821073
|
24 |
+
],
|
25 |
+
"std": [
|
26 |
+
0.26862954,
|
27 |
+
0.26130258,
|
28 |
+
0.27577711
|
29 |
+
]
|
30 |
+
}
|
31 |
+
}
|
ckpt/BiomedCLIP/special_tokens_map.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cls_token": "[CLS]",
|
3 |
+
"mask_token": "[MASK]",
|
4 |
+
"pad_token": "[PAD]",
|
5 |
+
"sep_token": "[SEP]",
|
6 |
+
"unk_token": "[UNK]"
|
7 |
+
}
|
ckpt/BiomedCLIP/tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
ckpt/BiomedCLIP/tokenizer_config.json
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"clean_up_tokenization_spaces": true,
|
3 |
+
"cls_token": "[CLS]",
|
4 |
+
"do_basic_tokenize": true,
|
5 |
+
"do_lower_case": true,
|
6 |
+
"mask_token": "[MASK]",
|
7 |
+
"model_max_length": 1000000000000000019884624838656,
|
8 |
+
"never_split": null,
|
9 |
+
"pad_token": "[PAD]",
|
10 |
+
"sep_token": "[SEP]",
|
11 |
+
"strip_accents": null,
|
12 |
+
"tokenize_chinese_chars": true,
|
13 |
+
"tokenizer_class": "BertTokenizer",
|
14 |
+
"unk_token": "[UNK]"
|
15 |
+
}
|
ckpt/BiomedCLIP/vocab.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
configs/confocal.cfg
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MODEL:
|
2 |
+
IMG_SIZE: 1024
|
3 |
+
CROP_SIZE: 320
|
4 |
+
CONFOCAL_IN: 3 # 3-channel microscope file
|
5 |
+
CONFOCAL_OUT: 1 # nuclei
|
6 |
+
IMC_IN: 4 # 3-channel microscope file + 1 channel confocal
|
7 |
+
IMC_OUT: 1 # nuclei
|
8 |
+
GRAD_CKPT: True
|
9 |
+
TIMM_MODEL: none
|
10 |
+
NORM: INSTANCE
|
11 |
+
CONFOCAL_PATH: none
|
12 |
+
IMC_PATH: none
|
13 |
+
TRAIN:
|
14 |
+
LR_G: 0.0002
|
15 |
+
LR_D: 0.0002
|
16 |
+
DECAY: 0.0
|
17 |
+
BETA1: 0.5
|
18 |
+
EARLY_STAGE: 0
|
19 |
+
BURN_IN: 0
|
20 |
+
BURN: 500
|
21 |
+
RAMPUP: 1000
|
22 |
+
EPOCHS: 1000
|
23 |
+
BATCHSIZE: 16
|
24 |
+
CROP_SAMPLE_NUM: 16
|
25 |
+
RATIO: 0.2
|
26 |
+
SEED: 42
|
27 |
+
PERTURB_PROB: 0.1
|
28 |
+
IMC_RATIO: 100.0
|
29 |
+
CON_RATIO: 100.0
|
30 |
+
SIM_RATIO: 50.0
|
31 |
+
EDGE_RATIO: 100.0
|
32 |
+
ADV_RATIO: 1.0
|
33 |
+
CLR_RATIO: 0.0
|
34 |
+
FREQ_RATIO: 0.00001
|
35 |
+
TEST:
|
36 |
+
BATCHSIZE: 32
|
configs/confocal_marker.cfg
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MODEL:
|
2 |
+
IMG_SIZE: 1024
|
3 |
+
CROP_SIZE: 320
|
4 |
+
IMC_IN: 1
|
5 |
+
IMC_OUT: 1
|
6 |
+
GRAD_CKPT: True
|
7 |
+
PRETRAIN: none
|
8 |
+
TRAIN:
|
9 |
+
LR_G: 0.002
|
10 |
+
LR_D: 0.002
|
11 |
+
DECAY: 0.0
|
12 |
+
BETA1: 0.5
|
13 |
+
EARLY_STAGE: 0
|
14 |
+
BURN_IN: 0
|
15 |
+
BURN: 100
|
16 |
+
RAMPUP: 100
|
17 |
+
EPOCHS: 100
|
18 |
+
BATCHSIZE: 8
|
19 |
+
CROP_SAMPLE_NUM: 8
|
20 |
+
RATIO: 0.2
|
21 |
+
SEED: 42
|
22 |
+
IMC_RATIO: 100.0
|
23 |
+
EDGE_RATIO: 10.0
|
24 |
+
ADV_RATIO: 1.0
|
25 |
+
TEST:
|
26 |
+
BATCHSIZE: 16
|
configs/convertion.cfg
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MODEL:
|
2 |
+
IMG_SIZE: 1024
|
3 |
+
CROP_SIZE: 320
|
4 |
+
IMC_IN: 1
|
5 |
+
IMC_OUT: 1
|
6 |
+
GRAD_CKPT: True
|
7 |
+
TRAIN:
|
8 |
+
LR_G: 0.002
|
9 |
+
LR_D: 0.002
|
10 |
+
DECAY: 0.0
|
11 |
+
BETA1: 0.5
|
12 |
+
EARLY_STAGE: 0
|
13 |
+
BURN_IN: 0
|
14 |
+
BURN: 100
|
15 |
+
RAMPUP: 100
|
16 |
+
EPOCHS: 100
|
17 |
+
BATCHSIZE: 16
|
18 |
+
CROP_SAMPLE_NUM: 8
|
19 |
+
RATIO: 0.2
|
20 |
+
SEED: 42
|
21 |
+
IMC_RATIO: 100.0
|
22 |
+
EDGE_RATIO: 10.0
|
23 |
+
ADV_RATIO: 1.0
|
24 |
+
TEST:
|
25 |
+
BATCHSIZE: 64
|
configs/extend_1.cfg
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MODEL:
|
2 |
+
IMG_SIZE: 1024
|
3 |
+
CROP_SIZE: 320
|
4 |
+
CONFOCAL_IN: 3 # 3-channel microscope file
|
5 |
+
CONFOCAL_OUT: 1 # nuclei
|
6 |
+
IMC_IN: 4 # 3-channel microscope file + 1 channel confocal
|
7 |
+
IMC_OUT: 1 # nuclei
|
8 |
+
GRAD_CKPT: True
|
9 |
+
TIMM_MODEL: none
|
10 |
+
NORM: INSTANCE
|
11 |
+
CONFOCAL_PATH: none
|
12 |
+
IMC_PATH: none
|
13 |
+
TRAIN:
|
14 |
+
LR_G: 0.0002
|
15 |
+
LR_D: 0.0002
|
16 |
+
DECAY: 0.0
|
17 |
+
BETA1: 0.5
|
18 |
+
EARLY_STAGE: 0
|
19 |
+
BURN_IN: 0
|
20 |
+
BURN: 500
|
21 |
+
RAMPUP: 1000
|
22 |
+
EPOCHS: 1000
|
23 |
+
BATCHSIZE: 16
|
24 |
+
CROP_SAMPLE_NUM: 16
|
25 |
+
RATIO: 0.2
|
26 |
+
SEED: 42
|
27 |
+
PERTURB_PROB: 0.1
|
28 |
+
IMC_RATIO: 100.0
|
29 |
+
CON_RATIO: 100.0
|
30 |
+
SIM_RATIO: 50.0
|
31 |
+
EDGE_RATIO: 100.0
|
32 |
+
ADV_RATIO: 1.0
|
33 |
+
CLR_RATIO: 0.0
|
34 |
+
FREQ_RATIO: 0.00001
|
35 |
+
TEST:
|
36 |
+
BATCHSIZE: 32
|
configs/extend_2.cfg
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MODEL:
|
2 |
+
IMG_SIZE: 1024
|
3 |
+
CROP_SIZE: 320
|
4 |
+
IMC_IN: 1
|
5 |
+
IMC_OUT: 1
|
6 |
+
GRAD_CKPT: True
|
7 |
+
PRETRAIN: /mnt/shared_storage/zhaoxiangyu/experiments/IMC_translation_v2/checkpoints/convertion/convertion_0918-task_convertion-ratio_0.2/Epoch_39.pkl
|
8 |
+
TRAIN:
|
9 |
+
LR_G: 0.002
|
10 |
+
LR_D: 0.002
|
11 |
+
DECAY: 0.0
|
12 |
+
BETA1: 0.5
|
13 |
+
EARLY_STAGE: 0
|
14 |
+
BURN_IN: 0
|
15 |
+
BURN: 100
|
16 |
+
RAMPUP: 100
|
17 |
+
EPOCHS: 100
|
18 |
+
BATCHSIZE: 8
|
19 |
+
CROP_SAMPLE_NUM: 8
|
20 |
+
RATIO: 0.2
|
21 |
+
SEED: 42
|
22 |
+
IMC_RATIO: 100.0
|
23 |
+
EDGE_RATIO: 10.0
|
24 |
+
ADV_RATIO: 1.0
|
25 |
+
TEST:
|
26 |
+
BATCHSIZE: 16
|
configs/full.cfg
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MODEL:
|
2 |
+
IMG_SIZE: 1024
|
3 |
+
CROP_SIZE: 320
|
4 |
+
CONFOCAL_IN: 3 # 3-channel microscope file
|
5 |
+
CONFOCAL_OUT: 1 # nuclei
|
6 |
+
IMC_IN: 4 # 3-channel microscope file + 1 channel confocal
|
7 |
+
IMC_OUT: 1 # nuclei
|
8 |
+
CONVERTION_IN: 1
|
9 |
+
CONVERTION_OUT: 1
|
10 |
+
GRAD_CKPT: True
|
11 |
+
TIMM_MODEL: none
|
12 |
+
NORM: INSTANCE
|
13 |
+
CONFOCAL_PATH: none
|
14 |
+
IMC_PATH: none
|
15 |
+
CONVERTION_PATH: /mnt/shared_storage/zhaoxiangyu/experiments/IMC_translation_v2/checkpoints/convertion/convertion_0918-task_convertion-ratio_0.2/Epoch_39.pkl
|
16 |
+
TRAIN:
|
17 |
+
LR_G: 0.0002
|
18 |
+
LR_D: 0.0002
|
19 |
+
DECAY: 0.0
|
20 |
+
BETA1: 0.5
|
21 |
+
EARLY_STAGE: 0
|
22 |
+
BURN_IN: 0
|
23 |
+
BURN: 500
|
24 |
+
RAMPUP: 1000
|
25 |
+
EPOCHS: 1000
|
26 |
+
BATCHSIZE: 16
|
27 |
+
CROP_SAMPLE_NUM: 16
|
28 |
+
RATIO: 0.2
|
29 |
+
SEED: 42
|
30 |
+
PERTURB_PROB: 0.1
|
31 |
+
IMC_RATIO: 100.0
|
32 |
+
CON_RATIO: 100.0
|
33 |
+
SIM_RATIO: 50.0
|
34 |
+
EDGE_RATIO: 100.0
|
35 |
+
ADV_RATIO: 1.0
|
36 |
+
CLR_RATIO: 0.0
|
37 |
+
FREQ_RATIO: 0.00001
|
38 |
+
TEST:
|
39 |
+
BATCHSIZE: 32
|
configs/imc.cfg
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MODEL:
|
2 |
+
IMG_SIZE: 1024
|
3 |
+
CROP_SIZE: 320
|
4 |
+
CONFOCAL_IN: 3 # 3-channel microscope file
|
5 |
+
CONFOCAL_OUT: 1 # nuclei
|
6 |
+
IMC_IN: 4 # 3-channel microscope file + 1 channel confocal
|
7 |
+
IMC_OUT: 1 # nuclei
|
8 |
+
GRAD_CKPT: True
|
9 |
+
TIMM_MODEL: none
|
10 |
+
NORM: INSTANCE
|
11 |
+
CONFOCAL_PATH: none
|
12 |
+
IMC_PATH: none
|
13 |
+
TRAIN:
|
14 |
+
LR_G: 0.0002
|
15 |
+
LR_D: 0.0002
|
16 |
+
DECAY: 0.0
|
17 |
+
BETA1: 0.5
|
18 |
+
EARLY_STAGE: 0
|
19 |
+
BURN_IN: 0
|
20 |
+
BURN: 500
|
21 |
+
RAMPUP: 1000
|
22 |
+
EPOCHS: 1000
|
23 |
+
BATCHSIZE: 16
|
24 |
+
CROP_SAMPLE_NUM: 16
|
25 |
+
RATIO: 0.2
|
26 |
+
SEED: 42
|
27 |
+
PERTURB_PROB: 0.1
|
28 |
+
IMC_RATIO: 100.0
|
29 |
+
CON_RATIO: 100.0
|
30 |
+
SIM_RATIO: 50.0
|
31 |
+
EDGE_RATIO: 100.0
|
32 |
+
ADV_RATIO: 1.0
|
33 |
+
CLR_RATIO: 0.0
|
34 |
+
FREQ_RATIO: 0.00001
|
35 |
+
TEST:
|
36 |
+
BATCHSIZE: 32
|
configs/translation.cfg
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MODEL:
|
2 |
+
IMG_SIZE: 1024
|
3 |
+
CROP_SIZE: 320
|
4 |
+
CONFOCAL_IN: 3 # 3-channel microscope file
|
5 |
+
CONFOCAL_OUT: 1 # nuclei
|
6 |
+
IMC_IN: 4 # 3-channel microscope file + 1 channel confocal
|
7 |
+
IMC_OUT: 1 # nuclei
|
8 |
+
GRAD_CKPT: True
|
9 |
+
TIMM_MODEL: none
|
10 |
+
NORM: INSTANCE
|
11 |
+
CONFOCAL_PATH: none
|
12 |
+
IMC_PATH: none
|
13 |
+
TRAIN:
|
14 |
+
LR_G: 0.0002
|
15 |
+
LR_D: 0.0002
|
16 |
+
DECAY: 0.0
|
17 |
+
BETA1: 0.5
|
18 |
+
EARLY_STAGE: 0
|
19 |
+
BURN_IN: 0
|
20 |
+
BURN: 500
|
21 |
+
RAMPUP: 1000
|
22 |
+
EPOCHS: 1000
|
23 |
+
BATCHSIZE: 16
|
24 |
+
CROP_SAMPLE_NUM: 16
|
25 |
+
RATIO: 0.2
|
26 |
+
SEED: 42
|
27 |
+
PERTURB_PROB: 0.1
|
28 |
+
IMC_RATIO: 100.0
|
29 |
+
CON_RATIO: 100.0
|
30 |
+
SIM_RATIO: 50.0
|
31 |
+
EDGE_RATIO: 100.0
|
32 |
+
ADV_RATIO: 1.0
|
33 |
+
CLR_RATIO: 0.0
|
34 |
+
FREQ_RATIO: 0.00001
|
35 |
+
TEST:
|
36 |
+
BATCHSIZE: 32
|
markers.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
breast_markers = ['HER2',
|
2 |
+
'TAPAN8',
|
3 |
+
'CD15',
|
4 |
+
'CD206',
|
5 |
+
'CD11b',
|
6 |
+
'HLA_DR',
|
7 |
+
'H3',
|
8 |
+
'CD8a',
|
9 |
+
'ISG15',
|
10 |
+
'CD14',
|
11 |
+
'ZC3HV1',
|
12 |
+
'Collagen1',
|
13 |
+
'CD4',
|
14 |
+
'CD66b',
|
15 |
+
'ALDH1',
|
16 |
+
'FOXP3',
|
17 |
+
'SMA',
|
18 |
+
'CD24',
|
19 |
+
'CD44',
|
20 |
+
'CD54',
|
21 |
+
'PPARG',
|
22 |
+
'CD31',
|
23 |
+
'PD1',
|
24 |
+
'CD19',
|
25 |
+
'CD69',
|
26 |
+
'PKCD',
|
27 |
+
'Ki67',
|
28 |
+
'ER',
|
29 |
+
'CD11c',
|
30 |
+
'CD27',
|
31 |
+
'LPS',
|
32 |
+
'CD11a',
|
33 |
+
'PR',
|
34 |
+
'CD3',
|
35 |
+
'CD68',
|
36 |
+
'CD83',
|
37 |
+
'LTA',
|
38 |
+
'IFI6',
|
39 |
+
'CD45',
|
40 |
+
'CDH1',
|
41 |
+
'CD62L']
|
42 |
+
|
43 |
+
pancreatic_markers = ['PGAM1',
|
44 |
+
'CD44',
|
45 |
+
'Amy2A',
|
46 |
+
'PGK1',
|
47 |
+
'PGAM5',
|
48 |
+
'CD99',
|
49 |
+
'CoL1',
|
50 |
+
'TALDO',
|
51 |
+
'ALDOB',
|
52 |
+
'ALDO',
|
53 |
+
'HK2',
|
54 |
+
'HK3',
|
55 |
+
'TPI',
|
56 |
+
'PKM',
|
57 |
+
'LDH',
|
58 |
+
'CK7',
|
59 |
+
'PDPN',
|
60 |
+
'HK1',
|
61 |
+
'NSE',
|
62 |
+
'AMF',
|
63 |
+
'PFKM',
|
64 |
+
'CD45',
|
65 |
+
'PGAM4',
|
66 |
+
'GAPDH',
|
67 |
+
'CD31',
|
68 |
+
'ECAD',
|
69 |
+
'PGAM2',
|
70 |
+
'aSMA',
|
71 |
+
'LDHB']
|
72 |
+
|
73 |
+
prostatic_markers = ['CXCR4',
|
74 |
+
'EGFR',
|
75 |
+
'LAG-3',
|
76 |
+
'CD278',
|
77 |
+
'PSMA',
|
78 |
+
'CD15',
|
79 |
+
'CD134',
|
80 |
+
'CTLA4',
|
81 |
+
'Nestin',
|
82 |
+
'CD16',
|
83 |
+
'CD56',
|
84 |
+
'PD-1',
|
85 |
+
'CD11b',
|
86 |
+
'CD66a',
|
87 |
+
'CXCL12',
|
88 |
+
'CCR7',
|
89 |
+
'IDO',
|
90 |
+
'CD73',
|
91 |
+
'CD33',
|
92 |
+
'VEGF',
|
93 |
+
'CD8a',
|
94 |
+
'aSMA',
|
95 |
+
'CD14',
|
96 |
+
'AMACR',
|
97 |
+
'CD20',
|
98 |
+
'Ki-67',
|
99 |
+
'CD4',
|
100 |
+
'SOX-9',
|
101 |
+
'B7-H4',
|
102 |
+
'CD11C',
|
103 |
+
'IFNgamma',
|
104 |
+
'CD25',
|
105 |
+
'Pan-Keratin',
|
106 |
+
'Pan-Actin',
|
107 |
+
'CD45AR',
|
108 |
+
'CD74',
|
109 |
+
'CD276',
|
110 |
+
'HLA-DR',
|
111 |
+
'CD31',
|
112 |
+
'CD45RO',
|
113 |
+
'TGFbeta',
|
114 |
+
'CD366',
|
115 |
+
'CD19',
|
116 |
+
'PSA',
|
117 |
+
'Foxp3',
|
118 |
+
'EpCAM',
|
119 |
+
'GranzymeB',
|
120 |
+
'BCL-2',
|
121 |
+
'ARG1',
|
122 |
+
'CD27',
|
123 |
+
'hFAP',
|
124 |
+
'PDL-2',
|
125 |
+
'Keratin8',
|
126 |
+
'PDL-1',
|
127 |
+
'CD127',
|
128 |
+
'CD304',
|
129 |
+
'CD3',
|
130 |
+
'CD68',
|
131 |
+
'AR',
|
132 |
+
'CD45',
|
133 |
+
'Vista',
|
134 |
+
'CD62L',
|
135 |
+
'CD163',
|
136 |
+
'pan-actin']
|
models/modules/biomedclip.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import torch.nn as nn
|
3 |
+
from open_clip.factory import *
|
4 |
+
|
5 |
+
|
6 |
+
# def create_model_and_transforms(
|
7 |
+
# model_name: str,
|
8 |
+
# config: str,
|
9 |
+
# device: Union[str, torch.device] = 'cpu',
|
10 |
+
# cache_dir: Optional[str] = None,
|
11 |
+
# force_preprocess_cfg: Optional[Dict[str, Any]] = None,
|
12 |
+
# ):
|
13 |
+
# force_preprocess_cfg = force_preprocess_cfg or {}
|
14 |
+
# preprocess_cfg = asdict(PreprocessCfg())
|
15 |
+
# with open(config, 'r') as f:
|
16 |
+
# config = json.load(f)
|
17 |
+
|
18 |
+
# checkpoint_path = os.path.join(cache_dir, 'open_clip_pytorch_model.bin')
|
19 |
+
# preprocess_cfg = merge_preprocess_dict(preprocess_cfg, config['preprocess_cfg'])
|
20 |
+
# model_cfg = config['model_cfg']
|
21 |
+
|
22 |
+
# if isinstance(device, str):
|
23 |
+
# device = torch.device(device)
|
24 |
+
# print(f'Loaded {model_name} model config.')
|
25 |
+
|
26 |
+
# # load pretrained weights for HF text model IFF no CLIP weights being loaded
|
27 |
+
# model_cfg['text_cfg']['hf_model_pretrained'] = False
|
28 |
+
|
29 |
+
# model = CustomTextCLIP(**model_cfg)
|
30 |
+
# model.to(device=device)
|
31 |
+
|
32 |
+
# print(f'Loading pretrained {model_name} weights ({checkpoint_path}).')
|
33 |
+
# load_checkpoint(model, checkpoint_path)
|
34 |
+
|
35 |
+
# # set image preprocessing configuration in model attributes for convenience
|
36 |
+
# if getattr(model.visual, 'image_size', None) is not None:
|
37 |
+
# # use image_size set on model creation (via config or force_image_size arg)
|
38 |
+
# force_preprocess_cfg['size'] = model.visual.image_size
|
39 |
+
# set_model_preprocess_cfg(model, merge_preprocess_dict(preprocess_cfg, force_preprocess_cfg))
|
40 |
+
|
41 |
+
# pp_cfg = PreprocessCfg(**model.visual.preprocess_cfg)
|
42 |
+
|
43 |
+
# preprocess_train = image_transform_v2(
|
44 |
+
# pp_cfg,
|
45 |
+
# is_train=True,
|
46 |
+
# aug_cfg=None,
|
47 |
+
# )
|
48 |
+
# preprocess_val = image_transform_v2(
|
49 |
+
# pp_cfg,
|
50 |
+
# is_train=False,
|
51 |
+
# )
|
52 |
+
|
53 |
+
# return model, preprocess_train, preprocess_val
|
54 |
+
|
55 |
+
|
56 |
+
def get_my_tokenizer(
|
57 |
+
config: str,
|
58 |
+
context_length: Optional[int] = None,
|
59 |
+
**kwargs,
|
60 |
+
):
|
61 |
+
with open(config, 'r') as f:
|
62 |
+
config = json.load(f)
|
63 |
+
|
64 |
+
text_config = config['model_cfg']['text_cfg']
|
65 |
+
if 'tokenizer_kwargs' in text_config:
|
66 |
+
tokenizer_kwargs = dict(text_config['tokenizer_kwargs'], **kwargs)
|
67 |
+
else:
|
68 |
+
tokenizer_kwargs = kwargs
|
69 |
+
|
70 |
+
if context_length is None:
|
71 |
+
context_length = text_config.get('context_length', DEFAULT_CONTEXT_LENGTH)
|
72 |
+
|
73 |
+
if 'hf_tokenizer_name' in text_config:
|
74 |
+
tokenizer = HFTokenizer(
|
75 |
+
text_config['hf_tokenizer_name'],
|
76 |
+
context_length=context_length,
|
77 |
+
**tokenizer_kwargs,
|
78 |
+
)
|
79 |
+
else:
|
80 |
+
tokenizer = SimpleTokenizer(
|
81 |
+
context_length=context_length,
|
82 |
+
**tokenizer_kwargs,
|
83 |
+
)
|
84 |
+
|
85 |
+
return tokenizer
|
86 |
+
|
87 |
+
|
88 |
+
class BiomedCLIPTextEncoder(nn.Module):
|
89 |
+
def __init__(self, device: torch.device) -> None:
|
90 |
+
super().__init__()
|
91 |
+
# self.model, _, _ = create_model_and_transforms(
|
92 |
+
# model_name='hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224',
|
93 |
+
# # config='./ckpt/BiomedCLIP/open_clip_config.json',
|
94 |
+
# cache_dir='./ckpt/BiomedCLIP/'
|
95 |
+
# )
|
96 |
+
self.model, _, _ = create_model_and_transforms('hf-hub:hsiangyualex/biomedclip4imc')
|
97 |
+
self.model.eval()
|
98 |
+
self.model.to(device)
|
99 |
+
for param in self.model.parameters():
|
100 |
+
param.requires_grad = False
|
101 |
+
# self.tokenizer = get_my_tokenizer(config='./ckpt/BiomedCLIP/open_clip_config.json')
|
102 |
+
self.tokenizer = get_tokenizer('hf-hub:hsiangyualex/biomedclip4imc')
|
103 |
+
self.device = device
|
104 |
+
|
105 |
+
@torch.no_grad()
|
106 |
+
def forward(self, prompts):
|
107 |
+
"""
|
108 |
+
Args:
|
109 |
+
prompts: a series of protein names
|
110 |
+
"""
|
111 |
+
prompts = [f"An imaging mass cytometry staining image of {prompt} protein." for prompt in prompts]
|
112 |
+
prompts = self.tokenizer(prompts).to(self.device)
|
113 |
+
text_features = self.model.encode_text(prompts).detach()
|
114 |
+
return text_features
|
models/modules/dct.py
ADDED
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
try:
|
6 |
+
# PyTorch 1.7.0 and newer versions
|
7 |
+
import torch.fft
|
8 |
+
|
9 |
+
def dct1_rfft_impl(x):
|
10 |
+
return torch.view_as_real(torch.fft.rfft(x, dim=1))
|
11 |
+
|
12 |
+
def dct_fft_impl(v):
|
13 |
+
return torch.view_as_real(torch.fft.fft(v, dim=1))
|
14 |
+
|
15 |
+
def idct_irfft_impl(V):
|
16 |
+
return torch.fft.irfft(torch.view_as_complex(V), n=V.shape[1], dim=1)
|
17 |
+
except ImportError:
|
18 |
+
# PyTorch 1.6.0 and older versions
|
19 |
+
def dct1_rfft_impl(x):
|
20 |
+
return torch.rfft(x, 1)
|
21 |
+
|
22 |
+
def dct_fft_impl(v):
|
23 |
+
return torch.rfft(v, 1, onesided=False)
|
24 |
+
|
25 |
+
def idct_irfft_impl(V):
|
26 |
+
return torch.irfft(V, 1, onesided=False)
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
def dct1(x):
|
31 |
+
"""
|
32 |
+
Discrete Cosine Transform, Type I
|
33 |
+
|
34 |
+
:param x: the input signal
|
35 |
+
:return: the DCT-I of the signal over the last dimension
|
36 |
+
"""
|
37 |
+
x_shape = x.shape
|
38 |
+
x = x.view(-1, x_shape[-1])
|
39 |
+
x = torch.cat([x, x.flip([1])[:, 1:-1]], dim=1)
|
40 |
+
|
41 |
+
return dct1_rfft_impl(x)[:, :, 0].view(*x_shape)
|
42 |
+
|
43 |
+
|
44 |
+
def idct1(X):
|
45 |
+
"""
|
46 |
+
The inverse of DCT-I, which is just a scaled DCT-I
|
47 |
+
|
48 |
+
Our definition if idct1 is such that idct1(dct1(x)) == x
|
49 |
+
|
50 |
+
:param X: the input signal
|
51 |
+
:return: the inverse DCT-I of the signal over the last dimension
|
52 |
+
"""
|
53 |
+
n = X.shape[-1]
|
54 |
+
return dct1(X) / (2 * (n - 1))
|
55 |
+
|
56 |
+
|
57 |
+
def dct(x, norm=None):
|
58 |
+
"""
|
59 |
+
Discrete Cosine Transform, Type II (a.k.a. the DCT)
|
60 |
+
|
61 |
+
For the meaning of the parameter `norm`, see:
|
62 |
+
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
|
63 |
+
|
64 |
+
:param x: the input signal
|
65 |
+
:param norm: the normalization, None or 'ortho'
|
66 |
+
:return: the DCT-II of the signal over the last dimension
|
67 |
+
"""
|
68 |
+
x_shape = x.shape
|
69 |
+
N = x_shape[-1]
|
70 |
+
x = x.contiguous().view(-1, N)
|
71 |
+
|
72 |
+
v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1)
|
73 |
+
|
74 |
+
Vc = dct_fft_impl(v)
|
75 |
+
|
76 |
+
k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N)
|
77 |
+
W_r = torch.cos(k)
|
78 |
+
W_i = torch.sin(k)
|
79 |
+
|
80 |
+
V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i
|
81 |
+
|
82 |
+
if norm == 'ortho':
|
83 |
+
V[:, 0] /= np.sqrt(N) * 2
|
84 |
+
V[:, 1:] /= np.sqrt(N / 2) * 2
|
85 |
+
|
86 |
+
V = 2 * V.view(*x_shape)
|
87 |
+
|
88 |
+
return V
|
89 |
+
|
90 |
+
|
91 |
+
def idct(X, norm=None):
|
92 |
+
"""
|
93 |
+
The inverse to DCT-II, which is a scaled Discrete Cosine Transform, Type III
|
94 |
+
|
95 |
+
Our definition of idct is that idct(dct(x)) == x
|
96 |
+
|
97 |
+
For the meaning of the parameter `norm`, see:
|
98 |
+
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
|
99 |
+
|
100 |
+
:param X: the input signal
|
101 |
+
:param norm: the normalization, None or 'ortho'
|
102 |
+
:return: the inverse DCT-II of the signal over the last dimension
|
103 |
+
"""
|
104 |
+
|
105 |
+
x_shape = X.shape
|
106 |
+
N = x_shape[-1]
|
107 |
+
|
108 |
+
X_v = X.contiguous().view(-1, x_shape[-1]) / 2
|
109 |
+
|
110 |
+
if norm == 'ortho':
|
111 |
+
X_v[:, 0] *= np.sqrt(N) * 2
|
112 |
+
X_v[:, 1:] *= np.sqrt(N / 2) * 2
|
113 |
+
|
114 |
+
k = torch.arange(x_shape[-1], dtype=X.dtype, device=X.device)[None, :] * np.pi / (2 * N)
|
115 |
+
W_r = torch.cos(k)
|
116 |
+
W_i = torch.sin(k)
|
117 |
+
|
118 |
+
V_t_r = X_v
|
119 |
+
V_t_i = torch.cat([X_v[:, :1] * 0, -X_v.flip([1])[:, :-1]], dim=1)
|
120 |
+
|
121 |
+
V_r = V_t_r * W_r - V_t_i * W_i
|
122 |
+
V_i = V_t_r * W_i + V_t_i * W_r
|
123 |
+
|
124 |
+
V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2)
|
125 |
+
|
126 |
+
v = idct_irfft_impl(V)
|
127 |
+
x = v.new_zeros(v.shape)
|
128 |
+
x[:, ::2] += v[:, :N - (N // 2)]
|
129 |
+
x[:, 1::2] += v.flip([1])[:, :N // 2]
|
130 |
+
|
131 |
+
return x.view(*x_shape)
|
132 |
+
|
133 |
+
|
134 |
+
def dct_2d(x, norm=None):
|
135 |
+
"""
|
136 |
+
2-dimentional Discrete Cosine Transform, Type II (a.k.a. the DCT)
|
137 |
+
|
138 |
+
For the meaning of the parameter `norm`, see:
|
139 |
+
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
|
140 |
+
|
141 |
+
:param x: the input signal
|
142 |
+
:param norm: the normalization, None or 'ortho'
|
143 |
+
:return: the DCT-II of the signal over the last 2 dimensions
|
144 |
+
"""
|
145 |
+
X1 = dct(x, norm=norm)
|
146 |
+
X2 = dct(X1.transpose(-1, -2), norm=norm)
|
147 |
+
return X2.transpose(-1, -2)
|
148 |
+
|
149 |
+
|
150 |
+
def idct_2d(X, norm=None):
|
151 |
+
"""
|
152 |
+
The inverse to 2D DCT-II, which is a scaled Discrete Cosine Transform, Type III
|
153 |
+
|
154 |
+
Our definition of idct is that idct_2d(dct_2d(x)) == x
|
155 |
+
|
156 |
+
For the meaning of the parameter `norm`, see:
|
157 |
+
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
|
158 |
+
|
159 |
+
:param X: the input signal
|
160 |
+
:param norm: the normalization, None or 'ortho'
|
161 |
+
:return: the DCT-II of the signal over the last 2 dimensions
|
162 |
+
"""
|
163 |
+
x1 = idct(X, norm=norm)
|
164 |
+
x2 = idct(x1.transpose(-1, -2), norm=norm)
|
165 |
+
return x2.transpose(-1, -2)
|
166 |
+
|
167 |
+
|
168 |
+
def dct_3d(x, norm=None):
|
169 |
+
"""
|
170 |
+
3-dimentional Discrete Cosine Transform, Type II (a.k.a. the DCT)
|
171 |
+
|
172 |
+
For the meaning of the parameter `norm`, see:
|
173 |
+
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
|
174 |
+
|
175 |
+
:param x: the input signal
|
176 |
+
:param norm: the normalization, None or 'ortho'
|
177 |
+
:return: the DCT-II of the signal over the last 3 dimensions
|
178 |
+
"""
|
179 |
+
X1 = dct(x, norm=norm)
|
180 |
+
X2 = dct(X1.transpose(-1, -2), norm=norm)
|
181 |
+
X3 = dct(X2.transpose(-1, -3), norm=norm)
|
182 |
+
return X3.transpose(-1, -3).transpose(-1, -2)
|
183 |
+
|
184 |
+
|
185 |
+
def idct_3d(X, norm=None):
|
186 |
+
"""
|
187 |
+
The inverse to 3D DCT-II, which is a scaled Discrete Cosine Transform, Type III
|
188 |
+
|
189 |
+
Our definition of idct is that idct_3d(dct_3d(x)) == x
|
190 |
+
|
191 |
+
For the meaning of the parameter `norm`, see:
|
192 |
+
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
|
193 |
+
|
194 |
+
:param X: the input signal
|
195 |
+
:param norm: the normalization, None or 'ortho'
|
196 |
+
:return: the DCT-II of the signal over the last 3 dimensions
|
197 |
+
"""
|
198 |
+
x1 = idct(X, norm=norm)
|
199 |
+
x2 = idct(x1.transpose(-1, -2), norm=norm)
|
200 |
+
x3 = idct(x2.transpose(-1, -3), norm=norm)
|
201 |
+
return x3.transpose(-1, -3).transpose(-1, -2)
|
202 |
+
|
203 |
+
|
204 |
+
class LinearDCT(nn.Linear):
|
205 |
+
"""Implement any DCT as a linear layer; in practice this executes around
|
206 |
+
50x faster on GPU. Unfortunately, the DCT matrix is stored, which will
|
207 |
+
increase memory usage.
|
208 |
+
:param in_features: size of expected input
|
209 |
+
:param type: which dct function in this file to use"""
|
210 |
+
def __init__(self, in_features, type, norm=None, bias=False):
|
211 |
+
self.type = type
|
212 |
+
self.N = in_features
|
213 |
+
self.norm = norm
|
214 |
+
super(LinearDCT, self).__init__(in_features, in_features, bias=bias)
|
215 |
+
|
216 |
+
def reset_parameters(self):
|
217 |
+
# initialise using dct function
|
218 |
+
I = torch.eye(self.N)
|
219 |
+
if self.type == 'dct1':
|
220 |
+
self.weight.data = dct1(I).data.t()
|
221 |
+
elif self.type == 'idct1':
|
222 |
+
self.weight.data = idct1(I).data.t()
|
223 |
+
elif self.type == 'dct':
|
224 |
+
self.weight.data = dct(I, norm=self.norm).data.t()
|
225 |
+
elif self.type == 'idct':
|
226 |
+
self.weight.data = idct(I, norm=self.norm).data.t()
|
227 |
+
self.weight.requires_grad = False # don't learn this!
|
228 |
+
|
229 |
+
|
230 |
+
def apply_linear_2d(x, linear_layer):
|
231 |
+
"""Can be used with a LinearDCT layer to do a 2D DCT.
|
232 |
+
:param x: the input signal
|
233 |
+
:param linear_layer: any PyTorch Linear layer
|
234 |
+
:return: result of linear layer applied to last 2 dimensions
|
235 |
+
"""
|
236 |
+
X1 = linear_layer(x)
|
237 |
+
X2 = linear_layer(X1.transpose(-1, -2))
|
238 |
+
return X2.transpose(-1, -2)
|
239 |
+
|
240 |
+
def apply_linear_3d(x, linear_layer):
|
241 |
+
"""Can be used with a LinearDCT layer to do a 3D DCT.
|
242 |
+
:param x: the input signal
|
243 |
+
:param linear_layer: any PyTorch Linear layer
|
244 |
+
:return: result of linear layer applied to last 3 dimensions
|
245 |
+
"""
|
246 |
+
X1 = linear_layer(x)
|
247 |
+
X2 = linear_layer(X1.transpose(-1, -2))
|
248 |
+
X3 = linear_layer(X2.transpose(-1, -3))
|
249 |
+
return X3.transpose(-1, -3).transpose(-1, -2)
|
250 |
+
|
251 |
+
|
252 |
+
class DCTHelper(nn.Module):
|
253 |
+
"""
|
254 |
+
Implement DCT operations and corresponding masking.
|
255 |
+
"""
|
256 |
+
def __init__(self, side_length: int, norm: str = None, cutoff: float = 0.8, data_range: tuple = (-1.0, 1.0)):
|
257 |
+
"""
|
258 |
+
Args:
|
259 |
+
side_length: the side length of the image
|
260 |
+
norm: the normalization, None or 'ortho'
|
261 |
+
cutoff: the cutoff frequency ratio for low-pass filtering
|
262 |
+
"""
|
263 |
+
super().__init__()
|
264 |
+
self.dct = LinearDCT(side_length, 'dct')
|
265 |
+
self.idct = LinearDCT(side_length, 'idct')
|
266 |
+
mask = self.create_circular_mask(side_length, side_length, radius=side_length * cutoff, center=(0, 0))
|
267 |
+
self.register_buffer('mask', torch.from_numpy(mask).float()[None, None, ...])
|
268 |
+
self.data_range = data_range
|
269 |
+
|
270 |
+
@staticmethod
|
271 |
+
def create_circular_mask(h, w, center=None, radius=None):
|
272 |
+
if center is None: # use the middle of the image
|
273 |
+
center = (int(w/2), int(h/2))
|
274 |
+
if radius is None: # use the smallest distance between the center and image walls
|
275 |
+
radius = min(center[0], center[1], w-center[0], h-center[1])
|
276 |
+
|
277 |
+
Y, X = np.ogrid[:h, :w]
|
278 |
+
dist_from_center = np.sqrt((X - center[0])**2 + (Y-center[1])**2)
|
279 |
+
|
280 |
+
mask = dist_from_center <= radius
|
281 |
+
return mask
|
282 |
+
|
283 |
+
def run_dct(self, x):
|
284 |
+
return apply_linear_2d(x, self.dct)
|
285 |
+
|
286 |
+
def run_idct(self, x):
|
287 |
+
return apply_linear_2d(x, self.idct)
|
288 |
+
|
289 |
+
def forward(self, x, mode: str = 'dct'):
|
290 |
+
if mode == 'dct':
|
291 |
+
return self.run_dct(x)
|
292 |
+
elif mode == 'idct':
|
293 |
+
return self.run_idct(x)
|
294 |
+
else:
|
295 |
+
raise ValueError(f"Invalid mode: {mode}")
|
296 |
+
|
297 |
+
if __name__ == '__main__':
|
298 |
+
x = torch.Tensor(1000,4096)
|
299 |
+
x.normal_(0,1)
|
300 |
+
linear_dct = LinearDCT(4096, 'dct')
|
301 |
+
error = torch.abs(dct(x) - linear_dct(x))
|
302 |
+
assert error.max() < 1e-3, (error, error.max())
|
303 |
+
linear_idct = LinearDCT(4096, 'idct')
|
304 |
+
error = torch.abs(idct(x) - linear_idct(x))
|
305 |
+
assert error.max() < 1e-3, (error, error.max())
|
models/modules/networks.py
ADDED
@@ -0,0 +1,714 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from base.base_modules import *
|
5 |
+
from timm.models import create_model
|
6 |
+
from functools import partial
|
7 |
+
|
8 |
+
|
9 |
+
class Backbone(nn.Module):
|
10 |
+
"""
|
11 |
+
Model backbone to extract features
|
12 |
+
"""
|
13 |
+
def __init__(self,
|
14 |
+
input_channels: int = 3,
|
15 |
+
channels: tuple = (32, 64, 128, 256, 512),
|
16 |
+
strides: tuple = (2, 2, 2, 2),
|
17 |
+
use_dropout: bool = False,
|
18 |
+
norm: str = 'BATCH',
|
19 |
+
leaky: bool = True):
|
20 |
+
"""
|
21 |
+
Args:
|
22 |
+
input_channels: the number of input channels
|
23 |
+
channels: length-5 tuple, define the number of channels in each stage
|
24 |
+
strides: tuple, define the stride in each stage
|
25 |
+
use_dropout: bool, whether to use dropout
|
26 |
+
norm: str, normalization type
|
27 |
+
leaky: bool, whether to use leaky relu
|
28 |
+
"""
|
29 |
+
super().__init__()
|
30 |
+
self.nb_filter = channels
|
31 |
+
self.strides = strides + (5 - len(strides)) * (1,)
|
32 |
+
res_unit = ResBlock if channels[-1] <= 320 else ResBottleneck
|
33 |
+
|
34 |
+
self.conv0_0 = nn.Sequential(
|
35 |
+
nn.Conv2d(input_channels, channels[0], kernel_size=7, stride=self.strides[0], padding=3),
|
36 |
+
nn.GroupNorm(1, channels[0]) if norm == 'GROUP' else nn.BatchNorm2d(channels[0]) if norm == 'BATCH' else nn.InstanceNorm2d(channels[0]),
|
37 |
+
nn.LeakyReLU() if leaky else nn.ReLU(),
|
38 |
+
)
|
39 |
+
self.conv1_0 = res_unit(self.nb_filter[0], self.nb_filter[1], self.strides[1], use_dropout=use_dropout, norm=norm, leaky=leaky)
|
40 |
+
self.conv2_0 = res_unit(self.nb_filter[1], self.nb_filter[2], self.strides[2], use_dropout=use_dropout, norm=norm, leaky=leaky)
|
41 |
+
self.conv3_0 = res_unit(self.nb_filter[2], self.nb_filter[3], self.strides[3], use_dropout=use_dropout, norm=norm, leaky=leaky)
|
42 |
+
self.conv4_0 = res_unit(self.nb_filter[3], self.nb_filter[4], self.strides[4], use_dropout=use_dropout, norm=norm, leaky=leaky)
|
43 |
+
|
44 |
+
def forward(self, x):
|
45 |
+
x0_0 = self.conv0_0(x)
|
46 |
+
x1_0 = self.conv1_0(x0_0)
|
47 |
+
x2_0 = self.conv2_0(x1_0)
|
48 |
+
x3_0 = self.conv3_0(x2_0)
|
49 |
+
x4_0 = self.conv4_0(x3_0)
|
50 |
+
return x0_0, x1_0, x2_0, x3_0, x4_0
|
51 |
+
|
52 |
+
|
53 |
+
class TimmBackbone(nn.Module):
|
54 |
+
"""
|
55 |
+
Timm backbone to extract features, utilizing pretrained weights
|
56 |
+
"""
|
57 |
+
def __init__(self, model_name) -> None:
|
58 |
+
super().__init__()
|
59 |
+
self.backbone = create_model(model_name, pretrained=True, features_only=True)
|
60 |
+
self.determine_nb_filters()
|
61 |
+
|
62 |
+
def determine_nb_filters(self):
|
63 |
+
dummy = torch.randn(1, 3, 256, 256)
|
64 |
+
out = self.backbone(dummy)
|
65 |
+
nb_filters = []
|
66 |
+
for o in out:
|
67 |
+
nb_filters.append(o.size(1))
|
68 |
+
self.nb_filter = nb_filters
|
69 |
+
|
70 |
+
def forward(self, inputs):
|
71 |
+
return self.backbone(inputs)
|
72 |
+
|
73 |
+
|
74 |
+
class UNet(nn.Module):
|
75 |
+
def __init__(self,
|
76 |
+
model_name: str = None,
|
77 |
+
in_channels: int = 1,
|
78 |
+
out_channels: int = None,
|
79 |
+
channels: tuple = (64, 128, 256, 320, 512),
|
80 |
+
strides: tuple = (2, 2, 2, 2, 2),
|
81 |
+
use_dropout: bool = False,
|
82 |
+
norm: str = 'INSTANCE',
|
83 |
+
leaky: bool = True,
|
84 |
+
use_dilated_bottleneck: bool = False):
|
85 |
+
"""
|
86 |
+
Args:
|
87 |
+
model_name: timm model name
|
88 |
+
input_channels: the number of input channels
|
89 |
+
in_channels: the number of output channels
|
90 |
+
channels: length-5 tuple, define the number of channels in each stage
|
91 |
+
strides: tuple, define the stride in each stage
|
92 |
+
use_dropout: bool, whether to use dropout
|
93 |
+
norm: str, normalization type
|
94 |
+
leaky: bool, whether to use leaky relu
|
95 |
+
"""
|
96 |
+
super().__init__()
|
97 |
+
if model_name not in [None, 'none', 'None']:
|
98 |
+
# use Timm backbone and overrides any other input arguments
|
99 |
+
self.backbone = TimmBackbone(model_name)
|
100 |
+
else:
|
101 |
+
self.backbone = Backbone(input_channels=in_channels, channels=channels, strides=strides,
|
102 |
+
use_dropout=use_dropout, norm=norm, leaky=leaky)
|
103 |
+
nb_filter = self.backbone.nb_filter
|
104 |
+
res_unit = ResBlock if nb_filter[-1] <= 512 else ResBottleneck
|
105 |
+
|
106 |
+
# decoder
|
107 |
+
self.conv3_1 = res_unit(nb_filter[3] + nb_filter[4], nb_filter[3], use_dropout=use_dropout, norm=norm, leaky=leaky)
|
108 |
+
self.conv2_2 = res_unit(nb_filter[2] + nb_filter[3], nb_filter[2], use_dropout=use_dropout, norm=norm, leaky=leaky)
|
109 |
+
self.conv1_3 = res_unit(nb_filter[1] + nb_filter[2], nb_filter[1], use_dropout=use_dropout, norm=norm, leaky=leaky)
|
110 |
+
self.conv0_4 = res_unit(nb_filter[0] + nb_filter[1], nb_filter[0], use_dropout=use_dropout, norm=norm, leaky=leaky)
|
111 |
+
|
112 |
+
# dilated bottleneck: optional
|
113 |
+
if use_dilated_bottleneck:
|
114 |
+
self.dilation = nn.Sequential(
|
115 |
+
nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=1, dilation=1),
|
116 |
+
nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]),
|
117 |
+
nn.LeakyReLU() if leaky else nn.ReLU(),
|
118 |
+
nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=2, dilation=2),
|
119 |
+
nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]),
|
120 |
+
nn.LeakyReLU() if leaky else nn.ReLU(),
|
121 |
+
nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=5, dilation=5),
|
122 |
+
nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]),
|
123 |
+
nn.LeakyReLU() if leaky else nn.ReLU(),
|
124 |
+
nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=1, dilation=1),
|
125 |
+
nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]),
|
126 |
+
nn.LeakyReLU() if leaky else nn.ReLU(),
|
127 |
+
nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=2, dilation=2),
|
128 |
+
nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]),
|
129 |
+
nn.LeakyReLU() if leaky else nn.ReLU(),
|
130 |
+
nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=5, dilation=5),
|
131 |
+
nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]),
|
132 |
+
nn.LeakyReLU() if leaky else nn.ReLU(),
|
133 |
+
)
|
134 |
+
else:
|
135 |
+
self.dilation = nn.Identity()
|
136 |
+
|
137 |
+
if out_channels is not None:
|
138 |
+
self.convds0 = nn.Conv2d(nb_filter[0], out_channels, kernel_size=1, bias=False)
|
139 |
+
else:
|
140 |
+
self.convds0 = None
|
141 |
+
|
142 |
+
def upsample(self, inputs, target):
|
143 |
+
return F.interpolate(inputs, size=target.shape[2:], mode='bilinear', align_corners=False)
|
144 |
+
|
145 |
+
def extract_features(self, x):
|
146 |
+
x0, x1, x2, x3, x4 = self.backbone(x)
|
147 |
+
|
148 |
+
x4 = self.dilation(x4)
|
149 |
+
|
150 |
+
x3_1 = self.conv3_1(torch.cat([x3, self.upsample(x4, x3)], dim=1))
|
151 |
+
x2_2 = self.conv2_2(torch.cat([x2, self.upsample(x3_1, x2)], dim=1))
|
152 |
+
x1_3 = self.conv1_3(torch.cat([x1, self.upsample(x2_2, x1)], dim=1))
|
153 |
+
x0_4 = self.conv0_4(torch.cat([x0, self.upsample(x1_3, x0)], dim=1))
|
154 |
+
return x4, x0_4
|
155 |
+
|
156 |
+
def forward(self, x):
|
157 |
+
size = x.shape[2:]
|
158 |
+
x0, x1, x2, x3, x4 = self.backbone(x)
|
159 |
+
|
160 |
+
x4 = self.dilation(x4)
|
161 |
+
|
162 |
+
x3_1 = self.conv3_1(torch.cat([x3, self.upsample(x4, x3)], dim=1))
|
163 |
+
x2_2 = self.conv2_2(torch.cat([x2, self.upsample(x3_1, x2)], dim=1))
|
164 |
+
x1_3 = self.conv1_3(torch.cat([x1, self.upsample(x2_2, x1)], dim=1))
|
165 |
+
x0_4 = self.conv0_4(torch.cat([x0, self.upsample(x1_3, x0)], dim=1))
|
166 |
+
if self.convds0 is not None:
|
167 |
+
x_out = self.convds0(x0_4)
|
168 |
+
out = F.interpolate(x_out, size=size, mode='bilinear', align_corners=False)
|
169 |
+
else:
|
170 |
+
out = x0_4
|
171 |
+
return out
|
172 |
+
|
173 |
+
def freeze(self):
|
174 |
+
# freeze the network
|
175 |
+
for p in self.parameters():
|
176 |
+
p.requires_grad = False
|
177 |
+
|
178 |
+
def unfreeze(self):
|
179 |
+
# unfreeze the network to allow parameter update
|
180 |
+
for p in self.parameters():
|
181 |
+
p.requires_grad = True
|
182 |
+
|
183 |
+
|
184 |
+
class PromptAttentionUNet(nn.Module):
|
185 |
+
def __init__(self,
|
186 |
+
model_name: str = None,
|
187 |
+
in_channels: int = 1,
|
188 |
+
out_channels: int = None,
|
189 |
+
channels: tuple = (64, 128, 256, 320, 512),
|
190 |
+
strides: tuple = (2, 2, 2, 2, 2),
|
191 |
+
use_dropout: bool = False,
|
192 |
+
norm: str = 'INSTANCE',
|
193 |
+
leaky: bool = True,
|
194 |
+
use_dilated_bottleneck: bool = False):
|
195 |
+
"""
|
196 |
+
Args:
|
197 |
+
model_name: timm model name
|
198 |
+
input_channels: the number of input channels
|
199 |
+
in_channels: the number of output channels
|
200 |
+
channels: length-5 tuple, define the number of channels in each stage
|
201 |
+
strides: tuple, define the stride in each stage
|
202 |
+
use_dropout: bool, whether to use dropout
|
203 |
+
norm: str, normalization type
|
204 |
+
leaky: bool, whether to use leaky relu
|
205 |
+
"""
|
206 |
+
super().__init__()
|
207 |
+
if model_name not in [None, 'none', 'None']:
|
208 |
+
# use Timm backbone and overrides any other input arguments
|
209 |
+
self.backbone = TimmBackbone(model_name)
|
210 |
+
else:
|
211 |
+
self.backbone = Backbone(input_channels=in_channels, channels=channels, strides=strides,
|
212 |
+
use_dropout=use_dropout, norm=norm, leaky=leaky)
|
213 |
+
nb_filter = self.backbone.nb_filter
|
214 |
+
res_unit = PromptResBlock if nb_filter[-1] <= 512 else PromptResBottleneck
|
215 |
+
|
216 |
+
# decoder
|
217 |
+
self.conv3_1 = res_unit(nb_filter[3] + nb_filter[4], nb_filter[3], use_dropout=use_dropout, norm=norm, leaky=leaky)
|
218 |
+
self.conv2_2 = res_unit(nb_filter[2] + nb_filter[3], nb_filter[2], use_dropout=use_dropout, norm=norm, leaky=leaky)
|
219 |
+
self.conv1_3 = res_unit(nb_filter[1] + nb_filter[2], nb_filter[1], use_dropout=use_dropout, norm=norm, leaky=leaky)
|
220 |
+
self.conv0_4 = res_unit(nb_filter[0] + nb_filter[1], nb_filter[0], use_dropout=use_dropout, norm=norm, leaky=leaky)
|
221 |
+
|
222 |
+
# dilated bottleneck: optional
|
223 |
+
if use_dilated_bottleneck:
|
224 |
+
self.dilation = nn.Sequential(
|
225 |
+
nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=1, dilation=1),
|
226 |
+
nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]),
|
227 |
+
nn.LeakyReLU() if leaky else nn.ReLU(),
|
228 |
+
nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=1, dilation=2),
|
229 |
+
nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]),
|
230 |
+
nn.LeakyReLU() if leaky else nn.ReLU(),
|
231 |
+
nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=1, dilation=5),
|
232 |
+
nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]),
|
233 |
+
nn.LeakyReLU() if leaky else nn.ReLU(),
|
234 |
+
nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=1, dilation=1),
|
235 |
+
nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]),
|
236 |
+
nn.LeakyReLU() if leaky else nn.ReLU(),
|
237 |
+
nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=1, dilation=2),
|
238 |
+
nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]),
|
239 |
+
nn.LeakyReLU() if leaky else nn.ReLU(),
|
240 |
+
nn.Conv2d(nb_filter[4], nb_filter[4], kernel_size=3, stride=1, padding=1, dilation=5),
|
241 |
+
nn.GroupNorm(16, nb_filter[4]) if norm == 'GROUP' else nn.BatchNorm2d(nb_filter[4]) if norm == 'BATCH' else nn.InstanceNorm2d(nb_filter[4]),
|
242 |
+
nn.LeakyReLU() if leaky else nn.ReLU(),
|
243 |
+
)
|
244 |
+
else:
|
245 |
+
self.dilation = nn.Identity()
|
246 |
+
|
247 |
+
if out_channels is not None:
|
248 |
+
self.convds0 = nn.Conv2d(nb_filter[0], out_channels, kernel_size=1, bias=False)
|
249 |
+
|
250 |
+
def upsample(self, inputs, target):
|
251 |
+
return F.interpolate(inputs, size=target.shape[2:], mode='bilinear', align_corners=False)
|
252 |
+
|
253 |
+
def extract_features(self, x):
|
254 |
+
x0, x1, x2, x3, x4 = self.backbone(x)
|
255 |
+
|
256 |
+
x4 = self.dilation(x4)
|
257 |
+
|
258 |
+
x3_1 = self.conv3_1(torch.cat([x3, self.upsample(x4, x3)], dim=1))
|
259 |
+
x2_2 = self.conv2_2(torch.cat([x2, self.upsample(x3_1, x2)], dim=1))
|
260 |
+
x1_3 = self.conv1_3(torch.cat([x1, self.upsample(x2_2, x1)], dim=1))
|
261 |
+
x0_4 = self.conv0_4(torch.cat([x0, self.upsample(x1_3, x0)], dim=1))
|
262 |
+
return x4, x0_4
|
263 |
+
|
264 |
+
def forward(self, x, prompt_in):
|
265 |
+
size = x.shape[2:]
|
266 |
+
x0, x1, x2, x3, x4 = self.backbone(x)
|
267 |
+
|
268 |
+
x4 = self.dilation(x4)
|
269 |
+
|
270 |
+
x3_1 = self.conv3_1(torch.cat([x3, self.upsample(x4, x3)], dim=1), prompt_in)
|
271 |
+
x2_2 = self.conv2_2(torch.cat([x2, self.upsample(x3_1, x2)], dim=1), prompt_in)
|
272 |
+
x1_3 = self.conv1_3(torch.cat([x1, self.upsample(x2_2, x1)], dim=1), prompt_in)
|
273 |
+
x0_4 = self.conv0_4(torch.cat([x0, self.upsample(x1_3, x0)], dim=1), prompt_in)
|
274 |
+
x_out = self.convds0(x0_4)
|
275 |
+
out = F.interpolate(x_out, size=size, mode='bilinear', align_corners=False)
|
276 |
+
return out
|
277 |
+
|
278 |
+
def freeze(self):
|
279 |
+
# freeze the network
|
280 |
+
for p in self.parameters():
|
281 |
+
p.requires_grad = False
|
282 |
+
|
283 |
+
def unfreeze(self):
|
284 |
+
# unfreeze the network to allow parameter update
|
285 |
+
for p in self.parameters():
|
286 |
+
p.requires_grad = True
|
287 |
+
|
288 |
+
|
289 |
+
class CLIPDrivenUNet(nn.Module):
|
290 |
+
def __init__(self, encoding: str, model_name: str = None, in_channels: int = 1, out_channels: int = 1, channels: tuple = (32, 64, 128, 256, 512),
|
291 |
+
strides: tuple = (2, 2, 2, 2, 2), norm: str = 'INSTANCE', leaky: bool = True) -> None:
|
292 |
+
super().__init__()
|
293 |
+
self.encoding = encoding
|
294 |
+
self.num_classes = out_channels
|
295 |
+
self.backbone = UNet(model_name=model_name, in_channels=in_channels, out_channels=None, channels=channels,
|
296 |
+
strides=strides, use_dropout=False, norm=norm, leaky=leaky)
|
297 |
+
self.gap = nn.AdaptiveAvgPool2d(1)
|
298 |
+
self.precls_conv = nn.Sequential(
|
299 |
+
nn.InstanceNorm2d(32),
|
300 |
+
nn.LeakyReLU(),
|
301 |
+
nn.Conv2d(32, 8, kernel_size=1)
|
302 |
+
)
|
303 |
+
|
304 |
+
self.weight_nums = [8*8, 8*8, 8*1]
|
305 |
+
self.bias_nums = [8, 8, 1]
|
306 |
+
self.controller = nn.Conv2d(256 + channels[-1], sum(self.weight_nums + self.bias_nums), kernel_size=1, stride=1, padding=0)
|
307 |
+
if encoding == 'CLIP':
|
308 |
+
self.register_buffer('protein_embedding', torch.randn(self.num_classes, 512))
|
309 |
+
self.text_to_vision = nn.Linear(512, 256)
|
310 |
+
elif encoding == 'RAND':
|
311 |
+
self.register_buffer('protein_embedding', torch.randn(self.num_classes, 256))
|
312 |
+
|
313 |
+
def parse_dynamic_params(self, params, channels, weight_nums, bias_nums):
|
314 |
+
assert params.dim() == 2
|
315 |
+
assert len(weight_nums) == len(bias_nums)
|
316 |
+
assert params.size(1) == sum(weight_nums) + sum(bias_nums)
|
317 |
+
|
318 |
+
num_insts = params.size(0)
|
319 |
+
num_layers = len(weight_nums)
|
320 |
+
|
321 |
+
params_splits = list(torch.split_with_sizes(
|
322 |
+
params, weight_nums + bias_nums, dim=1
|
323 |
+
))
|
324 |
+
|
325 |
+
weight_splits = params_splits[:num_layers]
|
326 |
+
bias_splits = params_splits[num_layers:]
|
327 |
+
|
328 |
+
for l in range(num_layers):
|
329 |
+
if l < num_layers - 1:
|
330 |
+
weight_splits[l] = weight_splits[l].reshape(num_insts * channels, -1, 1, 1)
|
331 |
+
bias_splits[l] = bias_splits[l].reshape(num_insts * channels)
|
332 |
+
else:
|
333 |
+
weight_splits[l] = weight_splits[l].reshape(num_insts * 1, -1, 1, 1)
|
334 |
+
bias_splits[l] = bias_splits[l].reshape(num_insts * 1)
|
335 |
+
# print(weight_splits[l].shape, bias_splits[l].shape)
|
336 |
+
|
337 |
+
return weight_splits, bias_splits
|
338 |
+
|
339 |
+
def heads_forward(self, features, weights, biases, num_insts):
|
340 |
+
n_layers = len(weights)
|
341 |
+
x = features
|
342 |
+
for i, (w, b) in enumerate(zip(weights, biases)):
|
343 |
+
x = F.conv2d(
|
344 |
+
x, w, bias=b,
|
345 |
+
stride=1, padding=0,
|
346 |
+
groups=num_insts
|
347 |
+
)
|
348 |
+
if i < n_layers - 1:
|
349 |
+
x = F.leaky_relu(x)
|
350 |
+
return x
|
351 |
+
|
352 |
+
def forward(self, x_in):
|
353 |
+
out_shape = x_in.shape[2:]
|
354 |
+
dec4, out = self.backbone.extract_features(x_in) # dec4: (B, channels[-1], H, W), out: (B, channels[0], H, W)
|
355 |
+
|
356 |
+
if self.encoding == 'RAND':
|
357 |
+
task_encoding = self.protein_embedding[..., None, None] # (num_classes, 256, 1, 1)
|
358 |
+
elif self.encoding == 'CLIP':
|
359 |
+
task_encoding = F.leaky_relu(self.text_to_vision(self.protein_embedding))[..., None, None] # (num_classes, 256, 1, 1)
|
360 |
+
else:
|
361 |
+
raise NotImplementedError
|
362 |
+
x_feat = self.gap(dec4)
|
363 |
+
b = x_feat.shape[0]
|
364 |
+
logits_array = []
|
365 |
+
for i in range(b):
|
366 |
+
x_cond = torch.cat([x_feat[i].unsqueeze(0).repeat(self.num_classes, 1, 1, 1), task_encoding], 1)
|
367 |
+
params = self.controller(x_cond) # (num_classes, num_params, 1, 1)
|
368 |
+
params.squeeze_(-1).squeeze_(-1) # (num_classes, num_params)
|
369 |
+
|
370 |
+
head_inputs = self.precls_conv(out[i].unsqueeze(0))
|
371 |
+
head_inputs = head_inputs.repeat(self.num_classes, 1, 1, 1) # (num_classes, 8, H, W)
|
372 |
+
N, _, H, W = head_inputs.size()
|
373 |
+
head_inputs = head_inputs.reshape(1, -1, H, W)
|
374 |
+
# print(head_inputs.shape, params.shape)
|
375 |
+
weights, biases = self.parse_dynamic_params(params, 8, self.weight_nums, self.bias_nums)
|
376 |
+
|
377 |
+
logits = self.heads_forward(head_inputs, weights, biases, N)
|
378 |
+
logits_array.append(logits.reshape(1, -1, H, W))
|
379 |
+
|
380 |
+
out = torch.cat(logits_array, dim=0)
|
381 |
+
out = F.interpolate(out, size=out_shape, mode='bilinear', align_corners=False)
|
382 |
+
# print(out.shape)
|
383 |
+
return out
|
384 |
+
|
385 |
+
|
386 |
+
class NLayerDiscriminator(nn.Module):
|
387 |
+
"""Defines a PatchGAN discriminator"""
|
388 |
+
|
389 |
+
def __init__(self, input_nc, norm='INSTANCE', ndf=64, n_layers=3):
|
390 |
+
"""Construct a PatchGAN discriminator
|
391 |
+
|
392 |
+
Parameters:
|
393 |
+
input_nc (int) -- the number of channels in input images
|
394 |
+
ndf (int) -- the number of filters in the last conv layer
|
395 |
+
n_layers (int) -- the number of conv layers in the discriminator
|
396 |
+
norm_layer -- normalization layer
|
397 |
+
"""
|
398 |
+
super(NLayerDiscriminator, self).__init__()
|
399 |
+
norm_layer = norm_dict[norm]
|
400 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
401 |
+
|
402 |
+
kw = 4
|
403 |
+
padw = 1
|
404 |
+
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
|
405 |
+
nf_mult = 1
|
406 |
+
nf_mult_prev = 1
|
407 |
+
for n in range(1, n_layers): # gradually increase the number of filters
|
408 |
+
nf_mult_prev = nf_mult
|
409 |
+
nf_mult = min(2 ** n, 8)
|
410 |
+
sequence += [
|
411 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
|
412 |
+
norm_layer(ndf * nf_mult),
|
413 |
+
nn.LeakyReLU(0.2, True)
|
414 |
+
]
|
415 |
+
|
416 |
+
nf_mult_prev = nf_mult
|
417 |
+
nf_mult = min(2 ** n_layers, 8)
|
418 |
+
sequence += [
|
419 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
|
420 |
+
norm_layer(ndf * nf_mult),
|
421 |
+
nn.LeakyReLU(0.2, True)
|
422 |
+
]
|
423 |
+
|
424 |
+
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
|
425 |
+
self.model = nn.Sequential(*sequence)
|
426 |
+
|
427 |
+
def forward(self, input):
|
428 |
+
"""Standard forward."""
|
429 |
+
return self.model(input)
|
430 |
+
|
431 |
+
|
432 |
+
class PatchDiscriminator(nn.Module):
|
433 |
+
def __init__(self, in_channels, norm_type='INSTANCE'):
|
434 |
+
super().__init__()
|
435 |
+
nb_filters = [32, 64, 128, 256, 512]
|
436 |
+
strides = [2, 2, 2, 2, 2]
|
437 |
+
|
438 |
+
self.layer1 = ConvNorm(in_channels=in_channels, out_channels=nb_filters[0], kernel_size=5, stride=strides[0], norm='NONE', leaky=True)
|
439 |
+
self.layer2 = ConvNorm(in_channels=nb_filters[0], out_channels=nb_filters[1], kernel_size=5, stride=strides[1], norm=norm_type, leaky=True)
|
440 |
+
self.layer3 = ConvNorm(in_channels=nb_filters[1], out_channels=nb_filters[2], kernel_size=5, stride=strides[2], norm=norm_type, leaky=True)
|
441 |
+
self.layer4 = ConvNorm(in_channels=nb_filters[2], out_channels=nb_filters[3], kernel_size=5, stride=strides[3], norm=norm_type, leaky=True)
|
442 |
+
self.layer5 = ConvNorm(in_channels=nb_filters[3], out_channels=nb_filters[4], kernel_size=5, stride=strides[4], norm=norm_type, leaky=True)
|
443 |
+
|
444 |
+
self.dense_pred = ConvNorm(in_channels=nb_filters[4], out_channels=1, kernel_size=3, stride=1, norm='NONE', activation=False)
|
445 |
+
|
446 |
+
def forward(self, inputs):
|
447 |
+
x1 = self.layer1(inputs)
|
448 |
+
x2 = self.layer2(x1)
|
449 |
+
x3 = self.layer3(x2)
|
450 |
+
x4 = self.layer4(x3)
|
451 |
+
x5 = self.layer5(x4)
|
452 |
+
output = self.dense_pred(x5)
|
453 |
+
output_list = [x1, x2, x3, x4, x5, output]
|
454 |
+
return output_list
|
455 |
+
|
456 |
+
|
457 |
+
class PromptPatchDiscriminator(nn.Module):
|
458 |
+
def __init__(self, in_channels, norm_type='INSTANCE'):
|
459 |
+
super().__init__()
|
460 |
+
nb_filters = [32, 64, 128, 256, 512]
|
461 |
+
strides = [2, 2, 2, 2, 2]
|
462 |
+
|
463 |
+
self.layer1 = ConvNorm(in_channels=in_channels, out_channels=nb_filters[0], kernel_size=5, stride=strides[0], norm='NONE', leaky=True)
|
464 |
+
self.layer2 = ConvNorm(in_channels=nb_filters[0], out_channels=nb_filters[1], kernel_size=5, stride=strides[1], norm=norm_type, leaky=True)
|
465 |
+
self.layer3 = ConvNorm(in_channels=nb_filters[1], out_channels=nb_filters[2], kernel_size=5, stride=strides[2], norm=norm_type, leaky=True)
|
466 |
+
self.layer4 = ConvNorm(in_channels=nb_filters[2], out_channels=nb_filters[3], kernel_size=5, stride=strides[3], norm=norm_type, leaky=True)
|
467 |
+
self.layer5 = ConvNorm(in_channels=nb_filters[3], out_channels=nb_filters[4], kernel_size=5, stride=strides[4], norm=norm_type, leaky=True)
|
468 |
+
|
469 |
+
self.attn4 = PromptAttentionModule(in_channels=nb_filters[3], prompt_channels=512, mid_channels=nb_filters[3] // 4)
|
470 |
+
self.attn5 = PromptAttentionModule(in_channels=nb_filters[4], prompt_channels=512, mid_channels=nb_filters[4] // 4)
|
471 |
+
|
472 |
+
self.dense_pred = ConvNorm(in_channels=nb_filters[4], out_channels=1, kernel_size=3, stride=1, norm='NONE', activation=False)
|
473 |
+
|
474 |
+
def forward(self, inputs, prompt_in):
|
475 |
+
x1 = self.layer1(inputs)
|
476 |
+
x2 = self.layer2(x1)
|
477 |
+
x3 = self.layer3(x2)
|
478 |
+
x4 = self.layer4(x3)
|
479 |
+
x4 = self.attn4(x4, prompt_in)
|
480 |
+
x5 = self.layer5(x4)
|
481 |
+
x5 = self.attn5(x5, prompt_in)
|
482 |
+
output = self.dense_pred(x5)
|
483 |
+
output_list = [x1, x2, x3, x4, x5, output]
|
484 |
+
return output_list
|
485 |
+
|
486 |
+
|
487 |
+
class MultiScaleDiscriminator(nn.Module):
|
488 |
+
def __init__(self, in_channels, norm='INSTANCE', num_D=3):
|
489 |
+
super(MultiScaleDiscriminator, self).__init__()
|
490 |
+
self.num_D = num_D
|
491 |
+
module = PatchDiscriminator
|
492 |
+
|
493 |
+
for i in range(num_D):
|
494 |
+
netD = module(in_channels, norm)
|
495 |
+
setattr(self, 'layer' + str(i), netD)
|
496 |
+
|
497 |
+
self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
|
498 |
+
|
499 |
+
def singleD_forward(self, model, input):
|
500 |
+
return model(input)
|
501 |
+
|
502 |
+
def forward(self, input):
|
503 |
+
num_D = self.num_D
|
504 |
+
result = []
|
505 |
+
input_downsampled = input
|
506 |
+
for i in range(num_D):
|
507 |
+
model = getattr(self, 'layer' + str(num_D - 1 - i))
|
508 |
+
result.append(self.singleD_forward(model, input_downsampled))
|
509 |
+
if i != (num_D - 1):
|
510 |
+
input_downsampled = self.downsample(input_downsampled)
|
511 |
+
return result
|
512 |
+
|
513 |
+
|
514 |
+
class PromptMultiScaleDiscriminator(nn.Module):
|
515 |
+
def __init__(self, in_channels, norm='INSTANCE', num_D=3):
|
516 |
+
super(PromptMultiScaleDiscriminator, self).__init__()
|
517 |
+
self.num_D = num_D
|
518 |
+
module = PromptPatchDiscriminator
|
519 |
+
|
520 |
+
for i in range(num_D):
|
521 |
+
netD = module(in_channels, norm)
|
522 |
+
setattr(self, 'layer' + str(i), netD)
|
523 |
+
|
524 |
+
self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
|
525 |
+
|
526 |
+
def singleD_forward(self, model, input, prompt_in):
|
527 |
+
return model(input, prompt_in)
|
528 |
+
|
529 |
+
def forward(self, input, prompt_in):
|
530 |
+
num_D = self.num_D
|
531 |
+
result = []
|
532 |
+
input_downsampled = input
|
533 |
+
for i in range(num_D):
|
534 |
+
model = getattr(self, 'layer' + str(num_D - 1 - i))
|
535 |
+
result.append(self.singleD_forward(model, input_downsampled, prompt_in))
|
536 |
+
if i != (num_D - 1):
|
537 |
+
input_downsampled = self.downsample(input_downsampled)
|
538 |
+
return result
|
539 |
+
|
540 |
+
|
541 |
+
class HighResEnhancer(nn.Module):
|
542 |
+
"""
|
543 |
+
Design a global-local network for high res generation and enhance boundary information.
|
544 |
+
"""
|
545 |
+
def __init__(self,
|
546 |
+
model_name: str = None,
|
547 |
+
in_channels: int = 1,
|
548 |
+
out_channels: int = None,
|
549 |
+
coarse_channels: tuple = (16, 32, 64, 128, 256),
|
550 |
+
channels: tuple = (32, 64, 128, 256, 512),
|
551 |
+
use_dropout: bool = False,
|
552 |
+
norm: str = 'INSTANCE',
|
553 |
+
leaky: bool = True,
|
554 |
+
use_dilated_bottleneck: bool = False):
|
555 |
+
super().__init__()
|
556 |
+
# define basic blocks
|
557 |
+
self.norm = norm
|
558 |
+
self.leaky = leaky
|
559 |
+
norm_layer = self.get_norm_layer()
|
560 |
+
act_layer = self.get_act_layer()
|
561 |
+
res_unit = ResBlock if channels[-1] <= 512 else ResBottleneck
|
562 |
+
|
563 |
+
# check input channels
|
564 |
+
assert channels[1] == coarse_channels[2], 'The number of channel-2 for coarse and number of channel-1 for fine branch should be the same.'
|
565 |
+
|
566 |
+
# downsample and edge information extraction:
|
567 |
+
# the downsample operation provides the input for coarse branch
|
568 |
+
self.downsample = nn.AvgPool2d(3, stride=2, padding=1)
|
569 |
+
# the sobel filter is operated on the downsampled image to provide edge information
|
570 |
+
self.sobel = SobelEdge(input_dim=2, channels=in_channels)
|
571 |
+
self.sobel_conv = nn.Sequential(
|
572 |
+
nn.Conv2d(in_channels, channels[0], kernel_size=3, stride=2, padding=1),
|
573 |
+
norm_layer(channels[0]),
|
574 |
+
act_layer()
|
575 |
+
)
|
576 |
+
|
577 |
+
# coarse generator: in_channels -> coarse_channels[2]
|
578 |
+
# input stride: 0
|
579 |
+
# output stride: 4 (as input is already 2x downsampled)
|
580 |
+
self.coarse = nn.Sequential(
|
581 |
+
nn.Conv2d(in_channels, coarse_channels[0], kernel_size=3, stride=2, padding=1),
|
582 |
+
norm_layer(coarse_channels[0]),
|
583 |
+
act_layer(),
|
584 |
+
res_unit(coarse_channels[0], coarse_channels[1], stride=2),
|
585 |
+
res_unit(coarse_channels[1], coarse_channels[2], stride=2),
|
586 |
+
res_unit(coarse_channels[2], coarse_channels[3], stride=2),
|
587 |
+
res_unit(coarse_channels[3], coarse_channels[4], stride=1),
|
588 |
+
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
|
589 |
+
res_unit(coarse_channels[4], coarse_channels[3], stride=1),
|
590 |
+
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
|
591 |
+
res_unit(coarse_channels[3], coarse_channels[2], stride=1),
|
592 |
+
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
|
593 |
+
res_unit(coarse_channels[2], coarse_channels[2], stride=1),
|
594 |
+
)
|
595 |
+
|
596 |
+
# fine generator: used to enhance the generation for better details
|
597 |
+
# 1. simple encoder: channels[0] -> channels[1]
|
598 |
+
# input stride: 0
|
599 |
+
# output stride: 4
|
600 |
+
self.fine_encoder = nn.Sequential(
|
601 |
+
nn.Conv2d(in_channels, channels[0], kernel_size=3, stride=2, padding=1),
|
602 |
+
norm_layer(channels[0]),
|
603 |
+
act_layer(),
|
604 |
+
nn.Conv2d(channels[0], channels[1], kernel_size=3, stride=2, padding=1),
|
605 |
+
norm_layer(channels[1]),
|
606 |
+
act_layer()
|
607 |
+
)
|
608 |
+
# 2. bottleneck: channels[1] -> channels[4]
|
609 |
+
# input stride: 4
|
610 |
+
# output stride: 16
|
611 |
+
self.bottleneck = nn.Sequential(
|
612 |
+
res_unit(channels[1], channels[2], stride=2),
|
613 |
+
res_unit(channels[2], channels[3], stride=2),
|
614 |
+
res_unit(channels[3], channels[4], stride=1),
|
615 |
+
res_unit(channels[4], channels[4], stride=1),
|
616 |
+
)
|
617 |
+
if use_dilated_bottleneck:
|
618 |
+
self.bottleneck.add_module('dilated_block_1',
|
619 |
+
nn.Sequential(
|
620 |
+
nn.Conv2d(channels[4], channels[4], kernel_size=3, stride=1, padding=1, dilation=1),
|
621 |
+
norm_layer(channels[4]),
|
622 |
+
act_layer()
|
623 |
+
))
|
624 |
+
self.bottleneck.add_module('dilated_block_2',
|
625 |
+
nn.Sequential(
|
626 |
+
nn.Conv2d(channels[4], channels[4], kernel_size=3, stride=1, padding=2, dilation=2),
|
627 |
+
norm_layer(channels[4]),
|
628 |
+
act_layer()
|
629 |
+
))
|
630 |
+
self.bottleneck.add_module('dilated_block_3',
|
631 |
+
nn.Sequential(
|
632 |
+
nn.Conv2d(channels[4], channels[4], kernel_size=3, stride=1, padding=5, dilation=5),
|
633 |
+
norm_layer(channels[4]),
|
634 |
+
act_layer()
|
635 |
+
))
|
636 |
+
self.bottleneck.add_module('dilated_block_4',
|
637 |
+
nn.Sequential(
|
638 |
+
nn.Conv2d(channels[4], channels[4], kernel_size=3, stride=1, padding=1, dilation=1),
|
639 |
+
norm_layer(channels[4]),
|
640 |
+
act_layer()
|
641 |
+
))
|
642 |
+
self.bottleneck.add_module('dilated_block_5',
|
643 |
+
nn.Sequential(
|
644 |
+
nn.Conv2d(channels[4], channels[4], kernel_size=3, stride=1, padding=2, dilation=2),
|
645 |
+
norm_layer(channels[4]),
|
646 |
+
act_layer()
|
647 |
+
))
|
648 |
+
self.bottleneck.add_module('dilated_block_6',
|
649 |
+
nn.Sequential(
|
650 |
+
nn.Conv2d(channels[4], channels[4], kernel_size=3, stride=1, padding=5, dilation=5),
|
651 |
+
norm_layer(channels[4]),
|
652 |
+
act_layer()
|
653 |
+
))
|
654 |
+
|
655 |
+
# 3. simple decoder: channels[4] -> channels[0]
|
656 |
+
# input stride: 16
|
657 |
+
# output stride: 2
|
658 |
+
self.decoder = nn.Sequential(
|
659 |
+
res_unit(channels[4], channels[3], stride=1),
|
660 |
+
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
|
661 |
+
res_unit(channels[3], channels[2], stride=1),
|
662 |
+
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
|
663 |
+
res_unit(channels[2], channels[1], stride=1),
|
664 |
+
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
|
665 |
+
res_unit(channels[1], channels[0], stride=1),
|
666 |
+
)
|
667 |
+
|
668 |
+
# output operation that combines both feature branch and edge branch
|
669 |
+
# input stride: 2
|
670 |
+
# output stride: 0
|
671 |
+
self.output = nn.Sequential(
|
672 |
+
nn.Conv2d(2 * channels[0], channels[0], kernel_size=3, stride=1, padding=1),
|
673 |
+
norm_layer(channels[0]),
|
674 |
+
act_layer(),
|
675 |
+
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
|
676 |
+
nn.Conv2d(channels[0], out_channels, kernel_size=1, stride=1, bias=False)
|
677 |
+
)
|
678 |
+
|
679 |
+
def get_norm_layer(self):
|
680 |
+
if self.norm == 'INSTANCE':
|
681 |
+
return partial(nn.InstanceNorm2d, affine=False)
|
682 |
+
elif self.norm == 'BATCH':
|
683 |
+
return partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
|
684 |
+
elif self.norm == 'GROUP':
|
685 |
+
return partial(nn.GroupNorm, num_groups=8)
|
686 |
+
else:
|
687 |
+
raise NotImplementedError(f'Normalization layer {self.norm} is not implemented.')
|
688 |
+
|
689 |
+
def get_act_layer(self):
|
690 |
+
if self.leaky:
|
691 |
+
return partial(nn.LeakyReLU, inplace=False)
|
692 |
+
else:
|
693 |
+
return partial(nn.ReLU, inplace=False)
|
694 |
+
|
695 |
+
def forward(self, inputs):
|
696 |
+
"""
|
697 |
+
Args:
|
698 |
+
inputs: (B, C, H, W), input IMC image
|
699 |
+
"""
|
700 |
+
# downsample and edge information extraction
|
701 |
+
downsampled = self.downsample(inputs) # 0 -> 2x stride
|
702 |
+
edge = self.sobel(inputs)
|
703 |
+
edge = self.sobel_conv(edge)
|
704 |
+
|
705 |
+
# coarse generator
|
706 |
+
coarse = self.coarse(downsampled) # 2x stride -> 4x stride
|
707 |
+
# fine generator
|
708 |
+
fine = self.fine_encoder(inputs) # 0x stride -> 4x stride
|
709 |
+
# add coarse and fine information together
|
710 |
+
fine = self.bottleneck(fine + coarse) # 4x stride -> 16x stride
|
711 |
+
fine = self.decoder(fine) # 16x stride -> 2x stride
|
712 |
+
# output operation
|
713 |
+
output = self.output(torch.cat([edge, fine], dim=1))
|
714 |
+
return output
|
test_data/1.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d97725edcd248d40aca15fc720f6caf46e55e5f2eab28fa7a28a0e8a1448dc80
|
3 |
+
size 1890089
|
test_data/10.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9e43c1e80c83dc7898163b54338485fb092c3470326914cd697d700970ba247a
|
3 |
+
size 1935806
|
test_data/11.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:144aac1b1e0d4566133eeb62d65e26fe29d430f082e9fcb0b4fd1794df43a406
|
3 |
+
size 1920270
|
test_data/12.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8bb067a4445326aa36775a728d4d0bdf8ea622f3dc1683b4d1d14e84b31b4e98
|
3 |
+
size 1286013
|
test_data/13.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ba481f70e969558838d17e608013cd838d858700fa628b857766ea44060cb96c
|
3 |
+
size 1858792
|
test_data/14.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bd5f4c6da4a8092095f9749869480281835a572b11086b82c1c1a6e230792071
|
3 |
+
size 1851990
|
test_data/15.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7518a1c2aa2432375794330262570d23631d9a2ebaa4ce924a9ad49df87218b1
|
3 |
+
size 1905786
|
test_data/16.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d4d452f027d2d05088142810a0e3ab9d5692898685182b6bfd0a64ebc1d033ee
|
3 |
+
size 1894100
|
test_data/17.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:43d29ce1f026a7b7f19746dba61e69e6682514164d02dae2f43575eb8f779b77
|
3 |
+
size 1966934
|
test_data/18.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0a8cec8db943c7f24a6cb63e0d935db729f883ff7ea2fafe72859bcbc9371711
|
3 |
+
size 1894208
|
test_data/19.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:84795f8d8f41e27e06c4fa1fa0e1a46e753ff8359b84f6fcc334d50ce28bf144
|
3 |
+
size 1901645
|
test_data/2.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6327b7e06cfc10eb075d45715dfb2a1807a7899bafe6d52c5eb5422332121f51
|
3 |
+
size 1918917
|
test_data/20.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:66ab36131f7630743676ca7d60c4c52e518f296c17b613449b1a45a7c565bfdd
|
3 |
+
size 1834266
|
test_data/21.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2fbb4bad0e82ac5fe66f8e56c5a3c45eefe46c9c96274f258d81f5a8da4f196a
|
3 |
+
size 1898715
|
test_data/22.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:15ec9e41c7ec036284dea853035976a8957f75d2974e9821f0f59e082adce622
|
3 |
+
size 1898663
|
test_data/23.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:173f0e81a90fe67458947d317e60c5d5227760a30f84bd172913f29c51604bfe
|
3 |
+
size 1772117
|
test_data/24.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f4bc4ddfe353eabc21369b1615b6b1800ace7fce3f052b15e2fe5a04e897a9cf
|
3 |
+
size 1933801
|
test_data/25.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0db0ddefd47d481b373aceb9aa9f6e9fdb671a256435e5e1e6cb78c1f5a650c7
|
3 |
+
size 1971978
|
test_data/26.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9a2b0f85da6d5c34c205a54017b18b0c9dbeea04b061f45070cbc4b1dca36e70
|
3 |
+
size 1802038
|
test_data/27.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e67efcfe7eec68477e088104f31732719a0f8b3ca86c92cc5e87f1ab1b465370
|
3 |
+
size 1633565
|
test_data/28.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:deb5dbffd27146b0c65378c82707219376c28d40de9d267acdb3f941fb8f3f87
|
3 |
+
size 1462921
|
test_data/29.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0f4e36c7e5203a3097929ffc987bd802ba0d4b7e2d4641a22623938bea0e4a94
|
3 |
+
size 1919319
|
test_data/3.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fce25bb09db8ed1d7d1537573ea86b614c095fe0398227a2cfbbaac70ac190f2
|
3 |
+
size 1987452
|