eksemyashkina commited on
Commit
3ac18a8
·
verified ·
1 Parent(s): efa4253

Upload 26 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/ARAD_1K_0001.mat filter=lfs diff=lfs merge=lfs -text
37
+ assets/ARAD_1K_0002.mat filter=lfs diff=lfs merge=lfs -text
38
+ assets/ARAD_1K_0003.mat filter=lfs diff=lfs merge=lfs -text
39
+ assets/ARAD_1K_0004.mat filter=lfs diff=lfs merge=lfs -text
40
+ assets/ARAD_1K_0005.mat filter=lfs diff=lfs merge=lfs -text
41
+ assets/ARAD_1K_0006.mat filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import gradio as gr
4
+ import cv2
5
+ import h5py
6
+ from test_develop_code.architecture import model_generator
7
+ from PIL import Image
8
+
9
+
10
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+ device = torch.device("cpu")
12
+ model = model_generator("mst_plus_plus", "mst_plus_plus.pth").to(device)
13
+ model.eval()
14
+ wavelengths = np.linspace(400, 700, 31)
15
+
16
+
17
+ def wavelength_to_rgb(wl: float) -> tuple:
18
+ if 380 <= wl <= 440:
19
+ R = -(wl - 440) / (440 - 380)
20
+ G = 0.0
21
+ B = 1.0
22
+ elif 440 < wl <= 490:
23
+ R = 0.0
24
+ G = (wl - 440) / (490 - 440)
25
+ B = 1.0
26
+ elif 490 < wl <= 510:
27
+ R = 0.0
28
+ G = 1.0
29
+ B = -(wl - 510) / (510 - 490)
30
+ elif 510 < wl <= 580:
31
+ R = (wl - 510) / (580 - 510)
32
+ G = 1.0
33
+ B = 0.0
34
+ elif 580 < wl <= 645:
35
+ R = 1.0
36
+ G = -(wl - 645) / (645 - 580)
37
+ B = 0.0
38
+ elif 645 < wl <= 700:
39
+ R = 1.0
40
+ G = 0.0
41
+ B = 0.0
42
+ else:
43
+ R = G = B = 0.0
44
+ return (max(R, 0.0), max(G, 0.0), max(B, 0.0))
45
+
46
+
47
+ def predict(img):
48
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
49
+ img = img.astype(np.float32) / 255.0
50
+ img = np.transpose(img, (2, 0, 1))
51
+ img_tensor = torch.from_numpy(img).unsqueeze(0).to(device)
52
+ with torch.no_grad():
53
+ pred = model(img_tensor)
54
+ pred = pred.squeeze(0).cpu().numpy()
55
+ pred = np.clip(pred, 0, 1)
56
+ return pred
57
+
58
+
59
+ def visualize_channel(cube: np.ndarray, index: int) -> Image.Image:
60
+ if cube is None:
61
+ return None
62
+ band = cube[index]
63
+ band = (band - band.min()) / (band.max() - band.min() + 1e-8)
64
+ color = wavelength_to_rgb(wavelengths[index])
65
+ rgb = np.stack([band * c for c in color], axis=-1)
66
+ rgb = (rgb * 255).astype(np.uint8)
67
+ return Image.fromarray(rgb)
68
+
69
+
70
+ def load_mat(mat_file) -> np.ndarray:
71
+ with h5py.File(mat_file.name, "r") as f:
72
+ cube = np.array(f["cube"])
73
+ cube = np.transpose(cube, (0, 2, 1))
74
+ cube = np.clip(cube, 0, 1)
75
+ return cube
76
+
77
+
78
+ def reset_all():
79
+ return None, None, None, None, 0
80
+
81
+
82
+ with gr.Blocks() as demo:
83
+ gr.Markdown("## Spectral Reconstruction")
84
+
85
+ with gr.Row():
86
+ with gr.Column():
87
+ rgb_input = gr.Image(type="numpy", label="Upload RGB Image")
88
+ pred_state = gr.State()
89
+ with gr.Column():
90
+ pred_output = gr.Image(label="Prediction Visualization")
91
+ pred_slider = gr.Slider(minimum=0, maximum=30, step=1, label="Channel (Prediction)", value=0)
92
+
93
+ with gr.Row():
94
+ with gr.Column():
95
+ mat_input = gr.File(label="Upload .mat file (Ground Truth)")
96
+ gt_state = gr.State()
97
+ with gr.Column():
98
+ gt_output = gr.Image(label="Ground Truth Visualization")
99
+ gt_slider = gr.Slider(minimum=0, maximum=30, step=1, label="Channel (Ground Truth)", value=0)
100
+
101
+ clear_btn = gr.Button("Clear")
102
+ rgb_input.change(fn=predict, inputs=rgb_input, outputs=pred_state)
103
+ pred_slider.change(fn=visualize_channel, inputs=[pred_state, pred_slider], outputs=pred_output)
104
+
105
+ mat_input.change(fn=load_mat, inputs=mat_input, outputs=gt_state)
106
+ gt_slider.change(fn=visualize_channel, inputs=[gt_state, gt_slider], outputs=gt_output)
107
+
108
+ clear_btn.click(fn=reset_all, outputs=[rgb_input, pred_output, pred_state, gt_state, mat_input])
109
+
110
+ gr.Examples(
111
+ examples=[
112
+ ["assets/ARAD_1K_0001.jpg", 0, "assets/ARAD_1K_0001.mat", 0],
113
+ ["assets/ARAD_1K_0002.jpg", 0, "assets/ARAD_1K_0002.mat", 0],
114
+ ["assets/ARAD_1K_0003.jpg", 0, "assets/ARAD_1K_0003.mat", 0],
115
+ ["assets/ARAD_1K_0004.jpg", 0, "assets/ARAD_1K_0004.mat", 0],
116
+ ["assets/ARAD_1K_0005.jpg", 0, "assets/ARAD_1K_0005.mat", 0],
117
+ ],
118
+ inputs=[rgb_input, pred_slider, mat_input, gt_slider],
119
+ outputs=[pred_output, gt_output],
120
+ label="Try Examples"
121
+ )
122
+
123
+
124
+ if __name__ == "__main__":
125
+ demo.launch()
assets/ARAD_1K_0001.jpg ADDED
assets/ARAD_1K_0001.mat ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d75d3689827d69fdb600ef49eceec4692615db5f0bb8882b41cc3cc0f29a139f
3
+ size 21896837
assets/ARAD_1K_0002.jpg ADDED
assets/ARAD_1K_0002.mat ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5ee379c0574ac7a225cdd6cd39f942441f48d3401f4d1abe0e26a664ce8f1af3
3
+ size 22920548
assets/ARAD_1K_0003.jpg ADDED
assets/ARAD_1K_0003.mat ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0729ab86549ca910803e2d0dcdc76c10087f9a0f4efbb252e08b564bd3a9a741
3
+ size 21218912
assets/ARAD_1K_0004.jpg ADDED
assets/ARAD_1K_0004.mat ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:abe2de20dc3b76e1a1701292e7ad1774aef643a88366084eef58d92f7e280a6d
3
+ size 22377073
assets/ARAD_1K_0005.jpg ADDED
assets/ARAD_1K_0005.mat ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6fd2f928230e076be0d2c707b657a5e6cdf84108cd53d86aab91450d8ce4222b
3
+ size 21042006
assets/ARAD_1K_0006.jpg ADDED
assets/ARAD_1K_0006.mat ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:32d2e4ff2732cd16bf3a9103592c9af304598307dac6967990bd25f00b47d9f7
3
+ size 22078636
mst_plus_plus.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d285430cc688d08434582eee71bae2d82661be7997af1a68d6636ec25f7a3421
3
+ size 6580074
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ opencv-python==4.11.0.86
2
+ einops==0.8.1
3
+ torchvision==0.22.0
4
+ torch==2.7.0
5
+ scipy==1.15.2
6
+ h5py==3.13.0
7
+ hdf5storage==0.1.19
8
+ tqdm==4.67.1
9
+ gdown==5.2.0
10
+ matplotlib==3.10.1
11
+ gradio==5.29.0
12
+
test_challenge_code/architecture/HDNet.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ def default_conv(in_channels, out_channels, kernel_size, bias=True):
4
+ return nn.Conv2d(
5
+ in_channels, out_channels, kernel_size,
6
+ padding=(kernel_size//2), bias=bias)
7
+
8
+ class MeanShift(nn.Conv2d):
9
+ def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):
10
+ super(MeanShift, self).__init__(3, 3, kernel_size=1)
11
+ std = torch.Tensor(rgb_std)
12
+ self.weight.data = torch.eye(3).view(3, 3, 1, 1)
13
+ self.weight.data.div_(std.view(3, 1, 1, 1))
14
+ self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)
15
+ self.bias.data.div_(std)
16
+ self.requires_grad = False
17
+
18
+ class BasicBlock(nn.Sequential):
19
+ def __init__(
20
+ self, in_channels, out_channels, kernel_size, stride=1, bias=False,
21
+ bn=True, act=nn.ReLU(True)):
22
+
23
+ m = [nn.Conv2d(
24
+ in_channels, out_channels, kernel_size,
25
+ padding=(kernel_size//2), stride=stride, bias=bias)
26
+ ]
27
+ if bn: m.append(nn.BatchNorm2d(out_channels))
28
+ if act is not None: m.append(act)
29
+ super(BasicBlock, self).__init__(*m)
30
+
31
+ class ResBlock(nn.Module):
32
+ def __init__(
33
+ self, conv=default_conv, n_feat=31, kernel_size=3,
34
+ bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
35
+
36
+ super(ResBlock, self).__init__()
37
+ m = []
38
+ for i in range(2):
39
+ m.append(conv(n_feat, n_feat, kernel_size, bias=bias))
40
+ if bn: m.append(nn.BatchNorm2d(n_feat))
41
+ if i == 0: m.append(act)
42
+
43
+ self.body = nn.Sequential(*m)
44
+ self.res_scale = res_scale
45
+
46
+ def forward(self, x):
47
+ res = self.body(x).mul(self.res_scale)
48
+ res += x
49
+
50
+ return res
51
+
52
+ class Upsampler(nn.Sequential):
53
+ def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True):
54
+
55
+ m = []
56
+ if (scale & (scale - 1)) == 0: # Is scale = 2^n?
57
+ for _ in range(int(math.log(scale, 2))):
58
+ m.append(conv(n_feat, 4 * n_feat, 3, bias))
59
+ m.append(nn.PixelShuffle(2))
60
+ if bn: m.append(nn.BatchNorm2d(n_feat))
61
+ if act: m.append(act())
62
+ elif scale == 3:
63
+ m.append(conv(n_feat, 9 * n_feat, 3, bias))
64
+ m.append(nn.PixelShuffle(3))
65
+ if bn: m.append(nn.BatchNorm2d(n_feat))
66
+ if act: m.append(act())
67
+ else:
68
+ raise NotImplementedError
69
+
70
+ super(Upsampler, self).__init__(*m)
71
+
72
+ ## add SELayer
73
+ class SELayer(nn.Module):
74
+ def __init__(self, channel, reduction=16):
75
+ super(SELayer, self).__init__()
76
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
77
+ self.conv_du = nn.Sequential(
78
+ nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),
79
+ nn.ReLU(inplace=True),
80
+ nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),
81
+ nn.Sigmoid()
82
+ )
83
+
84
+ def forward(self, x):
85
+ y = self.avg_pool(x)
86
+ y = self.conv_du(y)
87
+ return x * y
88
+
89
+ ## add SEResBlock
90
+ class SEResBlock(nn.Module):
91
+ def __init__(
92
+ self, conv, n_feat, kernel_size, reduction,
93
+ bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
94
+
95
+ super(SEResBlock, self).__init__()
96
+ modules_body = []
97
+ for i in range(2):
98
+ modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))
99
+ if bn: modules_body.append(nn.BatchNorm2d(n_feat))
100
+ if i == 0: modules_body.append(act)
101
+ modules_body.append(SELayer(n_feat, reduction))
102
+ self.body = nn.Sequential(*modules_body)
103
+ self.res_scale = res_scale
104
+
105
+ def forward(self, x):
106
+ res = self.body(x)
107
+ #res = self.body(x).mul(self.res_scale)
108
+ res += x
109
+
110
+ return res
111
+
112
+
113
+ _NORM_BONE = False
114
+
115
+ def constant_init(module, val, bias=0):
116
+ if hasattr(module, 'weight') and module.weight is not None:
117
+ nn.init.constant_(module.weight, val)
118
+ if hasattr(module, 'bias') and module.bias is not None:
119
+ nn.init.constant_(module.bias, bias)
120
+
121
+
122
+ def kaiming_init(module,
123
+ a=0,
124
+ mode='fan_out',
125
+ nonlinearity='relu',
126
+ bias=0,
127
+ distribution='normal'):
128
+ assert distribution in ['uniform', 'normal']
129
+ if distribution == 'uniform':
130
+ nn.init.kaiming_uniform_(
131
+ module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
132
+ else:
133
+ nn.init.kaiming_normal_(
134
+ module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
135
+ if hasattr(module, 'bias') and module.bias is not None:
136
+ nn.init.constant_(module.bias, bias)
137
+
138
+ # depthwise-separable convolution (DSC)
139
+ class DSC(nn.Module):
140
+
141
+ def __init__(self, nin: int) -> None:
142
+ super(DSC, self).__init__()
143
+ self.conv_dws = nn.Conv2d(
144
+ nin, nin, kernel_size=1, stride=1, padding=0, groups=nin
145
+ )
146
+ self.bn_dws = nn.BatchNorm2d(nin, momentum=0.9)
147
+ self.relu_dws = nn.ReLU(inplace=False)
148
+
149
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
150
+
151
+ self.conv_point = nn.Conv2d(
152
+ nin, 1, kernel_size=1, stride=1, padding=0, groups=1
153
+ )
154
+ self.bn_point = nn.BatchNorm2d(1, momentum=0.9)
155
+ self.relu_point = nn.ReLU(inplace=False)
156
+
157
+ self.softmax = nn.Softmax(dim=2)
158
+
159
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
160
+ out = self.conv_dws(x)
161
+ out = self.bn_dws(out)
162
+ out = self.relu_dws(out)
163
+
164
+ out = self.maxpool(out)
165
+
166
+ out = self.conv_point(out)
167
+ out = self.bn_point(out)
168
+ out = self.relu_point(out)
169
+
170
+ m, n, p, q = out.shape
171
+ out = self.softmax(out.view(m, n, -1))
172
+ out = out.view(m, n, p, q)
173
+
174
+ out = out.expand(x.shape[0], x.shape[1], x.shape[2], x.shape[3])
175
+
176
+ out = torch.mul(out, x)
177
+
178
+ out = out + x
179
+
180
+ return out
181
+
182
+ # Efficient Feature Fusion(EFF)
183
+ class EFF(nn.Module):
184
+ def __init__(self, nin: int, nout: int, num_splits: int) -> None:
185
+ super(EFF, self).__init__()
186
+
187
+ assert nin % num_splits == 0
188
+
189
+ self.nin = nin
190
+ self.nout = nout
191
+ self.num_splits = num_splits
192
+ self.subspaces = nn.ModuleList(
193
+ [DSC(int(self.nin / self.num_splits)) for i in range(self.num_splits)]
194
+ )
195
+
196
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
197
+ sub_feat = torch.chunk(x, self.num_splits, dim=1)
198
+ out = []
199
+ for idx, l in enumerate(self.subspaces):
200
+ out.append(self.subspaces[idx](sub_feat[idx]))
201
+ out = torch.cat(out, dim=1)
202
+
203
+ return out
204
+
205
+
206
+ # spatial-spectral domain attention learning(SDL)
207
+ class SDL_attention(nn.Module):
208
+ def __init__(self, inplanes, planes, kernel_size=1, stride=1):
209
+ super(SDL_attention, self).__init__()
210
+
211
+ self.inplanes = inplanes
212
+ self.inter_planes = planes // 2
213
+ self.planes = planes
214
+ self.kernel_size = kernel_size
215
+ self.stride = stride
216
+ self.padding = (kernel_size-1)//2
217
+
218
+ self.conv_q_right = nn.Conv2d(self.inplanes, 1, kernel_size=1, stride=stride, padding=0, bias=False)
219
+ self.conv_v_right = nn.Conv2d(self.inplanes, self.inter_planes, kernel_size=1, stride=stride, padding=0, bias=False)
220
+ self.conv_up = nn.Conv2d(self.inter_planes, self.planes, kernel_size=1, stride=1, padding=0, bias=False)
221
+ self.softmax_right = nn.Softmax(dim=2)
222
+ self.sigmoid = nn.Sigmoid()
223
+
224
+ self.conv_q_left = nn.Conv2d(self.inplanes, self.inter_planes, kernel_size=1, stride=stride, padding=0, bias=False) #g
225
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
226
+ self.conv_v_left = nn.Conv2d(self.inplanes, self.inter_planes, kernel_size=1, stride=stride, padding=0, bias=False) #theta
227
+ self.softmax_left = nn.Softmax(dim=2)
228
+
229
+ self.reset_parameters()
230
+
231
+ def reset_parameters(self):
232
+ kaiming_init(self.conv_q_right, mode='fan_in')
233
+ kaiming_init(self.conv_v_right, mode='fan_in')
234
+ kaiming_init(self.conv_q_left, mode='fan_in')
235
+ kaiming_init(self.conv_v_left, mode='fan_in')
236
+
237
+ self.conv_q_right.inited = True
238
+ self.conv_v_right.inited = True
239
+ self.conv_q_left.inited = True
240
+ self.conv_v_left.inited = True
241
+ # HR spatial attention
242
+ def spatial_attention(self, x):
243
+ input_x = self.conv_v_right(x)
244
+ batch, channel, height, width = input_x.size()
245
+
246
+ input_x = input_x.view(batch, channel, height * width)
247
+ context_mask = self.conv_q_right(x)
248
+ context_mask = context_mask.view(batch, 1, height * width)
249
+ context_mask = self.softmax_right(context_mask)
250
+
251
+ context = torch.matmul(input_x, context_mask.transpose(1,2))
252
+ context = context.unsqueeze(-1)
253
+ context = self.conv_up(context)
254
+
255
+ mask_ch = self.sigmoid(context)
256
+
257
+ out = x * mask_ch
258
+
259
+ return out
260
+ # HR spectral attention
261
+ def spectral_attention(self, x):
262
+
263
+ g_x = self.conv_q_left(x)
264
+ batch, channel, height, width = g_x.size()
265
+
266
+ avg_x = self.avg_pool(g_x)
267
+ batch, channel, avg_x_h, avg_x_w = avg_x.size()
268
+
269
+ avg_x = avg_x.view(batch, channel, avg_x_h * avg_x_w).permute(0, 2, 1)
270
+ theta_x = self.conv_v_left(x).view(batch, self.inter_planes, height * width)
271
+ context = torch.matmul(avg_x, theta_x)
272
+ context = self.softmax_left(context)
273
+ context = context.view(batch, 1, height, width)
274
+
275
+ mask_sp = self.sigmoid(context)
276
+
277
+ out = x * mask_sp
278
+
279
+ return out
280
+
281
+ def forward(self, x):
282
+ context_spectral = self.spectral_attention(x)
283
+ context_spatial = self.spatial_attention(x)
284
+ out = context_spatial + context_spectral
285
+ return out
286
+
287
+
288
+ class HDNet(nn.Module):
289
+
290
+ def __init__(self, in_ch=3, out_ch=31, conv=default_conv):
291
+ super(HDNet, self).__init__()
292
+
293
+ n_resblocks = 32
294
+ n_feats = 48
295
+ kernel_size = 3
296
+ act = nn.ReLU(True)
297
+
298
+ # define head module
299
+ m_head = [conv(in_ch, n_feats, kernel_size)]
300
+
301
+ # define body module
302
+ m_body = [
303
+ ResBlock(
304
+ conv, n_feats, kernel_size, act=act, res_scale= 1
305
+ ) for _ in range(n_resblocks)
306
+ ]
307
+ m_body.append(SDL_attention(inplanes = n_feats, planes = n_feats))
308
+ m_body.append(EFF(nin=n_feats, nout=n_feats, num_splits=4))
309
+
310
+ for i in range(1, n_resblocks):
311
+ m_body.append(ResBlock(
312
+ conv, n_feats, kernel_size, act=act, res_scale= 1
313
+ ))
314
+
315
+ m_body.append(conv(n_feats, n_feats, kernel_size))
316
+
317
+ m_tail = [conv(n_feats, out_ch, kernel_size)]
318
+
319
+ self.head = nn.Sequential(*m_head)
320
+ self.body = nn.Sequential(*m_body)
321
+ self.tail = nn.Sequential(*m_tail)
322
+
323
+ def forward(self, x):
324
+ x = self.head(x)
325
+
326
+ res = self.body(x)
327
+ res += x
328
+
329
+ x = self.tail(res)
330
+
331
+ return x
332
+
333
+ # frequency domain learning(FDL)
334
+ class FDL(nn.Module):
335
+ def __init__(self, loss_weight=1.0, alpha=1.0, patch_factor=1, ave_spectrum=False, log_matrix=False, batch_matrix=False):
336
+ super(FDL, self).__init__()
337
+ self.loss_weight = loss_weight
338
+ self.alpha = alpha
339
+ self.patch_factor = patch_factor
340
+ self.ave_spectrum = ave_spectrum
341
+ self.log_matrix = log_matrix
342
+ self.batch_matrix = batch_matrix
343
+
344
+ def tensor2freq(self, x):
345
+ patch_factor = self.patch_factor
346
+ _, _, h, w = x.shape
347
+ assert h % patch_factor == 0 and w % patch_factor == 0, (
348
+ 'Patch factor should be divisible by image height and width')
349
+ patch_list = []
350
+ patch_h = h // patch_factor
351
+ patch_w = w // patch_factor
352
+ for i in range(patch_factor):
353
+ for j in range(patch_factor):
354
+ patch_list.append(x[:, :, i * patch_h:(i + 1) * patch_h, j * patch_w:(j + 1) * patch_w])
355
+
356
+ y = torch.stack(patch_list, 1)
357
+
358
+ return torch.rfft(y, 2, onesided=False, normalized=True)
359
+
360
+ def loss_formulation(self, recon_freq, real_freq, matrix=None):
361
+ if matrix is not None:
362
+ weight_matrix = matrix.detach()
363
+ else:
364
+ matrix_tmp = (recon_freq - real_freq) ** 2
365
+ matrix_tmp = torch.sqrt(matrix_tmp[..., 0] + matrix_tmp[..., 1]) ** self.alpha
366
+ if self.log_matrix:
367
+ matrix_tmp = torch.log(matrix_tmp + 1.0)
368
+
369
+ if self.batch_matrix:
370
+ matrix_tmp = matrix_tmp / matrix_tmp.max()
371
+ else:
372
+ matrix_tmp = matrix_tmp / matrix_tmp.max(-1).values.max(-1).values[:, :, :, None, None]
373
+
374
+ matrix_tmp[torch.isnan(matrix_tmp)] = 0.0
375
+ matrix_tmp = torch.clamp(matrix_tmp, min=0.0, max=1.0)
376
+ weight_matrix = matrix_tmp.clone().detach()
377
+
378
+ assert weight_matrix.min().item() >= 0 and weight_matrix.max().item() <= 1, (
379
+ 'The values of spectrum weight matrix should be in the range [0, 1], '
380
+ 'but got Min: %.10f Max: %.10f' % (weight_matrix.min().item(), weight_matrix.max().item()))
381
+
382
+ tmp = (recon_freq - real_freq) ** 2
383
+ freq_distance = tmp[..., 0] + tmp[..., 1]
384
+
385
+ loss = weight_matrix * freq_distance
386
+ return torch.mean(loss)
387
+
388
+ def forward(self, pred, target, matrix=None, **kwargs):
389
+
390
+ pred_freq = self.tensor2freq(pred)
391
+ target_freq = self.tensor2freq(target)
392
+
393
+ if self.ave_spectrum:
394
+ pred_freq = torch.mean(pred_freq, 0, keepdim=True)
395
+ target_freq = torch.mean(target_freq, 0, keepdim=True)
396
+
397
+ return self.loss_formulation(pred_freq, target_freq, matrix) * self.loss_weight
test_challenge_code/architecture/HSCNN_Plus.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ class dfus_block(nn.Module):
4
+ def __init__(self, dim):
5
+ super(dfus_block, self).__init__()
6
+ self.conv1 = nn.Conv2d(dim, 128, 1, 1, 0, bias=False)
7
+
8
+ self.conv_up1 = nn.Conv2d(128, 32, 3, 1, 1, bias=False)
9
+ self.conv_up2 = nn.Conv2d(32, 16, 1, 1, 0, bias=False)
10
+
11
+ self.conv_down1 = nn.Conv2d(128, 32, 3, 1, 1, bias=False)
12
+ self.conv_down2 = nn.Conv2d(32, 16, 1, 1, 0, bias=False)
13
+
14
+ self.conv_fution = nn.Conv2d(96, 32, 1, 1, 0, bias=False)
15
+
16
+ #### activation function
17
+ self.relu = nn.ReLU(inplace=True)
18
+
19
+ def forward(self, x):
20
+ """
21
+ x: [b,c,h,w]
22
+ return out:[b,c,h,w]
23
+ """
24
+ feat = self.relu(self.conv1(x))
25
+ feat_up1 = self.relu(self.conv_up1(feat))
26
+ feat_up2 = self.relu(self.conv_up2(feat_up1))
27
+ feat_down1 = self.relu(self.conv_down1(feat))
28
+ feat_down2 = self.relu(self.conv_down2(feat_down1))
29
+ feat_fution = torch.cat([feat_up1,feat_up2,feat_down1,feat_down2],dim=1)
30
+ feat_fution = self.relu(self.conv_fution(feat_fution))
31
+ out = torch.cat([x, feat_fution], dim=1)
32
+ return out
33
+
34
+ class ddfn(nn.Module):
35
+ def __init__(self, dim, num_blocks=78):
36
+ super(ddfn, self).__init__()
37
+
38
+ self.conv_up1 = nn.Conv2d(dim, 32, 3, 1, 1, bias=False)
39
+ self.conv_up2 = nn.Conv2d(32, 32, 1, 1, 0, bias=False)
40
+
41
+ self.conv_down1 = nn.Conv2d(dim, 32, 3, 1, 1, bias=False)
42
+ self.conv_down2 = nn.Conv2d(32, 32, 1, 1, 0, bias=False)
43
+
44
+ dfus_blocks = [dfus_block(dim=128+32*i) for i in range(num_blocks)]
45
+ self.dfus_blocks = nn.Sequential(*dfus_blocks)
46
+
47
+ #### activation function
48
+ self.relu = nn.ReLU(inplace=True)
49
+
50
+ def forward(self, x):
51
+ """
52
+ x: [b,c,h,w]
53
+ return out:[b,c,h,w]
54
+ """
55
+ feat_up1 = self.relu(self.conv_up1(x))
56
+ feat_up2 = self.relu(self.conv_up2(feat_up1))
57
+ feat_down1 = self.relu(self.conv_down1(x))
58
+ feat_down2 = self.relu(self.conv_down2(feat_down1))
59
+ feat_fution = torch.cat([feat_up1,feat_up2,feat_down1,feat_down2],dim=1)
60
+ out = self.dfus_blocks(feat_fution)
61
+ return out
62
+
63
+ class HSCNN_Plus(nn.Module):
64
+ def __init__(self, in_channels=3, out_channels=31, num_blocks=30):
65
+ super(HSCNN_Plus, self).__init__()
66
+
67
+ self.ddfn = ddfn(dim=in_channels, num_blocks=num_blocks)
68
+ self.conv_out = nn.Conv2d(128+32*num_blocks, out_channels, 1, 1, 0, bias=False)
69
+
70
+ def forward(self, x):
71
+ """
72
+ x: [b,c,h,w]
73
+ return out:[b,c,h,w]
74
+ """
75
+ fea = self.ddfn(x)
76
+ out = self.conv_out(fea)
77
+ return out
test_challenge_code/architecture/MIRNet.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --- Imports --- #
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ # from pdb import set_trace as stx
7
+
8
+ def get_pad_layer(pad_type):
9
+ if(pad_type in ['refl','reflect']):
10
+ PadLayer = nn.ReflectionPad2d
11
+ elif(pad_type in ['repl','replicate']):
12
+ PadLayer = nn.ReplicationPad2d
13
+ elif(pad_type=='zero'):
14
+ PadLayer = nn.ZeroPad2d
15
+ else:
16
+ print('Pad type [%s] not recognized'%pad_type)
17
+ return PadLayer
18
+
19
+ class downsamp(nn.Module):
20
+ def __init__(self, pad_type='reflect', filt_size=3, stride=2, channels=None, pad_off=0):
21
+ super(downsamp, self).__init__()
22
+ self.filt_size = filt_size
23
+ self.pad_off = pad_off
24
+ self.pad_sizes = [int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2)), int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2))]
25
+ self.pad_sizes = [pad_size+pad_off for pad_size in self.pad_sizes]
26
+ self.stride = stride
27
+ self.off = int((self.stride-1)/2.)
28
+ self.channels = channels
29
+
30
+ # print('Filter size [%i]'%filt_size)
31
+ if(self.filt_size==1):
32
+ a = np.array([1.,])
33
+ elif(self.filt_size==2):
34
+ a = np.array([1., 1.])
35
+ elif(self.filt_size==3):
36
+ a = np.array([1., 2., 1.])
37
+ elif(self.filt_size==4):
38
+ a = np.array([1., 3., 3., 1.])
39
+ elif(self.filt_size==5):
40
+ a = np.array([1., 4., 6., 4., 1.])
41
+ elif(self.filt_size==6):
42
+ a = np.array([1., 5., 10., 10., 5., 1.])
43
+ elif(self.filt_size==7):
44
+ a = np.array([1., 6., 15., 20., 15., 6., 1.])
45
+
46
+ filt = torch.Tensor(a[:,None]*a[None,:])
47
+ filt = filt/torch.sum(filt)
48
+ self.register_buffer('filt', filt[None,None,:,:].repeat((self.channels,1,1,1)))
49
+
50
+ self.pad = get_pad_layer(pad_type)(self.pad_sizes)
51
+
52
+ def forward(self, inp):
53
+ if(self.filt_size==1):
54
+ if(self.pad_off==0):
55
+ return inp[:,:,::self.stride,::self.stride]
56
+ else:
57
+ return self.pad(inp)[:,:,::self.stride,::self.stride]
58
+ else:
59
+ return F.conv2d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1])
60
+
61
+ ##########################################################################
62
+
63
+ def conv(in_channels, out_channels, kernel_size, bias=False, padding=1, stride=1):
64
+ return nn.Conv2d(
65
+ in_channels, out_channels, kernel_size,
66
+ padding=(kernel_size // 2), bias=bias, stride=stride)
67
+
68
+
69
+ ##########################################################################
70
+ ##---------- Selective Kernel Feature Fusion (SKFF) ----------
71
+ class SKFF(nn.Module):
72
+ def __init__(self, in_channels, height=3, reduction=8, bias=False):
73
+ super(SKFF, self).__init__()
74
+
75
+ self.height = height
76
+ d = max(int(in_channels / reduction), 4)
77
+
78
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
79
+ self.conv_du = nn.Sequential(nn.Conv2d(in_channels, d, 1, padding=0, bias=bias), nn.PReLU())
80
+
81
+ self.fcs = nn.ModuleList([])
82
+ for i in range(self.height):
83
+ self.fcs.append(nn.Conv2d(d, in_channels, kernel_size=1, stride=1, bias=bias))
84
+
85
+ self.softmax = nn.Softmax(dim=1)
86
+
87
+ def forward(self, inp_feats):
88
+ batch_size = inp_feats[0].shape[0]
89
+ n_feats = inp_feats[0].shape[1]
90
+
91
+ inp_feats = torch.cat(inp_feats, dim=1)
92
+ inp_feats = inp_feats.view(batch_size, self.height, n_feats, inp_feats.shape[2], inp_feats.shape[3])
93
+
94
+ feats_U = torch.sum(inp_feats, dim=1)
95
+ feats_S = self.avg_pool(feats_U)
96
+ feats_Z = self.conv_du(feats_S)
97
+
98
+ attention_vectors = [fc(feats_Z) for fc in self.fcs]
99
+ attention_vectors = torch.cat(attention_vectors, dim=1)
100
+ attention_vectors = attention_vectors.view(batch_size, self.height, n_feats, 1, 1)
101
+ # stx()
102
+ attention_vectors = self.softmax(attention_vectors)
103
+
104
+ feats_V = torch.sum(inp_feats * attention_vectors, dim=1)
105
+
106
+ return feats_V
107
+
108
+ ##########################################################################
109
+
110
+
111
+ ##---------- Spatial Attention ----------
112
+ class BasicConv(nn.Module):
113
+ def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True,
114
+ bn=False, bias=False):
115
+ super(BasicConv, self).__init__()
116
+ self.out_channels = out_planes
117
+ self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,
118
+ dilation=dilation, groups=groups, bias=bias)
119
+ self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None
120
+ self.relu = nn.ReLU() if relu else None
121
+
122
+ def forward(self, x):
123
+ x = self.conv(x)
124
+ if self.bn is not None:
125
+ x = self.bn(x)
126
+ if self.relu is not None:
127
+ x = self.relu(x)
128
+ return x
129
+
130
+
131
+ class ChannelPool(nn.Module):
132
+ def forward(self, x):
133
+ return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)
134
+
135
+
136
+ class spatial_attn_layer(nn.Module):
137
+ def __init__(self, kernel_size=5):
138
+ super(spatial_attn_layer, self).__init__()
139
+ self.compress = ChannelPool()
140
+ self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2, relu=False)
141
+
142
+ def forward(self, x):
143
+ # import pdb;pdb.set_trace()
144
+ x_compress = self.compress(x)
145
+ x_out = self.spatial(x_compress)
146
+ scale = torch.sigmoid(x_out) # broadcasting
147
+ return x * scale
148
+
149
+
150
+ ##########################################################################
151
+ ## ------ Channel Attention --------------
152
+ class ca_layer(nn.Module):
153
+ def __init__(self, channel, reduction=8, bias=True):
154
+ super(ca_layer, self).__init__()
155
+ # global average pooling: feature --> point
156
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
157
+ # feature channel downscale and upscale --> channel weight
158
+ self.conv_du = nn.Sequential(
159
+ nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=bias),
160
+ nn.ReLU(inplace=True),
161
+ nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=bias),
162
+ nn.Sigmoid()
163
+ )
164
+
165
+ def forward(self, x):
166
+ y = self.avg_pool(x)
167
+ y = self.conv_du(y)
168
+ return x * y
169
+
170
+
171
+ ##########################################################################
172
+ ##---------- Dual Attention Unit (DAU) ----------
173
+ class DAU(nn.Module):
174
+ def __init__(
175
+ self, n_feat, kernel_size=3, reduction=8,
176
+ bias=False, bn=False, act=nn.PReLU(), res_scale=1):
177
+ super(DAU, self).__init__()
178
+ modules_body = [conv(n_feat, n_feat, kernel_size, bias=bias), act, conv(n_feat, n_feat, kernel_size, bias=bias)]
179
+ self.body = nn.Sequential(*modules_body)
180
+
181
+ ## Spatial Attention
182
+ self.SA = spatial_attn_layer()
183
+
184
+ ## Channel Attention
185
+ self.CA = ca_layer(n_feat, reduction, bias=bias)
186
+
187
+ self.conv1x1 = nn.Conv2d(n_feat * 2, n_feat, kernel_size=1, bias=bias)
188
+
189
+ def forward(self, x):
190
+ res = self.body(x)
191
+ sa_branch = self.SA(res)
192
+ ca_branch = self.CA(res)
193
+ res = torch.cat([sa_branch, ca_branch], dim=1)
194
+ res = self.conv1x1(res)
195
+ res += x
196
+ return res
197
+
198
+
199
+ ##########################################################################
200
+ ##---------- Resizing Modules ----------
201
+ class ResidualDownSample(nn.Module):
202
+ def __init__(self, in_channels, bias=False):
203
+ super(ResidualDownSample, self).__init__()
204
+
205
+ self.top = nn.Sequential(nn.Conv2d(in_channels, in_channels, 1, stride=1, padding=0, bias=bias),
206
+ nn.PReLU(),
207
+ nn.Conv2d(in_channels, in_channels, 3, stride=1, padding=1, bias=bias),
208
+ nn.PReLU(),
209
+ downsamp(channels=in_channels, filt_size=3, stride=2),
210
+ nn.Conv2d(in_channels, in_channels * 2, 1, stride=1, padding=0, bias=bias))
211
+
212
+ self.bot = nn.Sequential(downsamp(channels=in_channels, filt_size=3, stride=2),
213
+ nn.Conv2d(in_channels, in_channels * 2, 1, stride=1, padding=0, bias=bias))
214
+
215
+ def forward(self, x):
216
+ top = self.top(x)
217
+ bot = self.bot(x)
218
+ out = top + bot
219
+ return out
220
+
221
+
222
+ class DownSample(nn.Module):
223
+ def __init__(self, in_channels, scale_factor, stride=2, kernel_size=3):
224
+ super(DownSample, self).__init__()
225
+ self.scale_factor = int(np.log2(scale_factor))
226
+
227
+ modules_body = []
228
+ for i in range(self.scale_factor):
229
+ modules_body.append(ResidualDownSample(in_channels))
230
+ in_channels = int(in_channels * stride)
231
+
232
+ self.body = nn.Sequential(*modules_body)
233
+
234
+ def forward(self, x):
235
+ x = self.body(x)
236
+ return x
237
+
238
+
239
+ class ResidualUpSample(nn.Module):
240
+ def __init__(self, in_channels, bias=False):
241
+ super(ResidualUpSample, self).__init__()
242
+
243
+ self.top = nn.Sequential(nn.Conv2d(in_channels, in_channels, 1, stride=1, padding=0, bias=bias),
244
+ nn.PReLU(),
245
+ nn.ConvTranspose2d(in_channels, in_channels, 3, stride=2, padding=1, output_padding=1,
246
+ bias=bias),
247
+ nn.PReLU(),
248
+ nn.Conv2d(in_channels, in_channels // 2, 1, stride=1, padding=0, bias=bias))
249
+
250
+ self.bot = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=bias),
251
+ nn.Conv2d(in_channels, in_channels // 2, 1, stride=1, padding=0, bias=bias))
252
+
253
+ def forward(self, x):
254
+ top = self.top(x)
255
+ bot = self.bot(x)
256
+ out = top + bot
257
+ return out
258
+
259
+
260
+ class UpSample(nn.Module):
261
+ def __init__(self, in_channels, scale_factor, stride=2, kernel_size=3):
262
+ super(UpSample, self).__init__()
263
+ self.scale_factor = int(np.log2(scale_factor))
264
+
265
+ modules_body = []
266
+ for i in range(self.scale_factor):
267
+ modules_body.append(ResidualUpSample(in_channels))
268
+ in_channels = int(in_channels // stride)
269
+
270
+ self.body = nn.Sequential(*modules_body)
271
+
272
+ def forward(self, x):
273
+ x = self.body(x)
274
+ return x
275
+
276
+
277
+ ##########################################################################
278
+ ##---------- Multi-Scale Resiudal Block (MSRB) ----------
279
+ class MSRB(nn.Module):
280
+ def __init__(self, n_feat, height, width, stride, bias):
281
+ super(MSRB, self).__init__()
282
+
283
+ self.n_feat, self.height, self.width = n_feat, height, width
284
+ self.blocks = nn.ModuleList([nn.ModuleList([DAU(int(n_feat * stride ** i))] * width) for i in range(height)])
285
+
286
+ INDEX = np.arange(0, width, 2)
287
+ FEATS = [int((stride ** i) * n_feat) for i in range(height)]
288
+ SCALE = [2 ** i for i in range(1, height)]
289
+
290
+ self.last_up = nn.ModuleDict()
291
+ for i in range(1, height):
292
+ self.last_up.update({f'{i}': UpSample(int(n_feat * stride ** i), 2 ** i, stride)})
293
+
294
+ self.down = nn.ModuleDict()
295
+ self.up = nn.ModuleDict()
296
+
297
+ i = 0
298
+ SCALE.reverse()
299
+ for feat in FEATS:
300
+ for scale in SCALE[i:]:
301
+ self.down.update({f'{feat}_{scale}': DownSample(feat, scale, stride)})
302
+ i += 1
303
+
304
+ i = 0
305
+ FEATS.reverse()
306
+ for feat in FEATS:
307
+ for scale in SCALE[i:]:
308
+ self.up.update({f'{feat}_{scale}': UpSample(feat, scale, stride)})
309
+ i += 1
310
+
311
+ self.conv_out = nn.Conv2d(n_feat, n_feat, kernel_size=3, padding=1, bias=bias)
312
+
313
+ self.selective_kernel = nn.ModuleList([SKFF(n_feat * stride ** i, height) for i in range(height)])
314
+
315
+ def forward(self, x):
316
+ inp = x.clone()
317
+ # col 1 only
318
+ blocks_out = []
319
+ for j in range(self.height):
320
+ if j == 0:
321
+ inp = self.blocks[j][0](inp)
322
+ else:
323
+ inp = self.blocks[j][0](self.down[f'{inp.size(1)}_{2}'](inp))
324
+ blocks_out.append(inp)
325
+
326
+ # rest of grid
327
+ for i in range(1, self.width):
328
+ # Mesh
329
+ # Replace condition(i%2!=0) with True(Mesh) or False(Plain)
330
+ # if i%2!=0:
331
+ if True:
332
+ tmp = []
333
+ for j in range(self.height):
334
+ TENSOR = []
335
+ nfeats = (2 ** j) * self.n_feat
336
+ for k in range(self.height):
337
+ TENSOR.append(self.select_up_down(blocks_out[k], j, k))
338
+
339
+ selective_kernel_fusion = self.selective_kernel[j](TENSOR)
340
+ tmp.append(selective_kernel_fusion)
341
+ # Plain
342
+ else:
343
+ tmp = blocks_out
344
+ # Forward through either mesh or plain
345
+ for j in range(self.height):
346
+ blocks_out[j] = self.blocks[j][i](tmp[j])
347
+
348
+ # Sum after grid
349
+ out = []
350
+ for k in range(self.height):
351
+ out.append(self.select_last_up(blocks_out[k], k))
352
+
353
+ out = self.selective_kernel[0](out)
354
+
355
+ out = self.conv_out(out)
356
+ out = out + x
357
+
358
+ return out
359
+
360
+ def select_up_down(self, tensor, j, k):
361
+ if j == k:
362
+ return tensor
363
+ else:
364
+ diff = 2 ** np.abs(j - k)
365
+ if j < k:
366
+ return self.up[f'{tensor.size(1)}_{diff}'](tensor)
367
+ else:
368
+ return self.down[f'{tensor.size(1)}_{diff}'](tensor)
369
+
370
+ def select_last_up(self, tensor, k):
371
+ if k == 0:
372
+ return tensor
373
+ else:
374
+ return self.last_up[f'{k}'](tensor)
375
+
376
+
377
+ ##########################################################################
378
+ ##---------- Recursive Residual Group (RRG) ----------
379
+ class RRG(nn.Module):
380
+ def __init__(self, n_feat, n_MSRB, height, width, stride, bias=False):
381
+ super(RRG, self).__init__()
382
+ modules_body = [MSRB(n_feat, height, width, stride, bias) for _ in range(n_MSRB)]
383
+ modules_body.append(conv(n_feat, n_feat, kernel_size=3))
384
+ self.body = nn.Sequential(*modules_body)
385
+
386
+ def forward(self, x):
387
+ res = self.body(x)
388
+ res += x
389
+ return res
390
+
391
+ ##########################################################################
392
+ ##---------- MIRNet -----------------------
393
+ class MIRNet(nn.Module):
394
+ def __init__(self, in_channels=3, out_channels=31, n_feat=31, kernel_size=3, stride=2, n_RRG=2, n_MSRB=1, height=3,
395
+ width=1, bias=False):
396
+ super(MIRNet, self).__init__()
397
+
398
+ self.conv_in = nn.Conv2d(in_channels, n_feat, kernel_size=kernel_size, padding=(kernel_size - 1) // 2,
399
+ bias=bias)
400
+
401
+ modules_body = [RRG(n_feat, n_MSRB, height, width, stride, bias) for _ in range(n_RRG)]
402
+ self.body = nn.Sequential(*modules_body)
403
+
404
+ self.conv_out = nn.Conv2d(n_feat, out_channels, kernel_size=kernel_size, padding=(kernel_size - 1) // 2,
405
+ bias=bias)
406
+ def forward(self, x):
407
+ b, c, h_inp, w_inp = x.shape
408
+ hb, wb = 8, 8
409
+ pad_h = (hb - h_inp % hb) % hb
410
+ pad_w = (wb - w_inp % wb) % wb
411
+ x = F.pad(x, [0, pad_w, 0, pad_h], mode='reflect')
412
+ x = self.conv_in(x)
413
+ h = self.body(x)
414
+ h = self.conv_out(h)
415
+ h += x
416
+ return h[:, :, :h_inp, :w_inp]
test_challenge_code/architecture/MPRNet.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ ##########################################################################
6
+ def conv(in_channels, out_channels, kernel_size, bias=False, stride = 1):
7
+ return nn.Conv2d(
8
+ in_channels, out_channels, kernel_size,
9
+ padding=(kernel_size//2), bias=bias, stride = stride)
10
+
11
+
12
+ ##########################################################################
13
+ ## Channel Attention Layer
14
+ class CALayer(nn.Module):
15
+ def __init__(self, channel, reduction=16, bias=False):
16
+ super(CALayer, self).__init__()
17
+ # global average pooling: feature --> point
18
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
19
+ # feature channel downscale and upscale --> channel weight
20
+ self.conv_du = nn.Sequential(
21
+ nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=bias),
22
+ nn.ReLU(inplace=True),
23
+ nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=bias),
24
+ nn.Sigmoid()
25
+ )
26
+
27
+ def forward(self, x):
28
+ y = self.avg_pool(x)
29
+ y = self.conv_du(y)
30
+ return x * y
31
+
32
+
33
+ ##########################################################################
34
+ ## Channel Attention Block (CAB)
35
+ class CAB(nn.Module):
36
+ def __init__(self, n_feat, kernel_size, reduction, bias, act):
37
+ super(CAB, self).__init__()
38
+ modules_body = []
39
+ modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))
40
+ modules_body.append(act)
41
+ modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))
42
+
43
+ self.CA = CALayer(n_feat, reduction, bias=bias)
44
+ self.body = nn.Sequential(*modules_body)
45
+
46
+ def forward(self, x):
47
+ res = self.body(x)
48
+ res = self.CA(res)
49
+ res += x
50
+ return res
51
+
52
+ ##########################################################################
53
+ ## Supervised Attention Module
54
+ class SAM(nn.Module):
55
+ def __init__(self, n_feat, kernel_size, bias):
56
+ super(SAM, self).__init__()
57
+ self.conv1 = conv(n_feat, n_feat, kernel_size, bias=bias)
58
+ self.conv2 = conv(n_feat, 31, kernel_size, bias=bias)
59
+ self.conv3 = conv(31, n_feat, kernel_size, bias=bias)
60
+
61
+ def forward(self, x, x_img):
62
+ x1 = self.conv1(x)
63
+ img = self.conv2(x) + x_img
64
+ x2 = torch.sigmoid(self.conv3(img))
65
+ x1 = x1*x2
66
+ x1 = x1+x
67
+ return x1, img
68
+
69
+ ##########################################################################
70
+ ## U-Net
71
+
72
+ class Encoder(nn.Module):
73
+ def __init__(self, n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff):
74
+ super(Encoder, self).__init__()
75
+
76
+ self.encoder_level1 = [CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(2)]
77
+ self.encoder_level2 = [CAB(n_feat+scale_unetfeats, kernel_size, reduction, bias=bias, act=act) for _ in range(2)]
78
+ self.encoder_level3 = [CAB(n_feat+(scale_unetfeats*2), kernel_size, reduction, bias=bias, act=act) for _ in range(2)]
79
+
80
+ self.encoder_level1 = nn.Sequential(*self.encoder_level1)
81
+ self.encoder_level2 = nn.Sequential(*self.encoder_level2)
82
+ self.encoder_level3 = nn.Sequential(*self.encoder_level3)
83
+
84
+ self.down12 = DownSample(n_feat, scale_unetfeats)
85
+ self.down23 = DownSample(n_feat+scale_unetfeats, scale_unetfeats)
86
+
87
+ # Cross Stage Feature Fusion (CSFF)
88
+ if csff:
89
+ self.csff_enc1 = nn.Conv2d(n_feat, n_feat, kernel_size=1, bias=bias)
90
+ self.csff_enc2 = nn.Conv2d(n_feat+scale_unetfeats, n_feat+scale_unetfeats, kernel_size=1, bias=bias)
91
+ self.csff_enc3 = nn.Conv2d(n_feat+(scale_unetfeats*2), n_feat+(scale_unetfeats*2), kernel_size=1, bias=bias)
92
+
93
+ self.csff_dec1 = nn.Conv2d(n_feat, n_feat, kernel_size=1, bias=bias)
94
+ self.csff_dec2 = nn.Conv2d(n_feat+scale_unetfeats, n_feat+scale_unetfeats, kernel_size=1, bias=bias)
95
+ self.csff_dec3 = nn.Conv2d(n_feat+(scale_unetfeats*2), n_feat+(scale_unetfeats*2), kernel_size=1, bias=bias)
96
+
97
+ def forward(self, x, encoder_outs=None, decoder_outs=None):
98
+ enc1 = self.encoder_level1(x)
99
+ if (encoder_outs is not None) and (decoder_outs is not None):
100
+ enc1 = enc1 + self.csff_enc1(encoder_outs[0]) + self.csff_dec1(decoder_outs[0])
101
+
102
+ x = self.down12(enc1)
103
+
104
+ enc2 = self.encoder_level2(x)
105
+ if (encoder_outs is not None) and (decoder_outs is not None):
106
+ enc2 = enc2 + self.csff_enc2(encoder_outs[1]) + self.csff_dec2(decoder_outs[1])
107
+
108
+ x = self.down23(enc2)
109
+
110
+ enc3 = self.encoder_level3(x)
111
+ if (encoder_outs is not None) and (decoder_outs is not None):
112
+ enc3 = enc3 + self.csff_enc3(encoder_outs[2]) + self.csff_dec3(decoder_outs[2])
113
+
114
+ return [enc1, enc2, enc3]
115
+
116
+ class Decoder(nn.Module):
117
+ def __init__(self, n_feat, kernel_size, reduction, act, bias, scale_unetfeats):
118
+ super(Decoder, self).__init__()
119
+
120
+ self.decoder_level1 = [CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(2)]
121
+ self.decoder_level2 = [CAB(n_feat+scale_unetfeats, kernel_size, reduction, bias=bias, act=act) for _ in range(2)]
122
+ self.decoder_level3 = [CAB(n_feat+(scale_unetfeats*2), kernel_size, reduction, bias=bias, act=act) for _ in range(2)]
123
+
124
+ self.decoder_level1 = nn.Sequential(*self.decoder_level1)
125
+ self.decoder_level2 = nn.Sequential(*self.decoder_level2)
126
+ self.decoder_level3 = nn.Sequential(*self.decoder_level3)
127
+
128
+ self.skip_attn1 = CAB(n_feat, kernel_size, reduction, bias=bias, act=act)
129
+ self.skip_attn2 = CAB(n_feat+scale_unetfeats, kernel_size, reduction, bias=bias, act=act)
130
+
131
+ self.up21 = SkipUpSample(n_feat, scale_unetfeats)
132
+ self.up32 = SkipUpSample(n_feat+scale_unetfeats, scale_unetfeats)
133
+
134
+ def forward(self, outs):
135
+ enc1, enc2, enc3 = outs
136
+ dec3 = self.decoder_level3(enc3)
137
+
138
+ x = self.up32(dec3, self.skip_attn2(enc2))
139
+ dec2 = self.decoder_level2(x)
140
+
141
+ x = self.up21(dec2, self.skip_attn1(enc1))
142
+ dec1 = self.decoder_level1(x)
143
+
144
+ return [dec1,dec2,dec3]
145
+
146
+ ##########################################################################
147
+ ##---------- Resizing Modules ----------
148
+ class DownSample(nn.Module):
149
+ def __init__(self, in_channels,s_factor):
150
+ super(DownSample, self).__init__()
151
+ self.down = nn.Sequential(nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=False),
152
+ nn.Conv2d(in_channels, in_channels+s_factor, 1, stride=1, padding=0, bias=False))
153
+
154
+ def forward(self, x):
155
+ x = self.down(x)
156
+ return x
157
+
158
+ class UpSample(nn.Module):
159
+ def __init__(self, in_channels,s_factor):
160
+ super(UpSample, self).__init__()
161
+ self.up = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
162
+ nn.Conv2d(in_channels+s_factor, in_channels, 1, stride=1, padding=0, bias=False))
163
+
164
+ def forward(self, x):
165
+ x = self.up(x)
166
+ return x
167
+
168
+ class SkipUpSample(nn.Module):
169
+ def __init__(self, in_channels,s_factor):
170
+ super(SkipUpSample, self).__init__()
171
+ self.up = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
172
+ nn.Conv2d(in_channels+s_factor, in_channels, 1, stride=1, padding=0, bias=False))
173
+
174
+ def forward(self, x, y):
175
+ x = self.up(x)
176
+ x = x + y
177
+ return x
178
+
179
+ ##########################################################################
180
+ ## Original Resolution Block (ORB)
181
+ class ORB(nn.Module):
182
+ def __init__(self, n_feat, kernel_size, reduction, act, bias, num_cab):
183
+ super(ORB, self).__init__()
184
+ modules_body = []
185
+ modules_body = [CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(num_cab)]
186
+ modules_body.append(conv(n_feat, n_feat, kernel_size))
187
+ self.body = nn.Sequential(*modules_body)
188
+
189
+ def forward(self, x):
190
+ res = self.body(x)
191
+ res += x
192
+ return res
193
+
194
+ ##########################################################################
195
+ class ORSNet(nn.Module):
196
+ def __init__(self, n_feat, scale_orsnetfeats, kernel_size, reduction, act, bias, scale_unetfeats, num_cab):
197
+ super(ORSNet, self).__init__()
198
+
199
+ self.orb1 = ORB(n_feat+scale_orsnetfeats, kernel_size, reduction, act, bias, num_cab)
200
+ self.orb2 = ORB(n_feat+scale_orsnetfeats, kernel_size, reduction, act, bias, num_cab)
201
+ self.orb3 = ORB(n_feat+scale_orsnetfeats, kernel_size, reduction, act, bias, num_cab)
202
+
203
+ self.up_enc1 = UpSample(n_feat, scale_unetfeats)
204
+ self.up_dec1 = UpSample(n_feat, scale_unetfeats)
205
+
206
+ self.up_enc2 = nn.Sequential(UpSample(n_feat+scale_unetfeats, scale_unetfeats), UpSample(n_feat, scale_unetfeats))
207
+ self.up_dec2 = nn.Sequential(UpSample(n_feat+scale_unetfeats, scale_unetfeats), UpSample(n_feat, scale_unetfeats))
208
+
209
+ self.conv_enc1 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias)
210
+ self.conv_enc2 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias)
211
+ self.conv_enc3 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias)
212
+
213
+ self.conv_dec1 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias)
214
+ self.conv_dec2 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias)
215
+ self.conv_dec3 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias)
216
+
217
+ def forward(self, x, encoder_outs, decoder_outs):
218
+ x = self.orb1(x)
219
+ x = x + self.conv_enc1(encoder_outs[0]) + self.conv_dec1(decoder_outs[0])
220
+
221
+ x = self.orb2(x)
222
+ x = x + self.conv_enc2(self.up_enc1(encoder_outs[1])) + self.conv_dec2(self.up_dec1(decoder_outs[1]))
223
+
224
+ x = self.orb3(x)
225
+ x = x + self.conv_enc3(self.up_enc2(encoder_outs[2])) + self.conv_dec3(self.up_dec2(decoder_outs[2]))
226
+
227
+ return x
228
+
229
+
230
+ ##########################################################################
231
+ class MPRNet(nn.Module):
232
+ def __init__(self, in_c=31, out_c=31, n_feat=31, scale_unetfeats=31, scale_orsnetfeats=31, num_cab=4, kernel_size=3, reduction=1, bias=False):
233
+ super(MPRNet, self).__init__()
234
+
235
+ self.conv_in = nn.Conv2d(3, in_c, kernel_size=kernel_size, padding=(kernel_size - 1) // 2,
236
+ bias=bias)
237
+
238
+ act=nn.PReLU()
239
+ self.shallow_feat1 = nn.Sequential(conv(in_c, n_feat, kernel_size, bias=bias), CAB(n_feat,kernel_size, reduction, bias=bias, act=act))
240
+ self.shallow_feat2 = nn.Sequential(conv(in_c, n_feat, kernel_size, bias=bias), CAB(n_feat,kernel_size, reduction, bias=bias, act=act))
241
+ self.shallow_feat3 = nn.Sequential(conv(in_c, n_feat, kernel_size, bias=bias), CAB(n_feat,kernel_size, reduction, bias=bias, act=act))
242
+
243
+ # Cross Stage Feature Fusion (CSFF)
244
+ self.stage1_encoder = Encoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff=False)
245
+ self.stage1_decoder = Decoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats)
246
+
247
+ self.stage2_encoder = Encoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff=True)
248
+ self.stage2_decoder = Decoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats)
249
+
250
+ self.stage3_orsnet = ORSNet(n_feat, scale_orsnetfeats, kernel_size, reduction, act, bias, scale_unetfeats, num_cab)
251
+
252
+ self.sam12 = SAM(n_feat, kernel_size=1, bias=bias)
253
+ self.sam23 = SAM(n_feat, kernel_size=1, bias=bias)
254
+
255
+ self.concat12 = conv(n_feat*2, n_feat, kernel_size, bias=bias)
256
+ self.concat23 = conv(n_feat*2, n_feat+scale_orsnetfeats, kernel_size, bias=bias)
257
+ self.tail = conv(n_feat+scale_orsnetfeats, out_c, kernel_size, bias=bias)
258
+
259
+ def forward(self, x3_img):
260
+ b, c, h_inp, w_inp = x3_img.shape
261
+ hb, wb = 8, 8
262
+ pad_h = (hb - h_inp % hb) % hb
263
+ pad_w = (wb - w_inp % wb) % wb
264
+ x3_img = F.pad(x3_img, [0, pad_w, 0, pad_h], mode='reflect')
265
+ x3_img = self.conv_in(x3_img)
266
+
267
+ # Original-resolution Image for Stage 3
268
+ H = x3_img.size(2)
269
+ W = x3_img.size(3)
270
+
271
+ # Multi-Patch Hierarchy: Split Image into four non-overlapping patches
272
+
273
+ # Two Patches for Stage 2
274
+ x2top_img = x3_img[:,:,0:int(H/2),:]
275
+ x2bot_img = x3_img[:,:,int(H/2):H,:]
276
+
277
+ # Four Patches for Stage 1
278
+ x1ltop_img = x2top_img[:,:,:,0:int(W/2)]
279
+ x1rtop_img = x2top_img[:,:,:,int(W/2):W]
280
+ x1lbot_img = x2bot_img[:,:,:,0:int(W/2)]
281
+ x1rbot_img = x2bot_img[:,:,:,int(W/2):W]
282
+
283
+ ##-------------------------------------------
284
+ ##-------------- Stage 1---------------------
285
+ ##-------------------------------------------
286
+ ## Compute Shallow Features
287
+ x1ltop = self.shallow_feat1(x1ltop_img)
288
+ x1rtop = self.shallow_feat1(x1rtop_img)
289
+ x1lbot = self.shallow_feat1(x1lbot_img)
290
+ x1rbot = self.shallow_feat1(x1rbot_img)
291
+
292
+ ## Process features of all 4 patches with Encoder of Stage 1
293
+ feat1_ltop = self.stage1_encoder(x1ltop)
294
+ feat1_rtop = self.stage1_encoder(x1rtop)
295
+ feat1_lbot = self.stage1_encoder(x1lbot)
296
+ feat1_rbot = self.stage1_encoder(x1rbot)
297
+
298
+ ## Concat deep features
299
+ feat1_top = [torch.cat((k,v), 3) for k,v in zip(feat1_ltop,feat1_rtop)]
300
+ feat1_bot = [torch.cat((k,v), 3) for k,v in zip(feat1_lbot,feat1_rbot)]
301
+
302
+ ## Pass features through Decoder of Stage 1
303
+ res1_top = self.stage1_decoder(feat1_top)
304
+ res1_bot = self.stage1_decoder(feat1_bot)
305
+
306
+ ## Apply Supervised Attention Module (SAM)
307
+ x2top_samfeats, stage1_img_top = self.sam12(res1_top[0], x2top_img)
308
+ x2bot_samfeats, stage1_img_bot = self.sam12(res1_bot[0], x2bot_img)
309
+
310
+ ## Output image at Stage 1
311
+ stage1_img = torch.cat([stage1_img_top, stage1_img_bot],2)
312
+ ##-------------------------------------------
313
+ ##-------------- Stage 2---------------------
314
+ ##-------------------------------------------
315
+ ## Compute Shallow Features
316
+ x2top = self.shallow_feat2(x2top_img)
317
+ x2bot = self.shallow_feat2(x2bot_img)
318
+
319
+ ## Concatenate SAM features of Stage 1 with shallow features of Stage 2
320
+ x2top_cat = self.concat12(torch.cat([x2top, x2top_samfeats], 1))
321
+ x2bot_cat = self.concat12(torch.cat([x2bot, x2bot_samfeats], 1))
322
+
323
+ ## Process features of both patches with Encoder of Stage 2
324
+ feat2_top = self.stage2_encoder(x2top_cat, feat1_top, res1_top)
325
+ feat2_bot = self.stage2_encoder(x2bot_cat, feat1_bot, res1_bot)
326
+
327
+ ## Concat deep features
328
+ feat2 = [torch.cat((k,v), 2) for k,v in zip(feat2_top,feat2_bot)]
329
+
330
+ ## Pass features through Decoder of Stage 2
331
+ res2 = self.stage2_decoder(feat2)
332
+
333
+ ## Apply SAM
334
+ x3_samfeats, stage2_img = self.sam23(res2[0], x3_img)
335
+
336
+
337
+ ##-------------------------------------------
338
+ ##-------------- Stage 3---------------------
339
+ ##-------------------------------------------
340
+ ## Compute Shallow Features
341
+ x3 = self.shallow_feat3(x3_img)
342
+
343
+ ## Concatenate SAM features of Stage 2 with shallow features of Stage 3
344
+ x3_cat = self.concat23(torch.cat([x3, x3_samfeats], 1))
345
+
346
+ x3_cat = self.stage3_orsnet(x3_cat, feat2, res2)
347
+
348
+ stage3_img = self.tail(x3_cat)
349
+
350
+ return (stage3_img + x3_img)[:, :, :h_inp, :w_inp]
test_challenge_code/architecture/MST.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from einops import rearrange
5
+ import math
6
+ import warnings
7
+ from torch.nn.init import _calculate_fan_in_and_fan_out
8
+
9
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
10
+ def norm_cdf(x):
11
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
12
+
13
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
14
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
15
+ "The distribution of values may be incorrect.",
16
+ stacklevel=2)
17
+ with torch.no_grad():
18
+ l = norm_cdf((a - mean) / std)
19
+ u = norm_cdf((b - mean) / std)
20
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
21
+ tensor.erfinv_()
22
+ tensor.mul_(std * math.sqrt(2.))
23
+ tensor.add_(mean)
24
+ tensor.clamp_(min=a, max=b)
25
+ return tensor
26
+
27
+
28
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
29
+ # type: (Tensor, float, float, float, float) -> Tensor
30
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
31
+
32
+
33
+ def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
34
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
35
+ if mode == 'fan_in':
36
+ denom = fan_in
37
+ elif mode == 'fan_out':
38
+ denom = fan_out
39
+ elif mode == 'fan_avg':
40
+ denom = (fan_in + fan_out) / 2
41
+ variance = scale / denom
42
+ if distribution == "truncated_normal":
43
+ trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978)
44
+ elif distribution == "normal":
45
+ tensor.normal_(std=math.sqrt(variance))
46
+ elif distribution == "uniform":
47
+ bound = math.sqrt(3 * variance)
48
+ tensor.uniform_(-bound, bound)
49
+ else:
50
+ raise ValueError(f"invalid distribution {distribution}")
51
+
52
+
53
+ def lecun_normal_(tensor):
54
+ variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal')
55
+
56
+
57
+ class PreNorm(nn.Module):
58
+ def __init__(self, dim, fn):
59
+ super().__init__()
60
+ self.fn = fn
61
+ self.norm = nn.LayerNorm(dim)
62
+
63
+ def forward(self, x, *args, **kwargs):
64
+ x = self.norm(x)
65
+ return self.fn(x, *args, **kwargs)
66
+
67
+
68
+ class GELU(nn.Module):
69
+ def forward(self, x):
70
+ return F.gelu(x)
71
+
72
+ def conv(in_channels, out_channels, kernel_size, bias=False, padding = 1, stride = 1):
73
+ return nn.Conv2d(
74
+ in_channels, out_channels, kernel_size,
75
+ padding=(kernel_size//2), bias=bias, stride=stride)
76
+
77
+
78
+ def shift_back(inputs,step=2): # input [bs,28,256,310] output [bs, 28, 256, 256]
79
+ [bs, nC, row, col] = inputs.shape
80
+ down_sample = 256//row
81
+ step = float(step)/float(down_sample*down_sample)
82
+ out_col = row
83
+ for i in range(nC):
84
+ inputs[:,i,:,:out_col] = \
85
+ inputs[:,i,:,int(step*i):int(step*i)+out_col]
86
+ return inputs[:, :, :, :out_col]
87
+
88
+ class MaskGuidedMechanism(nn.Module):
89
+ def __init__(
90
+ self, n_feat):
91
+ super(MaskGuidedMechanism, self).__init__()
92
+
93
+ self.conv1 = nn.Conv2d(n_feat, n_feat, kernel_size=1, bias=True)
94
+ self.conv2 = nn.Conv2d(n_feat, n_feat, kernel_size=1, bias=True)
95
+ self.depth_conv = nn.Conv2d(n_feat, n_feat, kernel_size=5, padding=2, bias=True, groups=n_feat)
96
+
97
+ def forward(self, mask_shift):
98
+ # x: b,c,h,w
99
+ [bs, nC, row, col] = mask_shift.shape
100
+ mask_shift = self.conv1(mask_shift)
101
+ attn_map = torch.sigmoid(self.depth_conv(self.conv2(mask_shift)))
102
+ res = mask_shift * attn_map
103
+ mask_emb = res + mask_shift
104
+ return mask_emb
105
+
106
+ class MS_MSA(nn.Module):
107
+ def __init__(
108
+ self,
109
+ dim,
110
+ dim_head,
111
+ heads,
112
+ ):
113
+ super().__init__()
114
+ self.num_heads = heads
115
+ self.dim_head = dim_head
116
+ self.to_q = nn.Linear(dim, dim_head * heads, bias=False)
117
+ self.to_k = nn.Linear(dim, dim_head * heads, bias=False)
118
+ self.to_v = nn.Linear(dim, dim_head * heads, bias=False)
119
+ self.rescale = nn.Parameter(torch.ones(heads, 1, 1))
120
+ self.proj = nn.Linear(dim_head * heads, dim, bias=True)
121
+ self.pos_emb = nn.Sequential(
122
+ nn.Conv2d(dim, dim, 3, 1, 1, bias=False, groups=dim),
123
+ GELU(),
124
+ nn.Conv2d(dim, dim, 3, 1, 1, bias=False, groups=dim),
125
+ )
126
+ self.mm = MaskGuidedMechanism(dim)
127
+ self.dim = dim
128
+
129
+ def forward(self, x_in, mask=None):
130
+ """
131
+ x_in: [b,h,w,c]
132
+ mask: [1,h,w,c]
133
+ return out: [b,h,w,c]
134
+ """
135
+ b, h, w, c = x_in.shape
136
+ x = x_in.reshape(b,h*w,c)
137
+ q_inp = self.to_q(x)
138
+ k_inp = self.to_k(x)
139
+ v_inp = self.to_v(x)
140
+ mask_attn = self.mm(mask.permute(0,3,1,2)).permute(0,2,3,1)
141
+ if b != 0:
142
+ mask_attn = (mask_attn[0, :, :, :]).expand([b, h, w, c])
143
+ q, k, v, mask_attn = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.num_heads),
144
+ (q_inp, k_inp, v_inp, mask_attn.flatten(1, 2)))
145
+ v = v * mask_attn
146
+ # q: b,heads,hw,c
147
+ q = q.transpose(-2, -1)
148
+ k = k.transpose(-2, -1)
149
+ v = v.transpose(-2, -1)
150
+ q = F.normalize(q, dim=-1, p=2)
151
+ k = F.normalize(k, dim=-1, p=2)
152
+ attn = (k @ q.transpose(-2, -1)) # A = K^T*Q
153
+ attn = attn * self.rescale
154
+ attn = attn.softmax(dim=-1)
155
+ x = attn @ v # b,heads,d,hw
156
+ x = x.permute(0, 3, 1, 2) # Transpose
157
+ x = x.reshape(b, h * w, self.num_heads * self.dim_head)
158
+ out_c = self.proj(x).view(b, h, w, c)
159
+ out_p = self.pos_emb(v_inp.reshape(b,h,w,c).permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
160
+ out = out_c + out_p
161
+
162
+ return out
163
+
164
+ class FeedForward(nn.Module):
165
+ def __init__(self, dim, mult=4):
166
+ super().__init__()
167
+ self.net = nn.Sequential(
168
+ nn.Conv2d(dim, dim * mult, 1, 1, bias=False),
169
+ GELU(),
170
+ nn.Conv2d(dim * mult, dim * mult, 3, 1, 1, bias=False, groups=dim * mult),
171
+ GELU(),
172
+ nn.Conv2d(dim * mult, dim, 1, 1, bias=False),
173
+ )
174
+
175
+ def forward(self, x):
176
+ """
177
+ x: [b,h,w,c]
178
+ return out: [b,h,w,c]
179
+ """
180
+ out = self.net(x.permute(0, 3, 1, 2))
181
+ return out.permute(0, 2, 3, 1)
182
+
183
+ class MSAB(nn.Module):
184
+ def __init__(
185
+ self,
186
+ dim,
187
+ dim_head,
188
+ heads,
189
+ num_blocks,
190
+ ):
191
+ super().__init__()
192
+ self.blocks = nn.ModuleList([])
193
+ for _ in range(num_blocks):
194
+ self.blocks.append(nn.ModuleList([
195
+ MS_MSA(dim=dim, dim_head=dim_head, heads=heads),
196
+ PreNorm(dim, FeedForward(dim=dim))
197
+ ]))
198
+
199
+ def forward(self, x, mask):
200
+ """
201
+ x: [b,c,h,w]
202
+ return out: [b,c,h,w]
203
+ """
204
+ x = x.permute(0, 2, 3, 1)
205
+ for (attn, ff) in self.blocks:
206
+ x = attn(x, mask=mask.permute(0, 2, 3, 1)) + x
207
+ x = ff(x) + x
208
+ out = x.permute(0, 3, 1, 2)
209
+ return out
210
+
211
+ class MST(nn.Module):
212
+ def __init__(self, dim, stage, num_blocks):
213
+ super(MST, self).__init__()
214
+ self.dim = dim
215
+ self.stage = stage
216
+
217
+ # Input projection
218
+ self.embedding_1 = nn.Conv2d(3, self.dim, 3, 1, 1, bias=False)
219
+ self.embedding_2 = nn.Conv2d(3, self.dim, 3, 1, 1, bias=False)
220
+
221
+ # Encoder
222
+ self.encoder_layers = nn.ModuleList([])
223
+ dim_stage = dim
224
+ for i in range(stage):
225
+ self.encoder_layers.append(nn.ModuleList([
226
+ MSAB(
227
+ dim=dim_stage, num_blocks=num_blocks[i], dim_head=dim, heads=dim_stage // dim),
228
+ nn.Conv2d(dim_stage, dim_stage * 2, 4, 2, 1, bias=False),
229
+ nn.Conv2d(dim_stage, dim_stage * 2, 4, 2, 1, bias=False)
230
+ ]))
231
+ dim_stage *= 2
232
+
233
+ # Bottleneck
234
+ self.bottleneck = MSAB(
235
+ dim=dim_stage, dim_head=dim, heads=dim_stage // dim, num_blocks=num_blocks[-1])
236
+
237
+ # Decoder
238
+ self.decoder_layers = nn.ModuleList([])
239
+ for i in range(stage):
240
+ self.decoder_layers.append(nn.ModuleList([
241
+ nn.ConvTranspose2d(dim_stage, dim_stage // 2, stride=2, kernel_size=2, padding=0, output_padding=0),
242
+ nn.Conv2d(dim_stage, dim_stage // 2, 1, 1, bias=False),
243
+ MSAB(
244
+ dim=dim_stage // 2, num_blocks=num_blocks[stage - 1 - i], dim_head=dim,
245
+ heads=(dim_stage // 2) // dim),
246
+ ]))
247
+ dim_stage //= 2
248
+
249
+ # Output projection
250
+ self.mapping = nn.Conv2d(self.dim, 31, 3, 1, 1, bias=False)
251
+
252
+ #### activation function
253
+ self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
254
+
255
+ def forward(self, x):
256
+ """
257
+ x: [b,c,h,w]
258
+ return out:[b,c,h,w]
259
+ """
260
+ b, c, h_inp, w_inp = x.shape
261
+ hb, wb = 8, 8
262
+ pad_h = (hb - h_inp % hb) % hb
263
+ pad_w = (wb - w_inp % wb) % wb
264
+ x = F.pad(x, [0, pad_w, 0, pad_h], mode='reflect')
265
+
266
+ # Embedding
267
+ mask = self.lrelu(self.embedding_1(x))
268
+ x = self.lrelu(self.embedding_2(x))
269
+ fea = x
270
+
271
+ # Encoder
272
+ fea_encoder = []
273
+ masks = []
274
+ for (MSAB, FeaDownSample, MaskDownSample) in self.encoder_layers:
275
+ fea = MSAB(fea, mask)
276
+ masks.append(mask)
277
+ fea_encoder.append(fea)
278
+ fea = FeaDownSample(fea)
279
+ mask = MaskDownSample(mask)
280
+
281
+ # Bottleneck
282
+ fea = self.bottleneck(fea, mask)
283
+
284
+ # Decoder
285
+ for i, (FeaUpSample, Fution, LeWinBlcok) in enumerate(self.decoder_layers):
286
+ fea = FeaUpSample(fea)
287
+ fea = Fution(torch.cat([fea, fea_encoder[self.stage-1-i]], dim=1))
288
+ mask = masks[self.stage - 1 - i]
289
+ fea = LeWinBlcok(fea, mask)
290
+
291
+ # Mapping
292
+ out = self.mapping(fea) + x
293
+
294
+ return out[:, :, :h_inp, :w_inp]
295
+
296
+
297
+
298
+
299
+
300
+
301
+
302
+
303
+
304
+
305
+
306
+
307
+
308
+
309
+
310
+
311
+
312
+
313
+
test_challenge_code/architecture/MST_Plus_Plus.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from einops import rearrange
5
+ import math
6
+ import warnings
7
+ from torch.nn.init import _calculate_fan_in_and_fan_out
8
+
9
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
10
+ def norm_cdf(x):
11
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
12
+
13
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
14
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
15
+ "The distribution of values may be incorrect.",
16
+ stacklevel=2)
17
+ with torch.no_grad():
18
+ l = norm_cdf((a - mean) / std)
19
+ u = norm_cdf((b - mean) / std)
20
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
21
+ tensor.erfinv_()
22
+ tensor.mul_(std * math.sqrt(2.))
23
+ tensor.add_(mean)
24
+ tensor.clamp_(min=a, max=b)
25
+ return tensor
26
+
27
+
28
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
29
+ # type: (Tensor, float, float, float, float) -> Tensor
30
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
31
+
32
+
33
+ def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
34
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
35
+ if mode == 'fan_in':
36
+ denom = fan_in
37
+ elif mode == 'fan_out':
38
+ denom = fan_out
39
+ elif mode == 'fan_avg':
40
+ denom = (fan_in + fan_out) / 2
41
+ variance = scale / denom
42
+ if distribution == "truncated_normal":
43
+ trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978)
44
+ elif distribution == "normal":
45
+ tensor.normal_(std=math.sqrt(variance))
46
+ elif distribution == "uniform":
47
+ bound = math.sqrt(3 * variance)
48
+ tensor.uniform_(-bound, bound)
49
+ else:
50
+ raise ValueError(f"invalid distribution {distribution}")
51
+
52
+
53
+ def lecun_normal_(tensor):
54
+ variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal')
55
+
56
+
57
+ class PreNorm(nn.Module):
58
+ def __init__(self, dim, fn):
59
+ super().__init__()
60
+ self.fn = fn
61
+ self.norm = nn.LayerNorm(dim)
62
+
63
+ def forward(self, x, *args, **kwargs):
64
+ x = self.norm(x)
65
+ return self.fn(x, *args, **kwargs)
66
+
67
+
68
+ class GELU(nn.Module):
69
+ def forward(self, x):
70
+ return F.gelu(x)
71
+
72
+ def conv(in_channels, out_channels, kernel_size, bias=False, padding = 1, stride = 1):
73
+ return nn.Conv2d(
74
+ in_channels, out_channels, kernel_size,
75
+ padding=(kernel_size//2), bias=bias, stride=stride)
76
+
77
+
78
+ def shift_back(inputs,step=2): # input [bs,28,256,310] output [bs, 28, 256, 256]
79
+ [bs, nC, row, col] = inputs.shape
80
+ down_sample = 256//row
81
+ step = float(step)/float(down_sample*down_sample)
82
+ out_col = row
83
+ for i in range(nC):
84
+ inputs[:,i,:,:out_col] = \
85
+ inputs[:,i,:,int(step*i):int(step*i)+out_col]
86
+ return inputs[:, :, :, :out_col]
87
+
88
+ class MS_MSA(nn.Module):
89
+ def __init__(
90
+ self,
91
+ dim,
92
+ dim_head,
93
+ heads,
94
+ ):
95
+ super().__init__()
96
+ self.num_heads = heads
97
+ self.dim_head = dim_head
98
+ self.to_q = nn.Linear(dim, dim_head * heads, bias=False)
99
+ self.to_k = nn.Linear(dim, dim_head * heads, bias=False)
100
+ self.to_v = nn.Linear(dim, dim_head * heads, bias=False)
101
+ self.rescale = nn.Parameter(torch.ones(heads, 1, 1))
102
+ self.proj = nn.Linear(dim_head * heads, dim, bias=True)
103
+ self.pos_emb = nn.Sequential(
104
+ nn.Conv2d(dim, dim, 3, 1, 1, bias=False, groups=dim),
105
+ GELU(),
106
+ nn.Conv2d(dim, dim, 3, 1, 1, bias=False, groups=dim),
107
+ )
108
+ self.dim = dim
109
+
110
+ def forward(self, x_in):
111
+ """
112
+ x_in: [b,h,w,c]
113
+ return out: [b,h,w,c]
114
+ """
115
+ b, h, w, c = x_in.shape
116
+ x = x_in.reshape(b,h*w,c)
117
+ q_inp = self.to_q(x)
118
+ k_inp = self.to_k(x)
119
+ v_inp = self.to_v(x)
120
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.num_heads),
121
+ (q_inp, k_inp, v_inp))
122
+ v = v
123
+ # q: b,heads,hw,c
124
+ q = q.transpose(-2, -1)
125
+ k = k.transpose(-2, -1)
126
+ v = v.transpose(-2, -1)
127
+ q = F.normalize(q, dim=-1, p=2)
128
+ k = F.normalize(k, dim=-1, p=2)
129
+ attn = (k @ q.transpose(-2, -1)) # A = K^T*Q
130
+ attn = attn * self.rescale
131
+ attn = attn.softmax(dim=-1)
132
+ x = attn @ v # b,heads,d,hw
133
+ x = x.permute(0, 3, 1, 2) # Transpose
134
+ x = x.reshape(b, h * w, self.num_heads * self.dim_head)
135
+ out_c = self.proj(x).view(b, h, w, c)
136
+ out_p = self.pos_emb(v_inp.reshape(b,h,w,c).permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
137
+ out = out_c + out_p
138
+
139
+ return out
140
+
141
+ class FeedForward(nn.Module):
142
+ def __init__(self, dim, mult=4):
143
+ super().__init__()
144
+ self.net = nn.Sequential(
145
+ nn.Conv2d(dim, dim * mult, 1, 1, bias=False),
146
+ GELU(),
147
+ nn.Conv2d(dim * mult, dim * mult, 3, 1, 1, bias=False, groups=dim * mult),
148
+ GELU(),
149
+ nn.Conv2d(dim * mult, dim, 1, 1, bias=False),
150
+ )
151
+
152
+ def forward(self, x):
153
+ """
154
+ x: [b,h,w,c]
155
+ return out: [b,h,w,c]
156
+ """
157
+ out = self.net(x.permute(0, 3, 1, 2))
158
+ return out.permute(0, 2, 3, 1)
159
+
160
+ class MSAB(nn.Module):
161
+ def __init__(
162
+ self,
163
+ dim,
164
+ dim_head,
165
+ heads,
166
+ num_blocks,
167
+ ):
168
+ super().__init__()
169
+ self.blocks = nn.ModuleList([])
170
+ for _ in range(num_blocks):
171
+ self.blocks.append(nn.ModuleList([
172
+ MS_MSA(dim=dim, dim_head=dim_head, heads=heads),
173
+ PreNorm(dim, FeedForward(dim=dim))
174
+ ]))
175
+
176
+ def forward(self, x):
177
+ """
178
+ x: [b,c,h,w]
179
+ return out: [b,c,h,w]
180
+ """
181
+ x = x.permute(0, 2, 3, 1)
182
+ for (attn, ff) in self.blocks:
183
+ x = attn(x) + x
184
+ x = ff(x) + x
185
+ out = x.permute(0, 3, 1, 2)
186
+ return out
187
+
188
+ class MST(nn.Module):
189
+ def __init__(self, in_dim=31, out_dim=31, dim=31, stage=2, num_blocks=[2,4,4]):
190
+ super(MST, self).__init__()
191
+ self.dim = dim
192
+ self.stage = stage
193
+
194
+ # Input projection
195
+ self.embedding = nn.Conv2d(in_dim, self.dim, 3, 1, 1, bias=False)
196
+
197
+ # Encoder
198
+ self.encoder_layers = nn.ModuleList([])
199
+ dim_stage = dim
200
+ for i in range(stage):
201
+ self.encoder_layers.append(nn.ModuleList([
202
+ MSAB(
203
+ dim=dim_stage, num_blocks=num_blocks[i], dim_head=dim, heads=dim_stage // dim),
204
+ nn.Conv2d(dim_stage, dim_stage * 2, 4, 2, 1, bias=False),
205
+ ]))
206
+ dim_stage *= 2
207
+
208
+ # Bottleneck
209
+ self.bottleneck = MSAB(
210
+ dim=dim_stage, dim_head=dim, heads=dim_stage // dim, num_blocks=num_blocks[-1])
211
+
212
+ # Decoder
213
+ self.decoder_layers = nn.ModuleList([])
214
+ for i in range(stage):
215
+ self.decoder_layers.append(nn.ModuleList([
216
+ nn.ConvTranspose2d(dim_stage, dim_stage // 2, stride=2, kernel_size=2, padding=0, output_padding=0),
217
+ nn.Conv2d(dim_stage, dim_stage // 2, 1, 1, bias=False),
218
+ MSAB(
219
+ dim=dim_stage // 2, num_blocks=num_blocks[stage - 1 - i], dim_head=dim,
220
+ heads=(dim_stage // 2) // dim),
221
+ ]))
222
+ dim_stage //= 2
223
+
224
+ # Output projection
225
+ self.mapping = nn.Conv2d(self.dim, out_dim, 3, 1, 1, bias=False)
226
+
227
+ #### activation function
228
+ self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
229
+ self.apply(self._init_weights)
230
+
231
+ def _init_weights(self, m):
232
+ if isinstance(m, nn.Linear):
233
+ trunc_normal_(m.weight, std=.02)
234
+ if isinstance(m, nn.Linear) and m.bias is not None:
235
+ nn.init.constant_(m.bias, 0)
236
+ elif isinstance(m, nn.LayerNorm):
237
+ nn.init.constant_(m.bias, 0)
238
+ nn.init.constant_(m.weight, 1.0)
239
+
240
+ def forward(self, x):
241
+ """
242
+ x: [b,c,h,w]
243
+ return out:[b,c,h,w]
244
+ """
245
+
246
+ # Embedding
247
+ fea = self.embedding(x)
248
+
249
+ # Encoder
250
+ fea_encoder = []
251
+ for (MSAB, FeaDownSample) in self.encoder_layers:
252
+ fea = MSAB(fea)
253
+ fea_encoder.append(fea)
254
+ fea = FeaDownSample(fea)
255
+
256
+ # Bottleneck
257
+ fea = self.bottleneck(fea)
258
+
259
+ # Decoder
260
+ for i, (FeaUpSample, Fution, LeWinBlcok) in enumerate(self.decoder_layers):
261
+ fea = FeaUpSample(fea)
262
+ fea = Fution(torch.cat([fea, fea_encoder[self.stage-1-i]], dim=1))
263
+ fea = LeWinBlcok(fea)
264
+
265
+ # Mapping
266
+ out = self.mapping(fea) + x
267
+
268
+ return out
269
+
270
+ class MST_Plus_Plus(nn.Module):
271
+ def __init__(self, in_channels=3, out_channels=31, n_feat=31, stage=3):
272
+ super(MST_Plus_Plus, self).__init__()
273
+ self.stage = stage
274
+ self.conv_in = nn.Conv2d(in_channels, n_feat, kernel_size=3, padding=(3 - 1) // 2,bias=False)
275
+ modules_body = [MST(dim=31, stage=2, num_blocks=[1,1,1]) for _ in range(stage)]
276
+ self.body = nn.Sequential(*modules_body)
277
+ self.conv_out = nn.Conv2d(n_feat, out_channels, kernel_size=3, padding=(3 - 1) // 2,bias=False)
278
+
279
+ def forward(self, x):
280
+ """
281
+ x: [b,c,h,w]
282
+ return out:[b,c,h,w]
283
+ """
284
+ b, c, h_inp, w_inp = x.shape
285
+ hb, wb = 8, 8
286
+ pad_h = (hb - h_inp % hb) % hb
287
+ pad_w = (wb - w_inp % wb) % wb
288
+ x = F.pad(x, [0, pad_w, 0, pad_h], mode='reflect')
289
+ x = self.conv_in(x)
290
+ h = self.body(x)
291
+ h = self.conv_out(h)
292
+ h += x
293
+ return h[:, :, :h_inp, :w_inp]
294
+
295
+
296
+
297
+
298
+
299
+
300
+
301
+
302
+
303
+
304
+
305
+
306
+
307
+
test_challenge_code/architecture/Restormer.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numbers
5
+ from einops import rearrange
6
+
7
+
8
+ ##########################################################################
9
+ ## Layer Norm
10
+
11
+ def to_3d(x):
12
+ return rearrange(x, 'b c h w -> b (h w) c')
13
+
14
+
15
+ def to_4d(x, h, w):
16
+ return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
17
+
18
+
19
+ class BiasFree_LayerNorm(nn.Module):
20
+ def __init__(self, normalized_shape):
21
+ super(BiasFree_LayerNorm, self).__init__()
22
+ if isinstance(normalized_shape, numbers.Integral):
23
+ normalized_shape = (normalized_shape,)
24
+ normalized_shape = torch.Size(normalized_shape)
25
+
26
+ assert len(normalized_shape) == 1
27
+
28
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
29
+ self.normalized_shape = normalized_shape
30
+
31
+ def forward(self, x):
32
+ sigma = x.var(-1, keepdim=True, unbiased=False)
33
+ return x / torch.sqrt(sigma + 1e-5) * self.weight
34
+
35
+
36
+ class WithBias_LayerNorm(nn.Module):
37
+ def __init__(self, normalized_shape):
38
+ super(WithBias_LayerNorm, self).__init__()
39
+ if isinstance(normalized_shape, numbers.Integral):
40
+ normalized_shape = (normalized_shape,)
41
+ normalized_shape = torch.Size(normalized_shape)
42
+
43
+ assert len(normalized_shape) == 1
44
+
45
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
46
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
47
+ self.normalized_shape = normalized_shape
48
+
49
+ def forward(self, x):
50
+ mu = x.mean(-1, keepdim=True)
51
+ sigma = x.var(-1, keepdim=True, unbiased=False)
52
+ return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias
53
+
54
+
55
+ class LayerNorm(nn.Module):
56
+ def __init__(self, dim, LayerNorm_type):
57
+ super(LayerNorm, self).__init__()
58
+ if LayerNorm_type == 'BiasFree':
59
+ self.body = BiasFree_LayerNorm(dim)
60
+ else:
61
+ self.body = WithBias_LayerNorm(dim)
62
+
63
+ def forward(self, x):
64
+ h, w = x.shape[-2:]
65
+ return to_4d(self.body(to_3d(x)), h, w)
66
+
67
+
68
+ ##########################################################################
69
+ ## Gated-Dconv Feed-Forward Network (GDFN)
70
+ class FeedForward(nn.Module):
71
+ def __init__(self, dim, ffn_expansion_factor, bias):
72
+ super(FeedForward, self).__init__()
73
+
74
+ hidden_features = int(dim * ffn_expansion_factor)
75
+
76
+ self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias)
77
+
78
+ self.dwconv = nn.Conv2d(hidden_features * 2, hidden_features * 2, kernel_size=3, stride=1, padding=1,
79
+ groups=hidden_features * 2, bias=bias)
80
+
81
+ self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
82
+
83
+ def forward(self, x):
84
+ x = self.project_in(x)
85
+ x1, x2 = self.dwconv(x).chunk(2, dim=1)
86
+ x = F.gelu(x1) * x2
87
+ x = self.project_out(x)
88
+ return x
89
+
90
+
91
+ ##########################################################################
92
+ ## Multi-DConv Head Transposed Self-Attention (MDTA)
93
+ class Attention(nn.Module):
94
+ def __init__(self, dim, num_heads, bias):
95
+ super(Attention, self).__init__()
96
+ self.num_heads = num_heads
97
+ self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
98
+
99
+ self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
100
+ self.qkv_dwconv = nn.Conv2d(dim * 3, dim * 3, kernel_size=3, stride=1, padding=1, groups=dim * 3, bias=bias)
101
+ self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
102
+
103
+ def forward(self, x):
104
+ b, c, h, w = x.shape
105
+
106
+ qkv = self.qkv_dwconv(self.qkv(x))
107
+ q, k, v = qkv.chunk(3, dim=1)
108
+
109
+ q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
110
+ k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
111
+ v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
112
+
113
+ q = torch.nn.functional.normalize(q, dim=-1)
114
+ k = torch.nn.functional.normalize(k, dim=-1)
115
+
116
+ attn = (q @ k.transpose(-2, -1)) * self.temperature
117
+ attn = attn.softmax(dim=-1)
118
+
119
+ out = (attn @ v)
120
+
121
+ out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
122
+
123
+ out = self.project_out(out)
124
+ return out
125
+
126
+
127
+ ##########################################################################
128
+ class TransformerBlock(nn.Module):
129
+ def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type):
130
+ super(TransformerBlock, self).__init__()
131
+
132
+ self.norm1 = LayerNorm(dim, LayerNorm_type)
133
+ self.attn = Attention(dim, num_heads, bias)
134
+ self.norm2 = LayerNorm(dim, LayerNorm_type)
135
+ self.ffn = FeedForward(dim, ffn_expansion_factor, bias)
136
+
137
+ def forward(self, x):
138
+ x = x + self.attn(self.norm1(x))
139
+ x = x + self.ffn(self.norm2(x))
140
+
141
+ return x
142
+
143
+
144
+ ##########################################################################
145
+ ## Overlapped image patch embedding with 3x3 Conv
146
+ class OverlapPatchEmbed(nn.Module):
147
+ def __init__(self, in_c=3, embed_dim=48, bias=False):
148
+ super(OverlapPatchEmbed, self).__init__()
149
+
150
+ self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias)
151
+
152
+ def forward(self, x):
153
+ x = self.proj(x)
154
+
155
+ return x
156
+
157
+ def pixel_unshuffle(input, downscale_factor):
158
+ '''
159
+ input: batchSize * c * k*w * k*h
160
+ downscale_factor: k
161
+ batchSize * c * k*w * k*h -> batchSize * k*k*c * w * h
162
+ '''
163
+ c = input.shape[1]
164
+ kernel = torch.zeros(size = [downscale_factor * downscale_factor * c, 1, downscale_factor, downscale_factor],
165
+ device = input.device)
166
+ for y in range(downscale_factor):
167
+ for x in range(downscale_factor):
168
+ kernel[x + y * downscale_factor::downscale_factor * downscale_factor, 0, y, x] = 1
169
+ return F.conv2d(input, kernel, stride = downscale_factor, groups = c)
170
+
171
+ class PixelUnShuffle(nn.Module):
172
+ def __init__(self, downscale_factor):
173
+ super(PixelUnShuffle, self).__init__()
174
+ self.downscale_factor = downscale_factor
175
+
176
+ def forward(self, input):
177
+ '''
178
+ input: batchSize * c * k*w * k*h
179
+ downscale_factor: k
180
+ batchSize * c * k*w * k*h -> batchSize * k*k*c * w * h
181
+ '''
182
+ return pixel_unshuffle(input, self.downscale_factor)
183
+
184
+ ##########################################################################
185
+ ## Resizing modules
186
+ class Downsample(nn.Module):
187
+ def __init__(self, n_feat):
188
+ super(Downsample, self).__init__()
189
+
190
+ self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False),
191
+ PixelUnShuffle(2))
192
+
193
+ def forward(self, x):
194
+ return self.body(x)
195
+
196
+
197
+ class Upsample(nn.Module):
198
+ def __init__(self, n_feat):
199
+ super(Upsample, self).__init__()
200
+
201
+ self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False),
202
+ nn.PixelShuffle(2))
203
+
204
+ def forward(self, x):
205
+ return self.body(x)
206
+
207
+
208
+ ##########################################################################
209
+ ##---------- Restormer -----------------------
210
+ class Restormer(nn.Module):
211
+ def __init__(self,
212
+ inp_channels=3,
213
+ out_channels=31,
214
+ dim=48,
215
+ num_blocks=[2, 3, 3, 4],
216
+ num_refinement_blocks=3,
217
+ heads=[1, 2, 4, 8],
218
+ ffn_expansion_factor=2.66,
219
+ bias=False,
220
+ LayerNorm_type='WithBias', ## Other option 'BiasFree'
221
+ dual_pixel_task=True ## True for dual-pixel defocus deblurring only. Also set inp_channels=6
222
+ ):
223
+
224
+ super(Restormer, self).__init__()
225
+
226
+ self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
227
+
228
+ self.encoder_level1 = nn.Sequential(*[
229
+ TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias,
230
+ LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
231
+
232
+ self.down1_2 = Downsample(dim) ## From Level 1 to Level 2
233
+ self.encoder_level2 = nn.Sequential(*[
234
+ TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor,
235
+ bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
236
+
237
+ self.down2_3 = Downsample(int(dim * 2 ** 1)) ## From Level 2 to Level 3
238
+ self.encoder_level3 = nn.Sequential(*[
239
+ TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor,
240
+ bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])
241
+
242
+ self.down3_4 = Downsample(int(dim * 2 ** 2)) ## From Level 3 to Level 4
243
+ self.latent = nn.Sequential(*[
244
+ TransformerBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor,
245
+ bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[3])])
246
+
247
+ self.up4_3 = Upsample(int(dim * 2 ** 3)) ## From Level 4 to Level 3
248
+ self.reduce_chan_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=1, bias=bias)
249
+ self.decoder_level3 = nn.Sequential(*[
250
+ TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor,
251
+ bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])
252
+
253
+ self.up3_2 = Upsample(int(dim * 2 ** 2)) ## From Level 3 to Level 2
254
+ self.reduce_chan_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias)
255
+ self.decoder_level2 = nn.Sequential(*[
256
+ TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor,
257
+ bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
258
+
259
+ self.up2_1 = Upsample(int(dim * 2 ** 1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels)
260
+
261
+ self.decoder_level1 = nn.Sequential(*[
262
+ TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor,
263
+ bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
264
+
265
+ self.refinement = nn.Sequential(*[
266
+ TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor,
267
+ bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)])
268
+
269
+ #### For Dual-Pixel Defocus Deblurring Task ####
270
+ self.dual_pixel_task = dual_pixel_task
271
+ if self.dual_pixel_task:
272
+ self.skip_conv = nn.Conv2d(dim, int(dim * 2 ** 1), kernel_size=1, bias=bias)
273
+ ###########################
274
+
275
+ self.output = nn.Conv2d(int(dim * 2 ** 1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias)
276
+
277
+ def forward(self, inp_img):
278
+ b, c, h_inp, w_inp = inp_img.shape
279
+ hb, wb = 8, 8
280
+ pad_h = (hb - h_inp % hb) % hb
281
+ pad_w = (wb - w_inp % wb) % wb
282
+ inp_img = F.pad(inp_img, [0, pad_w, 0, pad_h], mode='reflect')
283
+
284
+ inp_enc_level1 = self.patch_embed(inp_img)
285
+ out_enc_level1 = self.encoder_level1(inp_enc_level1)
286
+
287
+ inp_enc_level2 = self.down1_2(out_enc_level1)
288
+ out_enc_level2 = self.encoder_level2(inp_enc_level2)
289
+
290
+ inp_enc_level3 = self.down2_3(out_enc_level2)
291
+ out_enc_level3 = self.encoder_level3(inp_enc_level3)
292
+
293
+ inp_enc_level4 = self.down3_4(out_enc_level3)
294
+ latent = self.latent(inp_enc_level4)
295
+
296
+ inp_dec_level3 = self.up4_3(latent)
297
+ inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1)
298
+ inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3)
299
+ out_dec_level3 = self.decoder_level3(inp_dec_level3)
300
+
301
+ inp_dec_level2 = self.up3_2(out_dec_level3)
302
+ inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1)
303
+ inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2)
304
+ out_dec_level2 = self.decoder_level2(inp_dec_level2)
305
+
306
+ inp_dec_level1 = self.up2_1(out_dec_level2)
307
+ inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1)
308
+ out_dec_level1 = self.decoder_level1(inp_dec_level1)
309
+
310
+ out_dec_level1 = self.refinement(out_dec_level1)
311
+
312
+ #### For Dual-Pixel Defocus Deblurring Task ####
313
+ if self.dual_pixel_task:
314
+ out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1)
315
+ out_dec_level1 = self.output(out_dec_level1)
316
+ ###########################
317
+ else:
318
+ out_dec_level1 = self.output(out_dec_level1) + inp_img
319
+
320
+ return out_dec_level1[:, :, :h_inp, :w_inp]
test_challenge_code/architecture/__init__.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .edsr import EDSR
3
+ from .HDNet import HDNet
4
+ from .hinet import HINet
5
+ from .hrnet import SGN
6
+ from .HSCNN_Plus import HSCNN_Plus
7
+ from .MIRNet import MIRNet
8
+ from .MPRNet import MPRNet
9
+ from .MST import MST
10
+ from .MST_Plus_Plus import MST_Plus_Plus
11
+ from .Restormer import Restormer
12
+
13
+ def model_generator(method, pretrained_model_path=None):
14
+ if method == 'mirnet':
15
+ model = MIRNet(n_RRG=3, n_MSRB=1, height=3, width=1).cuda()
16
+ elif method == 'mst_plus_plus':
17
+ model = MST_Plus_Plus().cuda()
18
+ elif method == 'mst':
19
+ model = MST(dim=31, stage=2, num_blocks=[4, 7, 5]).cuda()
20
+ elif method == 'hinet':
21
+ model = HINet(depth=4).cuda()
22
+ elif method == 'mprnet':
23
+ model = MPRNet(num_cab=4).cuda()
24
+ elif method == 'restormer':
25
+ model = Restormer().cuda()
26
+ elif method == 'edsr':
27
+ model = EDSR().cuda()
28
+ elif method == 'hdnet':
29
+ model = HDNet().cuda()
30
+ elif method == 'hrnet':
31
+ model = SGN().cuda()
32
+ elif method == 'hscnn_plus':
33
+ model = HSCNN_Plus().cuda()
34
+ else:
35
+ print(f'Method {method} is not defined !!!!')
36
+ if pretrained_model_path is not None:
37
+ print(f'load model from {pretrained_model_path}')
38
+ checkpoint = torch.load(pretrained_model_path)
39
+ model.load_state_dict({k.replace('module.', ''): v for k, v in checkpoint['state_dict'].items()},
40
+ strict=True)
41
+ return model
test_challenge_code/architecture/edsr.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ def default_conv(in_channels, out_channels, kernel_size, bias=True):
4
+ return nn.Conv2d(
5
+ in_channels, out_channels, kernel_size,
6
+ padding=(kernel_size//2), bias=bias)
7
+
8
+ class BasicBlock(nn.Sequential):
9
+ def __init__(
10
+ self, conv, in_channels, out_channels, kernel_size, stride=1, bias=False,
11
+ bn=True, act=nn.ReLU(True)):
12
+
13
+ m = [conv(in_channels, out_channels, kernel_size, bias=bias)]
14
+ if bn:
15
+ m.append(nn.BatchNorm2d(out_channels))
16
+ if act is not None:
17
+ m.append(act)
18
+
19
+ super(BasicBlock, self).__init__(*m)
20
+
21
+ class ResBlock(nn.Module):
22
+ def __init__(
23
+ self, conv, n_feats, kernel_size,
24
+ bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
25
+
26
+ super(ResBlock, self).__init__()
27
+ m = []
28
+ for i in range(2):
29
+ m.append(conv(n_feats, n_feats, kernel_size, bias=bias))
30
+ if bn:
31
+ m.append(nn.BatchNorm2d(n_feats))
32
+ if i == 0:
33
+ m.append(act)
34
+
35
+ self.body = nn.Sequential(*m)
36
+ self.res_scale = res_scale
37
+
38
+ def forward(self, x):
39
+ res = self.body(x).mul(self.res_scale)
40
+ res += x
41
+
42
+ return res
43
+
44
+
45
+
46
+ class EDSR(nn.Module):
47
+ def __init__(self, conv=default_conv):
48
+ super(EDSR, self).__init__()
49
+
50
+ n_resblocks = 32
51
+ n_feats = 64
52
+ kernel_size = 3
53
+ n_colors = 3
54
+ out_channels = 31
55
+ act = nn.ReLU(True)
56
+
57
+
58
+
59
+ # define head module
60
+ m_head = [conv(n_colors, n_feats, kernel_size)]
61
+
62
+ # define body module
63
+ m_body = [
64
+ ResBlock(
65
+ conv, n_feats, kernel_size, act=act, res_scale=1
66
+ ) for _ in range(n_resblocks)
67
+ ]
68
+ m_body.append(conv(n_feats, n_feats, kernel_size))
69
+
70
+ # define tail module
71
+ m_tail = [
72
+ conv(n_feats, out_channels, kernel_size)
73
+ ]
74
+
75
+ self.head = nn.Sequential(*m_head)
76
+ self.body = nn.Sequential(*m_body)
77
+ self.tail = nn.Sequential(*m_tail)
78
+
79
+ def forward(self, x):
80
+ x = self.head(x)
81
+
82
+ res = self.body(x)
83
+ res += x
84
+
85
+ x = self.tail(res)
86
+
87
+ return x
test_challenge_code/architecture/hinet.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ def conv3x3(in_chn, out_chn, bias=True):
6
+ layer = nn.Conv2d(in_chn, out_chn, kernel_size=3, stride=1, padding=1, bias=bias)
7
+ return layer
8
+
9
+ def conv_down(in_chn, out_chn, bias=False):
10
+ layer = nn.Conv2d(in_chn, out_chn, kernel_size=4, stride=2, padding=1, bias=bias)
11
+ return layer
12
+
13
+ def conv(in_channels, out_channels, kernel_size, bias=False, stride = 1):
14
+ return nn.Conv2d(
15
+ in_channels, out_channels, kernel_size,
16
+ padding=(kernel_size//2), bias=bias, stride = stride)
17
+
18
+ ## Supervised Attention Module
19
+ class SAM(nn.Module):
20
+ def __init__(self, n_feat, kernel_size=3, bias=True):
21
+ super(SAM, self).__init__()
22
+ self.conv1 = conv(n_feat, n_feat, kernel_size, bias=bias)
23
+ self.conv2 = conv(n_feat, n_feat, kernel_size, bias=bias)
24
+ self.conv3 = conv(n_feat, n_feat, kernel_size, bias=bias)
25
+
26
+ def forward(self, x, x_img):
27
+ x1 = self.conv1(x)
28
+ img = self.conv2(x) + x_img
29
+ x2 = torch.sigmoid(self.conv3(img))
30
+ x1 = x1*x2
31
+ x1 = x1+x
32
+ return x1, img
33
+
34
+ class HINet(nn.Module):
35
+
36
+ def __init__(self, in_chn=31, out_chn=31, wf=31, depth=4, relu_slope=0.2, hin_position_left=0, hin_position_right=4):
37
+ super(HINet, self).__init__()
38
+
39
+ self.conv_in = nn.Conv2d(3, in_chn, kernel_size=3, padding=(3 - 1) // 2,
40
+ bias=False)
41
+ self.depth = depth
42
+ self.down_path_1 = nn.ModuleList()
43
+ self.down_path_2 = nn.ModuleList()
44
+ self.conv_01 = nn.Conv2d(in_chn, wf, 3, 1, 1)
45
+ self.conv_02 = nn.Conv2d(in_chn, wf, 3, 1, 1)
46
+
47
+ prev_channels = self.get_input_chn(wf)
48
+ for i in range(depth): #0,1,2,3,4
49
+ use_HIN = True if hin_position_left <= i and i <= hin_position_right else False
50
+ downsample = True if (i+1) < depth else False
51
+ self.down_path_1.append(UNetConvBlock(prev_channels, (2**i) * wf, downsample, relu_slope, use_HIN=use_HIN))
52
+ self.down_path_2.append(UNetConvBlock(prev_channels, (2**i) * wf, downsample, relu_slope, use_csff=downsample, use_HIN=use_HIN))
53
+ prev_channels = (2**i) * wf
54
+
55
+ self.up_path_1 = nn.ModuleList()
56
+ self.up_path_2 = nn.ModuleList()
57
+ self.skip_conv_1 = nn.ModuleList()
58
+ self.skip_conv_2 = nn.ModuleList()
59
+ for i in reversed(range(depth - 1)):
60
+ self.up_path_1.append(UNetUpBlock(prev_channels, (2**i)*wf, relu_slope))
61
+ self.up_path_2.append(UNetUpBlock(prev_channels, (2**i)*wf, relu_slope))
62
+ self.skip_conv_1.append(nn.Conv2d((2**i)*wf, (2**i)*wf, 3, 1, 1))
63
+ self.skip_conv_2.append(nn.Conv2d((2**i)*wf, (2**i)*wf, 3, 1, 1))
64
+ prev_channels = (2**i)*wf
65
+ self.sam12 = SAM(prev_channels)
66
+ self.cat12 = nn.Conv2d(prev_channels*2, prev_channels, 1, 1, 0)
67
+
68
+ self.last = conv3x3(prev_channels, out_chn, bias=True)
69
+
70
+ def forward(self, x):
71
+
72
+ b, c, h_inp, w_inp = x.shape
73
+ hb, wb = 16, 16
74
+ pad_h = (hb - h_inp % hb) % hb
75
+ pad_w = (wb - w_inp % wb) % wb
76
+ x = F.pad(x, [0, pad_w, 0, pad_h], mode='reflect')
77
+
78
+ image = self.conv_in(x)
79
+
80
+ #stage 1
81
+ x1 = self.conv_01(image)
82
+ encs = []
83
+ decs = []
84
+ for i, down in enumerate(self.down_path_1):
85
+ if (i+1) < self.depth:
86
+ x1, x1_up = down(x1)
87
+ encs.append(x1_up)
88
+ else:
89
+ x1 = down(x1)
90
+
91
+ for i, up in enumerate(self.up_path_1):
92
+ x1 = up(x1, self.skip_conv_1[i](encs[-i-1]))
93
+ decs.append(x1)
94
+
95
+ sam_feature, out_1 = self.sam12(x1, image)
96
+ #stage 2
97
+ x2 = self.conv_02(image)
98
+ x2 = self.cat12(torch.cat([x2, sam_feature], dim=1))
99
+ blocks = []
100
+ for i, down in enumerate(self.down_path_2):
101
+ if (i+1) < self.depth:
102
+ x2, x2_up = down(x2, encs[i], decs[-i-1])
103
+ blocks.append(x2_up)
104
+ else:
105
+ x2 = down(x2)
106
+
107
+ for i, up in enumerate(self.up_path_2):
108
+ x2 = up(x2, self.skip_conv_2[i](blocks[-i-1]))
109
+
110
+ out_2 = self.last(x2)
111
+ out_2 = out_2 + image
112
+ return out_2[:, :, :h_inp, :w_inp]
113
+
114
+ def get_input_chn(self, in_chn):
115
+ return in_chn
116
+
117
+ def _initialize(self):
118
+ gain = nn.init.calculate_gain('leaky_relu', 0.20)
119
+ for m in self.modules():
120
+ if isinstance(m, nn.Conv2d):
121
+ nn.init.orthogonal_(m.weight, gain=gain)
122
+ if not m.bias is None:
123
+ nn.init.constant_(m.bias, 0)
124
+
125
+
126
+ class UNetConvBlock(nn.Module):
127
+ def __init__(self, in_size, out_size, downsample, relu_slope, use_csff=False, use_HIN=False):
128
+ super(UNetConvBlock, self).__init__()
129
+ self.downsample = downsample
130
+ self.identity = nn.Conv2d(in_size, out_size, 1, 1, 0)
131
+ self.use_csff = use_csff
132
+
133
+ self.conv_1 = nn.Conv2d(in_size, out_size, kernel_size=3, padding=1, bias=True)
134
+ self.relu_1 = nn.LeakyReLU(relu_slope, inplace=False)
135
+ self.conv_2 = nn.Conv2d(out_size, out_size, kernel_size=3, padding=1, bias=True)
136
+ self.relu_2 = nn.LeakyReLU(relu_slope, inplace=False)
137
+
138
+ if downsample and use_csff:
139
+ self.csff_enc = nn.Conv2d(out_size, out_size, 3, 1, 1)
140
+ self.csff_dec = nn.Conv2d(out_size, out_size, 3, 1, 1)
141
+
142
+ if use_HIN:
143
+ self.norm = nn.InstanceNorm2d(((out_size+1)//2), affine=True)
144
+ self.use_HIN = use_HIN
145
+
146
+ if downsample:
147
+ self.downsample = conv_down(out_size, out_size, bias=False)
148
+
149
+ def forward(self, x, enc=None, dec=None):
150
+ out = self.conv_1(x)
151
+
152
+ if self.use_HIN:
153
+ out_1, out_2 = torch.chunk(out, 2, dim=1)
154
+ out = torch.cat([self.norm(out_1), out_2], dim=1)
155
+ out = self.relu_1(out)
156
+ out = self.relu_2(self.conv_2(out))
157
+
158
+ out += self.identity(x)
159
+ if enc is not None and dec is not None:
160
+ assert self.use_csff
161
+ out = out + self.csff_enc(enc) + self.csff_dec(dec)
162
+ if self.downsample:
163
+ out_down = self.downsample(out)
164
+ return out_down, out
165
+ else:
166
+ return out
167
+
168
+
169
+ class UNetUpBlock(nn.Module):
170
+ def __init__(self, in_size, out_size, relu_slope):
171
+ super(UNetUpBlock, self).__init__()
172
+ self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2, bias=True)
173
+ self.conv_block = UNetConvBlock(in_size, out_size, False, relu_slope)
174
+
175
+ def forward(self, x, bridge):
176
+ up = self.up(x)
177
+ out = torch.cat([up, bridge], 1)
178
+ out = self.conv_block(out)
179
+ return out
180
+
181
+ class Subspace(nn.Module):
182
+
183
+ def __init__(self, in_size, out_size):
184
+ super(Subspace, self).__init__()
185
+ self.blocks = nn.ModuleList()
186
+ self.blocks.append(UNetConvBlock(in_size, out_size, False, 0.2))
187
+ self.shortcut = nn.Conv2d(in_size, out_size, kernel_size=1, bias=True)
188
+
189
+ def forward(self, x):
190
+ sc = self.shortcut(x)
191
+ for i in range(len(self.blocks)):
192
+ x = self.blocks[i](x)
193
+ return x + sc
194
+
195
+ class skip_blocks(nn.Module):
196
+
197
+ def __init__(self, in_size, out_size, repeat_num=1):
198
+ super(skip_blocks, self).__init__()
199
+ self.blocks = nn.ModuleList()
200
+ self.re_num = repeat_num
201
+ mid_c = 128
202
+ self.blocks.append(UNetConvBlock(in_size, mid_c, False, 0.2))
203
+ for i in range(self.re_num - 2):
204
+ self.blocks.append(UNetConvBlock(mid_c, mid_c, False, 0.2))
205
+ self.blocks.append(UNetConvBlock(mid_c, out_size, False, 0.2))
206
+ self.shortcut = nn.Conv2d(in_size, out_size, kernel_size=1, bias=True)
207
+
208
+ def forward(self, x):
209
+ sc = self.shortcut(x)
210
+ for m in self.blocks:
211
+ x = m(x)
212
+ return x + sc
test_challenge_code/architecture/hrnet.py ADDED
@@ -0,0 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn import Parameter
5
+ # ----------------------------------------
6
+ # Conv2d Block
7
+ # ----------------------------------------
8
+ class Conv2dLayer(nn.Module):
9
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, pad_type='zero',
10
+ activation='lrelu', norm='none', sn=False):
11
+ super(Conv2dLayer, self).__init__()
12
+ # Initialize the padding scheme
13
+ if pad_type == 'reflect':
14
+ self.pad = nn.ReflectionPad2d(padding)
15
+ elif pad_type == 'replicate':
16
+ self.pad = nn.ReplicationPad2d(padding)
17
+ elif pad_type == 'zero':
18
+ self.pad = nn.ZeroPad2d(padding)
19
+ else:
20
+ assert 0, "Unsupported padding type: {}".format(pad_type)
21
+
22
+ # Initialize the normalization type
23
+ if norm == 'bn':
24
+ self.norm = nn.BatchNorm2d(out_channels)
25
+ elif norm == 'in':
26
+ self.norm = nn.InstanceNorm2d(out_channels)
27
+ elif norm == 'ln':
28
+ self.norm = LayerNorm(out_channels)
29
+ elif norm == 'none':
30
+ self.norm = None
31
+ else:
32
+ assert 0, "Unsupported normalization: {}".format(norm)
33
+
34
+ # Initialize the activation funtion
35
+ if activation == 'relu':
36
+ self.activation = nn.ReLU(inplace=True)
37
+ elif activation == 'lrelu':
38
+ self.activation = nn.LeakyReLU(0.2, inplace=True)
39
+ elif activation == 'prelu':
40
+ self.activation = nn.PReLU()
41
+ elif activation == 'selu':
42
+ self.activation = nn.SELU(inplace=True)
43
+ elif activation == 'tanh':
44
+ self.activation = nn.Tanh()
45
+ elif activation == 'sigmoid':
46
+ self.activation = nn.Sigmoid()
47
+ elif activation == 'none':
48
+ self.activation = None
49
+ else:
50
+ assert 0, "Unsupported activation: {}".format(activation)
51
+
52
+ # Initialize the convolution layers
53
+ if sn:
54
+ self.conv2d = SpectralNorm(
55
+ nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=0, dilation=dilation))
56
+ else:
57
+ self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=0, dilation=dilation)
58
+
59
+ def forward(self, x):
60
+ x = self.pad(x)
61
+ x = self.conv2d(x)
62
+ if self.norm:
63
+ x = self.norm(x)
64
+ if self.activation:
65
+ x = self.activation(x)
66
+ return x
67
+
68
+
69
+ class TransposeConv2dLayer(nn.Module):
70
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, pad_type='zero',
71
+ activation='lrelu', norm='none', sn=False, scale_factor=2):
72
+ super(TransposeConv2dLayer, self).__init__()
73
+ # Initialize the conv scheme
74
+ self.scale_factor = scale_factor
75
+ self.conv2d = Conv2dLayer(in_channels, out_channels, kernel_size, stride, padding, dilation, pad_type,
76
+ activation, norm, sn)
77
+
78
+ def forward(self, x):
79
+ x = F.interpolate(x, scale_factor=self.scale_factor, mode='nearest')
80
+ x = self.conv2d(x)
81
+ return x
82
+
83
+
84
+ class ResConv2dLayer(nn.Module):
85
+ def __init__(self, in_channels, kernel_size=3, stride=1, padding=1, dilation=1, pad_type='zero', activation='lrelu',
86
+ norm='none', sn=False, scale_factor=2):
87
+ super(ResConv2dLayer, self).__init__()
88
+ # Initialize the conv scheme
89
+ self.conv2d = nn.Sequential(
90
+ Conv2dLayer(in_channels, in_channels, kernel_size, stride, padding, dilation, pad_type, activation, norm,
91
+ sn),
92
+ Conv2dLayer(in_channels, in_channels, kernel_size, stride, padding, dilation, pad_type, activation='none',
93
+ norm=norm, sn=sn)
94
+ )
95
+
96
+ def forward(self, x):
97
+ residual = x
98
+ out = self.conv2d(x)
99
+ out = 0.1 * out + residual
100
+ return out
101
+
102
+
103
+ class DenseConv2dLayer_5C(nn.Module):
104
+ def __init__(self, in_channels, latent_channels, kernel_size=3, stride=1, padding=1, dilation=1, pad_type='zero',
105
+ activation='lrelu', norm='none', sn=False):
106
+ super(DenseConv2dLayer_5C, self).__init__()
107
+ # dense convolutions
108
+ self.conv1 = Conv2dLayer(in_channels, latent_channels, kernel_size, stride, padding, dilation, pad_type,
109
+ activation, norm, sn)
110
+ self.conv2 = Conv2dLayer(in_channels + latent_channels, latent_channels, kernel_size, stride, padding, dilation,
111
+ pad_type, activation, norm, sn)
112
+ self.conv3 = Conv2dLayer(in_channels + latent_channels * 2, latent_channels, kernel_size, stride, padding,
113
+ dilation, pad_type, activation, norm, sn)
114
+ self.conv4 = Conv2dLayer(in_channels + latent_channels * 3, latent_channels, kernel_size, stride, padding,
115
+ dilation, pad_type, activation, norm, sn)
116
+ self.conv5 = Conv2dLayer(in_channels + latent_channels * 4, in_channels, kernel_size, stride, padding, dilation,
117
+ pad_type, activation, norm, sn)
118
+
119
+ def forward(self, x):
120
+ x1 = self.conv1(x)
121
+ x2 = self.conv2(torch.cat((x, x1), 1))
122
+ x3 = self.conv3(torch.cat((x, x1, x2), 1))
123
+ x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
124
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
125
+ return x5
126
+
127
+
128
+ class ResidualDenseBlock_5C(nn.Module):
129
+ def __init__(self, in_channels, latent_channels, kernel_size=3, stride=1, padding=1, dilation=1, pad_type='zero',
130
+ activation='lrelu', norm='none', sn=False):
131
+ super(ResidualDenseBlock_5C, self).__init__()
132
+ # dense convolutions
133
+ self.conv1 = Conv2dLayer(in_channels, latent_channels, kernel_size, stride, padding, dilation, pad_type,
134
+ activation, norm, sn)
135
+ self.conv2 = Conv2dLayer(in_channels + latent_channels, latent_channels, kernel_size, stride, padding, dilation,
136
+ pad_type, activation, norm, sn)
137
+ self.conv3 = Conv2dLayer(in_channels + latent_channels * 2, latent_channels, kernel_size, stride, padding,
138
+ dilation, pad_type, activation, norm, sn)
139
+ self.conv4 = Conv2dLayer(in_channels + latent_channels * 3, latent_channels, kernel_size, stride, padding,
140
+ dilation, pad_type, activation, norm, sn)
141
+ self.conv5 = Conv2dLayer(in_channels + latent_channels * 4, in_channels, kernel_size, stride, padding, dilation,
142
+ pad_type, activation, norm, sn)
143
+
144
+ def forward(self, x):
145
+ residual = x
146
+ x1 = self.conv1(x)
147
+ x2 = self.conv2(torch.cat((x, x1), 1))
148
+ x3 = self.conv3(torch.cat((x, x1, x2), 1))
149
+ x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
150
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
151
+ x5 = 0.1 * x5 + residual
152
+ return x5
153
+
154
+
155
+ # ----------------------------------------
156
+ # Layer Norm
157
+ # ----------------------------------------
158
+ class LayerNorm(nn.Module):
159
+ def __init__(self, num_features, eps=1e-8, affine=True):
160
+ super(LayerNorm, self).__init__()
161
+ self.num_features = num_features
162
+ self.affine = affine
163
+ self.eps = eps
164
+
165
+ if self.affine:
166
+ self.gamma = Parameter(torch.Tensor(num_features).uniform_())
167
+ self.beta = Parameter(torch.zeros(num_features))
168
+
169
+ def forward(self, x):
170
+ # layer norm
171
+ shape = [-1] + [1] * (x.dim() - 1) # for 4d input: [-1, 1, 1, 1]
172
+ if x.size(0) == 1:
173
+ # These two lines run much faster in pytorch 0.4 than the two lines listed below.
174
+ mean = x.view(-1).mean().view(*shape)
175
+ std = x.view(-1).std().view(*shape)
176
+ else:
177
+ mean = x.view(x.size(0), -1).mean(1).view(*shape)
178
+ std = x.view(x.size(0), -1).std(1).view(*shape)
179
+ x = (x - mean) / (std + self.eps)
180
+ # if it is learnable
181
+ if self.affine:
182
+ shape = [1, -1] + [1] * (x.dim() - 2) # for 4d input: [1, -1, 1, 1]
183
+ x = x * self.gamma.view(*shape) + self.beta.view(*shape)
184
+ return x
185
+
186
+
187
+ # ----------------------------------------
188
+ # Spectral Norm Block
189
+ # ----------------------------------------
190
+ def l2normalize(v, eps=1e-12):
191
+ return v / (v.norm() + eps)
192
+
193
+
194
+ class SpectralNorm(nn.Module):
195
+ def __init__(self, module, name='weight', power_iterations=1):
196
+ super(SpectralNorm, self).__init__()
197
+ self.module = module
198
+ self.name = name
199
+ self.power_iterations = power_iterations
200
+ if not self._made_params():
201
+ self._make_params()
202
+
203
+ def _update_u_v(self):
204
+ u = getattr(self.module, self.name + "_u")
205
+ v = getattr(self.module, self.name + "_v")
206
+ w = getattr(self.module, self.name + "_bar")
207
+
208
+ height = w.data.shape[0]
209
+ for _ in range(self.power_iterations):
210
+ v.data = l2normalize(torch.mv(torch.t(w.view(height, -1).data), u.data))
211
+ u.data = l2normalize(torch.mv(w.view(height, -1).data, v.data))
212
+
213
+ # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
214
+ sigma = u.dot(w.view(height, -1).mv(v))
215
+ setattr(self.module, self.name, w / sigma.expand_as(w))
216
+
217
+ def _made_params(self):
218
+ try:
219
+ u = getattr(self.module, self.name + "_u")
220
+ v = getattr(self.module, self.name + "_v")
221
+ w = getattr(self.module, self.name + "_bar")
222
+ return True
223
+ except AttributeError:
224
+ return False
225
+
226
+ def _make_params(self):
227
+ w = getattr(self.module, self.name)
228
+
229
+ height = w.data.shape[0]
230
+ width = w.view(height, -1).data.shape[1]
231
+
232
+ u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
233
+ v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
234
+ u.data = l2normalize(u.data)
235
+ v.data = l2normalize(v.data)
236
+ w_bar = Parameter(w.data)
237
+
238
+ del self.module._parameters[self.name]
239
+
240
+ self.module.register_parameter(self.name + "_u", u)
241
+ self.module.register_parameter(self.name + "_v", v)
242
+ self.module.register_parameter(self.name + "_bar", w_bar)
243
+
244
+ def forward(self, *args):
245
+ self._update_u_v()
246
+ return self.module.forward(*args)
247
+
248
+
249
+ # ----------------------------------------
250
+ # Non-local Block
251
+ # ----------------------------------------
252
+ class Self_Attn(nn.Module):
253
+ """ Self attention Layer for Feature Map dimension"""
254
+
255
+ def __init__(self, in_dim, latent_dim=8):
256
+ super(Self_Attn, self).__init__()
257
+ self.channel_in = in_dim
258
+ self.channel_latent = in_dim // latent_dim
259
+ self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // latent_dim, kernel_size=1)
260
+ self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // latent_dim, kernel_size=1)
261
+ self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
262
+ self.gamma = nn.Parameter(torch.zeros(1))
263
+ self.softmax = nn.Softmax(dim=-1)
264
+
265
+ def forward(self, x):
266
+ """
267
+ inputs :
268
+ x : input feature maps(B X C X H X W)
269
+ returns :
270
+ out : self attention value + input feature
271
+ attention: B X N X N (N is Height * Width)
272
+ """
273
+ batchsize, C, height, width = x.size()
274
+ # proj_query: reshape to B x N x c, N = H x W
275
+ proj_query = self.query_conv(x).view(batchsize, -1, height * width).permute(0, 2, 1)
276
+ # proj_query: reshape to B x c x N, N = H x W
277
+ proj_key = self.key_conv(x).view(batchsize, -1, height * width)
278
+ # transpose check, energy: B x N x N, N = H x W
279
+ energy = torch.bmm(proj_query, proj_key)
280
+ # attention: B x N x N, N = H x W
281
+ attention = self.softmax(energy)
282
+ # proj_value is normal convolution, B x C x N
283
+ proj_value = self.value_conv(x).view(batchsize, -1, height * width)
284
+ # out: B x C x N
285
+ out = torch.bmm(proj_value, attention.permute(0, 2, 1))
286
+ out = out.view(batchsize, C, height, width)
287
+
288
+ out = self.gamma * out + x
289
+ return out
290
+
291
+
292
+ # ----------------------------------------
293
+ # Global Block
294
+ # ----------------------------------------
295
+ class SELayer(nn.Module):
296
+ def __init__(self, channel, reduction=16):
297
+ super(SELayer, self).__init__()
298
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
299
+ self.fc = nn.Sequential(
300
+ nn.Linear(channel, channel // reduction, bias=False),
301
+ nn.ReLU(inplace=True),
302
+ nn.Linear(channel // reduction, channel // reduction, bias=False),
303
+ nn.ReLU(inplace=True),
304
+ nn.Linear(channel // reduction, channel, bias=False),
305
+ nn.Sigmoid()
306
+ )
307
+
308
+ def forward(self, x):
309
+ b, c, _, _ = x.size()
310
+ y = self.avg_pool(x).view(b, c)
311
+ y = self.fc(y).view(b, c, 1, 1)
312
+ return x * y.expand_as(x)
313
+
314
+
315
+ class GlobalBlock(nn.Module):
316
+ def __init__(self, in_channels, kernel_size, stride=1, padding=0, dilation=1, pad_type='zero', activation='lrelu',
317
+ norm='none', sn=False, reduction=8):
318
+ super(GlobalBlock, self).__init__()
319
+ self.conv1 = Conv2dLayer(in_channels, in_channels, kernel_size, stride, padding, dilation, pad_type, activation,
320
+ norm, sn)
321
+ self.conv2 = Conv2dLayer(in_channels, in_channels, kernel_size, stride, padding, dilation, pad_type, activation,
322
+ norm, sn)
323
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
324
+ self.fc = nn.Sequential(
325
+ nn.Linear(in_channels, in_channels // reduction, bias=False),
326
+ nn.ReLU(inplace=True),
327
+ nn.Linear(in_channels // reduction, in_channels // reduction, bias=False),
328
+ nn.ReLU(inplace=True),
329
+ nn.Linear(in_channels // reduction, in_channels, bias=False),
330
+ nn.Sigmoid()
331
+ )
332
+
333
+ def forward(self, x):
334
+ # residual
335
+ residual = x
336
+ # Sequeeze-and-Excitation(SE)
337
+ b, c, _, _ = x.size()
338
+ x = self.conv1(x)
339
+ y = self.avg_pool(x).view(b, c)
340
+ y = self.fc(y).view(b, c, 1, 1)
341
+ y = x * y.expand_as(x)
342
+ y = self.conv2(x)
343
+ # addition
344
+ out = 0.1 * y + residual
345
+ return out
346
+ def pixel_unshuffle(input, downscale_factor):
347
+ '''
348
+ input: batchSize * c * k*w * k*h
349
+ downscale_factor: k
350
+ batchSize * c * k*w * k*h -> batchSize * k*k*c * w * h
351
+ '''
352
+ c = input.shape[1]
353
+ kernel = torch.zeros(size = [downscale_factor * downscale_factor * c, 1, downscale_factor, downscale_factor],
354
+ device = input.device)
355
+ for y in range(downscale_factor):
356
+ for x in range(downscale_factor):
357
+ kernel[x + y * downscale_factor::downscale_factor * downscale_factor, 0, y, x] = 1
358
+ return F.conv2d(input, kernel, stride = downscale_factor, groups = c)
359
+
360
+ class PixelUnShuffle(nn.Module):
361
+ def __init__(self, downscale_factor):
362
+ super(PixelUnShuffle, self).__init__()
363
+ self.downscale_factor = downscale_factor
364
+
365
+ def forward(self, input):
366
+ '''
367
+ input: batchSize * c * k*w * k*h
368
+ downscale_factor: k
369
+ batchSize * c * k*w * k*h -> batchSize * k*k*c * w * h
370
+ '''
371
+ return pixel_unshuffle(input, self.downscale_factor)
372
+
373
+ # ----------------------------------------
374
+ # Initialize the networks
375
+ # ----------------------------------------
376
+ def weights_init(net, init_type = 'normal', init_gain = 0.02):
377
+ """Initialize network weights.
378
+ Parameters:
379
+ net (network) -- network to be initialized
380
+ init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
381
+ init_gain (float) -- scaling factor for normal, xavier and orthogonal
382
+ In our paper, we choose the default setting: zero mean Gaussian distribution with a standard deviation of 0.02
383
+ """
384
+ def init_func(m):
385
+ classname = m.__class__.__name__
386
+ if hasattr(m, 'weight') and classname.find('Conv') != -1:
387
+ if init_type == 'normal':
388
+ torch.nn.init.normal_(m.weight.data, 0.0, init_gain)
389
+ elif init_type == 'xavier':
390
+ torch.nn.init.xavier_normal_(m.weight.data, gain = init_gain)
391
+ elif init_type == 'kaiming':
392
+ torch.nn.init.kaiming_normal_(m.weight.data, a = 0, mode = 'fan_in')
393
+ elif init_type == 'orthogonal':
394
+ torch.nn.init.orthogonal_(m.weight.data, gain = init_gain)
395
+ else:
396
+ raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
397
+ elif classname.find('BatchNorm2d') != -1:
398
+ torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
399
+ torch.nn.init.constant_(m.bias.data, 0.0)
400
+
401
+ # apply the initialization function <init_func>
402
+ print('initialize network with %s type' % init_type)
403
+ net.apply(init_func)
404
+
405
+ # ----------------------------------------
406
+ # Generator
407
+ # ----------------------------------------
408
+ class SGN(nn.Module):
409
+ def __init__(self, in_channels=3, out_channels=31, start_channels=64, pad='zero', activ='lrelu', norm='none', ):
410
+ super(SGN, self).__init__()
411
+ # Top subnetwork, K = 3
412
+ self.top1 = Conv2dLayer(in_channels * (4 ** 3), start_channels * (2 ** 3), 3, 1, 1, pad_type = pad, activation = activ, norm = norm)
413
+ self.top21 = ResidualDenseBlock_5C(start_channels * (2 ** 3), start_channels * (2 ** 2), 3, 1, 1, pad_type = pad, activation = activ, norm = norm)
414
+ self.top22 = GlobalBlock(start_channels * (2 ** 3), 3, 1, 1, pad_type = pad, activation = activ, norm = norm, sn = False, reduction = 4)
415
+ self.top3 = Conv2dLayer(start_channels * (2 ** 3), start_channels * (2 ** 3), 1, 1, 0, pad_type = pad, activation = activ, norm = norm)
416
+ # Middle subnetwork, K = 2
417
+ self.mid1 = Conv2dLayer(in_channels * (4 ** 2), start_channels * (2 ** 2), 3, 1, 1, pad_type = pad, activation = activ, norm = norm)
418
+ self.mid2 = Conv2dLayer(int(start_channels * (2 ** 2 + 2 ** 3 / 4)), start_channels * (2 ** 2), 3, 1, 1, pad_type = pad, activation = activ, norm = norm)
419
+ self.mid31 = ResidualDenseBlock_5C(start_channels * (2 ** 2), start_channels * (2 ** 1), 3, 1, 1, pad_type = pad, activation = activ, norm = norm)
420
+ self.mid32 = GlobalBlock(start_channels * (2 ** 2), 3, 1, 1, pad_type = pad, activation = activ, norm = norm, sn = False, reduction = 4)
421
+ self.mid4 = Conv2dLayer(start_channels * (2 ** 2), start_channels * (2 ** 2), 3, 1, 1, pad_type = pad, activation = activ, norm = norm)
422
+ # Bottom subnetwork, K = 1
423
+ self.bot1 = Conv2dLayer(in_channels * (4 ** 1), start_channels * (2 ** 1), 3, 1, 1, pad_type = pad, activation = activ, norm = norm)
424
+ self.bot2 = Conv2dLayer(int(start_channels * (2 ** 1 + 2 ** 2 / 4)), start_channels * (2 ** 1), 3, 1, 1, pad_type = pad, activation = activ, norm = norm)
425
+ self.bot31 = ResidualDenseBlock_5C(start_channels * (2 ** 1), start_channels * (2 ** 0), 3, 1, 1, pad_type = pad, activation = activ, norm = norm)
426
+ self.bot32 = ResidualDenseBlock_5C(start_channels * (2 ** 1), start_channels * (2 ** 0), 3, 1, 1, pad_type = pad, activation = activ, norm = norm)
427
+ self.bot33 = GlobalBlock(start_channels * (2 ** 1), 3, 1, 1, pad_type = pad, activation = activ, norm = norm, sn = False, reduction = 4)
428
+ self.bot4 = Conv2dLayer(start_channels * (2 ** 1), start_channels * (2 ** 1), 3, 1, 1, pad_type = pad, activation = activ, norm = norm)
429
+ # Mainstream
430
+ self.main1 = Conv2dLayer(in_channels, start_channels, 3, 1, 1, pad_type = pad, activation = activ, norm = norm)
431
+ self.main2 = Conv2dLayer(int(start_channels * (2 ** 0 + 2 ** 1 / 4)), start_channels, 3, 1, 1, pad_type = pad, activation = activ, norm = norm)
432
+ self.main31 = ResidualDenseBlock_5C(start_channels, start_channels // 2, 3, 1, 1, pad_type = pad, activation = activ, norm = norm)
433
+ self.main32 = ResidualDenseBlock_5C(start_channels, start_channels // 2, 3, 1, 1, pad_type = pad, activation = activ, norm = norm)
434
+ self.main33 = ResidualDenseBlock_5C(start_channels, start_channels // 2, 3, 1, 1, pad_type = pad, activation = activ, norm = norm)
435
+ self.main34 = ResidualDenseBlock_5C(start_channels, start_channels // 2, 3, 1, 1, pad_type = pad, activation = activ, norm = norm)
436
+ self.main35 = GlobalBlock(start_channels, 3, 1, 1, pad_type = pad, activation = activ, norm = norm, sn = False, reduction = 4)
437
+ self.main4 = Conv2dLayer(start_channels, out_channels, 3, 1, 1, pad_type = pad, activation = activ, norm = norm)
438
+
439
+ def forward(self, x):
440
+ b, c, h_inp, w_inp = x.shape
441
+ hb, wb = 8, 8
442
+ pad_h = (hb - h_inp % hb) % hb
443
+ pad_w = (wb - w_inp % wb) % wb
444
+ x = F.pad(x, [0, pad_w, 0, pad_h], mode='reflect')
445
+ # PixelUnShuffle input: batch * 3 * 256 * 256
446
+ x1 = pixel_unshuffle(x, 2) # out: batch * 12 * 128 * 128
447
+ x2 = pixel_unshuffle(x, 4) # out: batch * 48 * 64 * 64
448
+ x3 = pixel_unshuffle(x, 8) # out: batch * 192 * 32 * 32
449
+ # Top subnetwork suppose the start_channels = 32
450
+ x3 = self.top1(x3) # out: batch * 256 * 32 * 32
451
+ x3 = self.top21(x3) # out: batch * 256 * 32 * 32
452
+ x3 = self.top22(x3) # out: batch * 256 * 32 * 32
453
+ x3 = self.top3(x3) # out: batch * 256 * 32 * 32
454
+ x3 = F.pixel_shuffle(x3, 2) # out: batch * 64 * 64 * 64, ready to be concatenated
455
+ # Middle subnetwork
456
+ x2 = self.mid1(x2) # out: batch * 128 * 64 * 64
457
+ x2 = torch.cat((x2, x3), 1) # out: batch * (128 + 64) * 64 * 64
458
+ x2 = self.mid2(x2) # out: batch * 128 * 64 * 64
459
+ x2 = self.mid31(x2) # out: batch * 128 * 64 * 64
460
+ x2 = self.mid32(x2) # out: batch * 128 * 64 * 64
461
+ x2 = self.mid4(x2) # out: batch * 128 * 64 * 64
462
+ x2 = F.pixel_shuffle(x2, 2) # out: batch * 32 * 128 * 128, ready to be concatenated
463
+ # Bottom subnetwork
464
+ x1 = self.bot1(x1) # out: batch * 64 * 128 * 128
465
+ x1 = torch.cat((x1, x2), 1) # out: batch * (64 + 32) * 128 * 128
466
+ x1 = self.bot2(x1) # out: batch * 64 * 128 * 128
467
+ x1 = self.bot31(x1) # out: batch * 64 * 128 * 128
468
+ x1 = self.bot32(x1) # out: batch * 64 * 128 * 128
469
+ x1 = self.bot33(x1) # out: batch * 64 * 128 * 128
470
+ x1 = self.bot4(x1) # out: batch * 64 * 128 * 128
471
+ x1 = F.pixel_shuffle(x1, 2) # out: batch * 16 * 256 * 256, ready to be concatenated
472
+ # U-Net generator with skip connections from encoder to decoder
473
+ x = self.main1(x) # out: batch * 32 * 256 * 256
474
+ x = torch.cat((x, x1), 1) # out: batch * (32 + 16) * 256 * 256
475
+ x = self.main2(x) # out: batch * 32 * 256 * 256
476
+ x = self.main31(x) # out: batch * 32 * 256 * 256
477
+ x = self.main32(x) # out: batch * 32 * 256 * 256
478
+ x = self.main33(x) # out: batch * 32 * 256 * 256
479
+ x = self.main34(x) # out: batch * 32 * 256 * 256
480
+ x = self.main35(x) # out: batch * 32 * 256 * 256
481
+ x = self.main4(x) # out: batch * 3 * 256 * 256
482
+
483
+ return x[:, :, :h_inp, :w_inp]
484
+