Upload 26 files
Browse files- .gitattributes +6 -0
- app.py +125 -0
- assets/ARAD_1K_0001.jpg +0 -0
- assets/ARAD_1K_0001.mat +3 -0
- assets/ARAD_1K_0002.jpg +0 -0
- assets/ARAD_1K_0002.mat +3 -0
- assets/ARAD_1K_0003.jpg +0 -0
- assets/ARAD_1K_0003.mat +3 -0
- assets/ARAD_1K_0004.jpg +0 -0
- assets/ARAD_1K_0004.mat +3 -0
- assets/ARAD_1K_0005.jpg +0 -0
- assets/ARAD_1K_0005.mat +3 -0
- assets/ARAD_1K_0006.jpg +0 -0
- assets/ARAD_1K_0006.mat +3 -0
- mst_plus_plus.pth +3 -0
- requirements.txt +12 -0
- test_challenge_code/architecture/HDNet.py +397 -0
- test_challenge_code/architecture/HSCNN_Plus.py +77 -0
- test_challenge_code/architecture/MIRNet.py +416 -0
- test_challenge_code/architecture/MPRNet.py +350 -0
- test_challenge_code/architecture/MST.py +313 -0
- test_challenge_code/architecture/MST_Plus_Plus.py +307 -0
- test_challenge_code/architecture/Restormer.py +320 -0
- test_challenge_code/architecture/__init__.py +41 -0
- test_challenge_code/architecture/edsr.py +87 -0
- test_challenge_code/architecture/hinet.py +212 -0
- test_challenge_code/architecture/hrnet.py +484 -0
.gitattributes
CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/ARAD_1K_0001.mat filter=lfs diff=lfs merge=lfs -text
|
37 |
+
assets/ARAD_1K_0002.mat filter=lfs diff=lfs merge=lfs -text
|
38 |
+
assets/ARAD_1K_0003.mat filter=lfs diff=lfs merge=lfs -text
|
39 |
+
assets/ARAD_1K_0004.mat filter=lfs diff=lfs merge=lfs -text
|
40 |
+
assets/ARAD_1K_0005.mat filter=lfs diff=lfs merge=lfs -text
|
41 |
+
assets/ARAD_1K_0006.mat filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import gradio as gr
|
4 |
+
import cv2
|
5 |
+
import h5py
|
6 |
+
from test_develop_code.architecture import model_generator
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
|
10 |
+
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
11 |
+
device = torch.device("cpu")
|
12 |
+
model = model_generator("mst_plus_plus", "mst_plus_plus.pth").to(device)
|
13 |
+
model.eval()
|
14 |
+
wavelengths = np.linspace(400, 700, 31)
|
15 |
+
|
16 |
+
|
17 |
+
def wavelength_to_rgb(wl: float) -> tuple:
|
18 |
+
if 380 <= wl <= 440:
|
19 |
+
R = -(wl - 440) / (440 - 380)
|
20 |
+
G = 0.0
|
21 |
+
B = 1.0
|
22 |
+
elif 440 < wl <= 490:
|
23 |
+
R = 0.0
|
24 |
+
G = (wl - 440) / (490 - 440)
|
25 |
+
B = 1.0
|
26 |
+
elif 490 < wl <= 510:
|
27 |
+
R = 0.0
|
28 |
+
G = 1.0
|
29 |
+
B = -(wl - 510) / (510 - 490)
|
30 |
+
elif 510 < wl <= 580:
|
31 |
+
R = (wl - 510) / (580 - 510)
|
32 |
+
G = 1.0
|
33 |
+
B = 0.0
|
34 |
+
elif 580 < wl <= 645:
|
35 |
+
R = 1.0
|
36 |
+
G = -(wl - 645) / (645 - 580)
|
37 |
+
B = 0.0
|
38 |
+
elif 645 < wl <= 700:
|
39 |
+
R = 1.0
|
40 |
+
G = 0.0
|
41 |
+
B = 0.0
|
42 |
+
else:
|
43 |
+
R = G = B = 0.0
|
44 |
+
return (max(R, 0.0), max(G, 0.0), max(B, 0.0))
|
45 |
+
|
46 |
+
|
47 |
+
def predict(img):
|
48 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
49 |
+
img = img.astype(np.float32) / 255.0
|
50 |
+
img = np.transpose(img, (2, 0, 1))
|
51 |
+
img_tensor = torch.from_numpy(img).unsqueeze(0).to(device)
|
52 |
+
with torch.no_grad():
|
53 |
+
pred = model(img_tensor)
|
54 |
+
pred = pred.squeeze(0).cpu().numpy()
|
55 |
+
pred = np.clip(pred, 0, 1)
|
56 |
+
return pred
|
57 |
+
|
58 |
+
|
59 |
+
def visualize_channel(cube: np.ndarray, index: int) -> Image.Image:
|
60 |
+
if cube is None:
|
61 |
+
return None
|
62 |
+
band = cube[index]
|
63 |
+
band = (band - band.min()) / (band.max() - band.min() + 1e-8)
|
64 |
+
color = wavelength_to_rgb(wavelengths[index])
|
65 |
+
rgb = np.stack([band * c for c in color], axis=-1)
|
66 |
+
rgb = (rgb * 255).astype(np.uint8)
|
67 |
+
return Image.fromarray(rgb)
|
68 |
+
|
69 |
+
|
70 |
+
def load_mat(mat_file) -> np.ndarray:
|
71 |
+
with h5py.File(mat_file.name, "r") as f:
|
72 |
+
cube = np.array(f["cube"])
|
73 |
+
cube = np.transpose(cube, (0, 2, 1))
|
74 |
+
cube = np.clip(cube, 0, 1)
|
75 |
+
return cube
|
76 |
+
|
77 |
+
|
78 |
+
def reset_all():
|
79 |
+
return None, None, None, None, 0
|
80 |
+
|
81 |
+
|
82 |
+
with gr.Blocks() as demo:
|
83 |
+
gr.Markdown("## Spectral Reconstruction")
|
84 |
+
|
85 |
+
with gr.Row():
|
86 |
+
with gr.Column():
|
87 |
+
rgb_input = gr.Image(type="numpy", label="Upload RGB Image")
|
88 |
+
pred_state = gr.State()
|
89 |
+
with gr.Column():
|
90 |
+
pred_output = gr.Image(label="Prediction Visualization")
|
91 |
+
pred_slider = gr.Slider(minimum=0, maximum=30, step=1, label="Channel (Prediction)", value=0)
|
92 |
+
|
93 |
+
with gr.Row():
|
94 |
+
with gr.Column():
|
95 |
+
mat_input = gr.File(label="Upload .mat file (Ground Truth)")
|
96 |
+
gt_state = gr.State()
|
97 |
+
with gr.Column():
|
98 |
+
gt_output = gr.Image(label="Ground Truth Visualization")
|
99 |
+
gt_slider = gr.Slider(minimum=0, maximum=30, step=1, label="Channel (Ground Truth)", value=0)
|
100 |
+
|
101 |
+
clear_btn = gr.Button("Clear")
|
102 |
+
rgb_input.change(fn=predict, inputs=rgb_input, outputs=pred_state)
|
103 |
+
pred_slider.change(fn=visualize_channel, inputs=[pred_state, pred_slider], outputs=pred_output)
|
104 |
+
|
105 |
+
mat_input.change(fn=load_mat, inputs=mat_input, outputs=gt_state)
|
106 |
+
gt_slider.change(fn=visualize_channel, inputs=[gt_state, gt_slider], outputs=gt_output)
|
107 |
+
|
108 |
+
clear_btn.click(fn=reset_all, outputs=[rgb_input, pred_output, pred_state, gt_state, mat_input])
|
109 |
+
|
110 |
+
gr.Examples(
|
111 |
+
examples=[
|
112 |
+
["assets/ARAD_1K_0001.jpg", 0, "assets/ARAD_1K_0001.mat", 0],
|
113 |
+
["assets/ARAD_1K_0002.jpg", 0, "assets/ARAD_1K_0002.mat", 0],
|
114 |
+
["assets/ARAD_1K_0003.jpg", 0, "assets/ARAD_1K_0003.mat", 0],
|
115 |
+
["assets/ARAD_1K_0004.jpg", 0, "assets/ARAD_1K_0004.mat", 0],
|
116 |
+
["assets/ARAD_1K_0005.jpg", 0, "assets/ARAD_1K_0005.mat", 0],
|
117 |
+
],
|
118 |
+
inputs=[rgb_input, pred_slider, mat_input, gt_slider],
|
119 |
+
outputs=[pred_output, gt_output],
|
120 |
+
label="Try Examples"
|
121 |
+
)
|
122 |
+
|
123 |
+
|
124 |
+
if __name__ == "__main__":
|
125 |
+
demo.launch()
|
assets/ARAD_1K_0001.jpg
ADDED
![]() |
assets/ARAD_1K_0001.mat
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d75d3689827d69fdb600ef49eceec4692615db5f0bb8882b41cc3cc0f29a139f
|
3 |
+
size 21896837
|
assets/ARAD_1K_0002.jpg
ADDED
![]() |
assets/ARAD_1K_0002.mat
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5ee379c0574ac7a225cdd6cd39f942441f48d3401f4d1abe0e26a664ce8f1af3
|
3 |
+
size 22920548
|
assets/ARAD_1K_0003.jpg
ADDED
![]() |
assets/ARAD_1K_0003.mat
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0729ab86549ca910803e2d0dcdc76c10087f9a0f4efbb252e08b564bd3a9a741
|
3 |
+
size 21218912
|
assets/ARAD_1K_0004.jpg
ADDED
![]() |
assets/ARAD_1K_0004.mat
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:abe2de20dc3b76e1a1701292e7ad1774aef643a88366084eef58d92f7e280a6d
|
3 |
+
size 22377073
|
assets/ARAD_1K_0005.jpg
ADDED
![]() |
assets/ARAD_1K_0005.mat
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6fd2f928230e076be0d2c707b657a5e6cdf84108cd53d86aab91450d8ce4222b
|
3 |
+
size 21042006
|
assets/ARAD_1K_0006.jpg
ADDED
![]() |
assets/ARAD_1K_0006.mat
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:32d2e4ff2732cd16bf3a9103592c9af304598307dac6967990bd25f00b47d9f7
|
3 |
+
size 22078636
|
mst_plus_plus.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d285430cc688d08434582eee71bae2d82661be7997af1a68d6636ec25f7a3421
|
3 |
+
size 6580074
|
requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
opencv-python==4.11.0.86
|
2 |
+
einops==0.8.1
|
3 |
+
torchvision==0.22.0
|
4 |
+
torch==2.7.0
|
5 |
+
scipy==1.15.2
|
6 |
+
h5py==3.13.0
|
7 |
+
hdf5storage==0.1.19
|
8 |
+
tqdm==4.67.1
|
9 |
+
gdown==5.2.0
|
10 |
+
matplotlib==3.10.1
|
11 |
+
gradio==5.29.0
|
12 |
+
|
test_challenge_code/architecture/HDNet.py
ADDED
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
def default_conv(in_channels, out_channels, kernel_size, bias=True):
|
4 |
+
return nn.Conv2d(
|
5 |
+
in_channels, out_channels, kernel_size,
|
6 |
+
padding=(kernel_size//2), bias=bias)
|
7 |
+
|
8 |
+
class MeanShift(nn.Conv2d):
|
9 |
+
def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):
|
10 |
+
super(MeanShift, self).__init__(3, 3, kernel_size=1)
|
11 |
+
std = torch.Tensor(rgb_std)
|
12 |
+
self.weight.data = torch.eye(3).view(3, 3, 1, 1)
|
13 |
+
self.weight.data.div_(std.view(3, 1, 1, 1))
|
14 |
+
self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)
|
15 |
+
self.bias.data.div_(std)
|
16 |
+
self.requires_grad = False
|
17 |
+
|
18 |
+
class BasicBlock(nn.Sequential):
|
19 |
+
def __init__(
|
20 |
+
self, in_channels, out_channels, kernel_size, stride=1, bias=False,
|
21 |
+
bn=True, act=nn.ReLU(True)):
|
22 |
+
|
23 |
+
m = [nn.Conv2d(
|
24 |
+
in_channels, out_channels, kernel_size,
|
25 |
+
padding=(kernel_size//2), stride=stride, bias=bias)
|
26 |
+
]
|
27 |
+
if bn: m.append(nn.BatchNorm2d(out_channels))
|
28 |
+
if act is not None: m.append(act)
|
29 |
+
super(BasicBlock, self).__init__(*m)
|
30 |
+
|
31 |
+
class ResBlock(nn.Module):
|
32 |
+
def __init__(
|
33 |
+
self, conv=default_conv, n_feat=31, kernel_size=3,
|
34 |
+
bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
|
35 |
+
|
36 |
+
super(ResBlock, self).__init__()
|
37 |
+
m = []
|
38 |
+
for i in range(2):
|
39 |
+
m.append(conv(n_feat, n_feat, kernel_size, bias=bias))
|
40 |
+
if bn: m.append(nn.BatchNorm2d(n_feat))
|
41 |
+
if i == 0: m.append(act)
|
42 |
+
|
43 |
+
self.body = nn.Sequential(*m)
|
44 |
+
self.res_scale = res_scale
|
45 |
+
|
46 |
+
def forward(self, x):
|
47 |
+
res = self.body(x).mul(self.res_scale)
|
48 |
+
res += x
|
49 |
+
|
50 |
+
return res
|
51 |
+
|
52 |
+
class Upsampler(nn.Sequential):
|
53 |
+
def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True):
|
54 |
+
|
55 |
+
m = []
|
56 |
+
if (scale & (scale - 1)) == 0: # Is scale = 2^n?
|
57 |
+
for _ in range(int(math.log(scale, 2))):
|
58 |
+
m.append(conv(n_feat, 4 * n_feat, 3, bias))
|
59 |
+
m.append(nn.PixelShuffle(2))
|
60 |
+
if bn: m.append(nn.BatchNorm2d(n_feat))
|
61 |
+
if act: m.append(act())
|
62 |
+
elif scale == 3:
|
63 |
+
m.append(conv(n_feat, 9 * n_feat, 3, bias))
|
64 |
+
m.append(nn.PixelShuffle(3))
|
65 |
+
if bn: m.append(nn.BatchNorm2d(n_feat))
|
66 |
+
if act: m.append(act())
|
67 |
+
else:
|
68 |
+
raise NotImplementedError
|
69 |
+
|
70 |
+
super(Upsampler, self).__init__(*m)
|
71 |
+
|
72 |
+
## add SELayer
|
73 |
+
class SELayer(nn.Module):
|
74 |
+
def __init__(self, channel, reduction=16):
|
75 |
+
super(SELayer, self).__init__()
|
76 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
77 |
+
self.conv_du = nn.Sequential(
|
78 |
+
nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),
|
79 |
+
nn.ReLU(inplace=True),
|
80 |
+
nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),
|
81 |
+
nn.Sigmoid()
|
82 |
+
)
|
83 |
+
|
84 |
+
def forward(self, x):
|
85 |
+
y = self.avg_pool(x)
|
86 |
+
y = self.conv_du(y)
|
87 |
+
return x * y
|
88 |
+
|
89 |
+
## add SEResBlock
|
90 |
+
class SEResBlock(nn.Module):
|
91 |
+
def __init__(
|
92 |
+
self, conv, n_feat, kernel_size, reduction,
|
93 |
+
bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
|
94 |
+
|
95 |
+
super(SEResBlock, self).__init__()
|
96 |
+
modules_body = []
|
97 |
+
for i in range(2):
|
98 |
+
modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))
|
99 |
+
if bn: modules_body.append(nn.BatchNorm2d(n_feat))
|
100 |
+
if i == 0: modules_body.append(act)
|
101 |
+
modules_body.append(SELayer(n_feat, reduction))
|
102 |
+
self.body = nn.Sequential(*modules_body)
|
103 |
+
self.res_scale = res_scale
|
104 |
+
|
105 |
+
def forward(self, x):
|
106 |
+
res = self.body(x)
|
107 |
+
#res = self.body(x).mul(self.res_scale)
|
108 |
+
res += x
|
109 |
+
|
110 |
+
return res
|
111 |
+
|
112 |
+
|
113 |
+
_NORM_BONE = False
|
114 |
+
|
115 |
+
def constant_init(module, val, bias=0):
|
116 |
+
if hasattr(module, 'weight') and module.weight is not None:
|
117 |
+
nn.init.constant_(module.weight, val)
|
118 |
+
if hasattr(module, 'bias') and module.bias is not None:
|
119 |
+
nn.init.constant_(module.bias, bias)
|
120 |
+
|
121 |
+
|
122 |
+
def kaiming_init(module,
|
123 |
+
a=0,
|
124 |
+
mode='fan_out',
|
125 |
+
nonlinearity='relu',
|
126 |
+
bias=0,
|
127 |
+
distribution='normal'):
|
128 |
+
assert distribution in ['uniform', 'normal']
|
129 |
+
if distribution == 'uniform':
|
130 |
+
nn.init.kaiming_uniform_(
|
131 |
+
module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
|
132 |
+
else:
|
133 |
+
nn.init.kaiming_normal_(
|
134 |
+
module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
|
135 |
+
if hasattr(module, 'bias') and module.bias is not None:
|
136 |
+
nn.init.constant_(module.bias, bias)
|
137 |
+
|
138 |
+
# depthwise-separable convolution (DSC)
|
139 |
+
class DSC(nn.Module):
|
140 |
+
|
141 |
+
def __init__(self, nin: int) -> None:
|
142 |
+
super(DSC, self).__init__()
|
143 |
+
self.conv_dws = nn.Conv2d(
|
144 |
+
nin, nin, kernel_size=1, stride=1, padding=0, groups=nin
|
145 |
+
)
|
146 |
+
self.bn_dws = nn.BatchNorm2d(nin, momentum=0.9)
|
147 |
+
self.relu_dws = nn.ReLU(inplace=False)
|
148 |
+
|
149 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
|
150 |
+
|
151 |
+
self.conv_point = nn.Conv2d(
|
152 |
+
nin, 1, kernel_size=1, stride=1, padding=0, groups=1
|
153 |
+
)
|
154 |
+
self.bn_point = nn.BatchNorm2d(1, momentum=0.9)
|
155 |
+
self.relu_point = nn.ReLU(inplace=False)
|
156 |
+
|
157 |
+
self.softmax = nn.Softmax(dim=2)
|
158 |
+
|
159 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
160 |
+
out = self.conv_dws(x)
|
161 |
+
out = self.bn_dws(out)
|
162 |
+
out = self.relu_dws(out)
|
163 |
+
|
164 |
+
out = self.maxpool(out)
|
165 |
+
|
166 |
+
out = self.conv_point(out)
|
167 |
+
out = self.bn_point(out)
|
168 |
+
out = self.relu_point(out)
|
169 |
+
|
170 |
+
m, n, p, q = out.shape
|
171 |
+
out = self.softmax(out.view(m, n, -1))
|
172 |
+
out = out.view(m, n, p, q)
|
173 |
+
|
174 |
+
out = out.expand(x.shape[0], x.shape[1], x.shape[2], x.shape[3])
|
175 |
+
|
176 |
+
out = torch.mul(out, x)
|
177 |
+
|
178 |
+
out = out + x
|
179 |
+
|
180 |
+
return out
|
181 |
+
|
182 |
+
# Efficient Feature Fusion(EFF)
|
183 |
+
class EFF(nn.Module):
|
184 |
+
def __init__(self, nin: int, nout: int, num_splits: int) -> None:
|
185 |
+
super(EFF, self).__init__()
|
186 |
+
|
187 |
+
assert nin % num_splits == 0
|
188 |
+
|
189 |
+
self.nin = nin
|
190 |
+
self.nout = nout
|
191 |
+
self.num_splits = num_splits
|
192 |
+
self.subspaces = nn.ModuleList(
|
193 |
+
[DSC(int(self.nin / self.num_splits)) for i in range(self.num_splits)]
|
194 |
+
)
|
195 |
+
|
196 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
197 |
+
sub_feat = torch.chunk(x, self.num_splits, dim=1)
|
198 |
+
out = []
|
199 |
+
for idx, l in enumerate(self.subspaces):
|
200 |
+
out.append(self.subspaces[idx](sub_feat[idx]))
|
201 |
+
out = torch.cat(out, dim=1)
|
202 |
+
|
203 |
+
return out
|
204 |
+
|
205 |
+
|
206 |
+
# spatial-spectral domain attention learning(SDL)
|
207 |
+
class SDL_attention(nn.Module):
|
208 |
+
def __init__(self, inplanes, planes, kernel_size=1, stride=1):
|
209 |
+
super(SDL_attention, self).__init__()
|
210 |
+
|
211 |
+
self.inplanes = inplanes
|
212 |
+
self.inter_planes = planes // 2
|
213 |
+
self.planes = planes
|
214 |
+
self.kernel_size = kernel_size
|
215 |
+
self.stride = stride
|
216 |
+
self.padding = (kernel_size-1)//2
|
217 |
+
|
218 |
+
self.conv_q_right = nn.Conv2d(self.inplanes, 1, kernel_size=1, stride=stride, padding=0, bias=False)
|
219 |
+
self.conv_v_right = nn.Conv2d(self.inplanes, self.inter_planes, kernel_size=1, stride=stride, padding=0, bias=False)
|
220 |
+
self.conv_up = nn.Conv2d(self.inter_planes, self.planes, kernel_size=1, stride=1, padding=0, bias=False)
|
221 |
+
self.softmax_right = nn.Softmax(dim=2)
|
222 |
+
self.sigmoid = nn.Sigmoid()
|
223 |
+
|
224 |
+
self.conv_q_left = nn.Conv2d(self.inplanes, self.inter_planes, kernel_size=1, stride=stride, padding=0, bias=False) #g
|
225 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
226 |
+
self.conv_v_left = nn.Conv2d(self.inplanes, self.inter_planes, kernel_size=1, stride=stride, padding=0, bias=False) #theta
|
227 |
+
self.softmax_left = nn.Softmax(dim=2)
|
228 |
+
|
229 |
+
self.reset_parameters()
|
230 |
+
|
231 |
+
def reset_parameters(self):
|
232 |
+
kaiming_init(self.conv_q_right, mode='fan_in')
|
233 |
+
kaiming_init(self.conv_v_right, mode='fan_in')
|
234 |
+
kaiming_init(self.conv_q_left, mode='fan_in')
|
235 |
+
kaiming_init(self.conv_v_left, mode='fan_in')
|
236 |
+
|
237 |
+
self.conv_q_right.inited = True
|
238 |
+
self.conv_v_right.inited = True
|
239 |
+
self.conv_q_left.inited = True
|
240 |
+
self.conv_v_left.inited = True
|
241 |
+
# HR spatial attention
|
242 |
+
def spatial_attention(self, x):
|
243 |
+
input_x = self.conv_v_right(x)
|
244 |
+
batch, channel, height, width = input_x.size()
|
245 |
+
|
246 |
+
input_x = input_x.view(batch, channel, height * width)
|
247 |
+
context_mask = self.conv_q_right(x)
|
248 |
+
context_mask = context_mask.view(batch, 1, height * width)
|
249 |
+
context_mask = self.softmax_right(context_mask)
|
250 |
+
|
251 |
+
context = torch.matmul(input_x, context_mask.transpose(1,2))
|
252 |
+
context = context.unsqueeze(-1)
|
253 |
+
context = self.conv_up(context)
|
254 |
+
|
255 |
+
mask_ch = self.sigmoid(context)
|
256 |
+
|
257 |
+
out = x * mask_ch
|
258 |
+
|
259 |
+
return out
|
260 |
+
# HR spectral attention
|
261 |
+
def spectral_attention(self, x):
|
262 |
+
|
263 |
+
g_x = self.conv_q_left(x)
|
264 |
+
batch, channel, height, width = g_x.size()
|
265 |
+
|
266 |
+
avg_x = self.avg_pool(g_x)
|
267 |
+
batch, channel, avg_x_h, avg_x_w = avg_x.size()
|
268 |
+
|
269 |
+
avg_x = avg_x.view(batch, channel, avg_x_h * avg_x_w).permute(0, 2, 1)
|
270 |
+
theta_x = self.conv_v_left(x).view(batch, self.inter_planes, height * width)
|
271 |
+
context = torch.matmul(avg_x, theta_x)
|
272 |
+
context = self.softmax_left(context)
|
273 |
+
context = context.view(batch, 1, height, width)
|
274 |
+
|
275 |
+
mask_sp = self.sigmoid(context)
|
276 |
+
|
277 |
+
out = x * mask_sp
|
278 |
+
|
279 |
+
return out
|
280 |
+
|
281 |
+
def forward(self, x):
|
282 |
+
context_spectral = self.spectral_attention(x)
|
283 |
+
context_spatial = self.spatial_attention(x)
|
284 |
+
out = context_spatial + context_spectral
|
285 |
+
return out
|
286 |
+
|
287 |
+
|
288 |
+
class HDNet(nn.Module):
|
289 |
+
|
290 |
+
def __init__(self, in_ch=3, out_ch=31, conv=default_conv):
|
291 |
+
super(HDNet, self).__init__()
|
292 |
+
|
293 |
+
n_resblocks = 32
|
294 |
+
n_feats = 48
|
295 |
+
kernel_size = 3
|
296 |
+
act = nn.ReLU(True)
|
297 |
+
|
298 |
+
# define head module
|
299 |
+
m_head = [conv(in_ch, n_feats, kernel_size)]
|
300 |
+
|
301 |
+
# define body module
|
302 |
+
m_body = [
|
303 |
+
ResBlock(
|
304 |
+
conv, n_feats, kernel_size, act=act, res_scale= 1
|
305 |
+
) for _ in range(n_resblocks)
|
306 |
+
]
|
307 |
+
m_body.append(SDL_attention(inplanes = n_feats, planes = n_feats))
|
308 |
+
m_body.append(EFF(nin=n_feats, nout=n_feats, num_splits=4))
|
309 |
+
|
310 |
+
for i in range(1, n_resblocks):
|
311 |
+
m_body.append(ResBlock(
|
312 |
+
conv, n_feats, kernel_size, act=act, res_scale= 1
|
313 |
+
))
|
314 |
+
|
315 |
+
m_body.append(conv(n_feats, n_feats, kernel_size))
|
316 |
+
|
317 |
+
m_tail = [conv(n_feats, out_ch, kernel_size)]
|
318 |
+
|
319 |
+
self.head = nn.Sequential(*m_head)
|
320 |
+
self.body = nn.Sequential(*m_body)
|
321 |
+
self.tail = nn.Sequential(*m_tail)
|
322 |
+
|
323 |
+
def forward(self, x):
|
324 |
+
x = self.head(x)
|
325 |
+
|
326 |
+
res = self.body(x)
|
327 |
+
res += x
|
328 |
+
|
329 |
+
x = self.tail(res)
|
330 |
+
|
331 |
+
return x
|
332 |
+
|
333 |
+
# frequency domain learning(FDL)
|
334 |
+
class FDL(nn.Module):
|
335 |
+
def __init__(self, loss_weight=1.0, alpha=1.0, patch_factor=1, ave_spectrum=False, log_matrix=False, batch_matrix=False):
|
336 |
+
super(FDL, self).__init__()
|
337 |
+
self.loss_weight = loss_weight
|
338 |
+
self.alpha = alpha
|
339 |
+
self.patch_factor = patch_factor
|
340 |
+
self.ave_spectrum = ave_spectrum
|
341 |
+
self.log_matrix = log_matrix
|
342 |
+
self.batch_matrix = batch_matrix
|
343 |
+
|
344 |
+
def tensor2freq(self, x):
|
345 |
+
patch_factor = self.patch_factor
|
346 |
+
_, _, h, w = x.shape
|
347 |
+
assert h % patch_factor == 0 and w % patch_factor == 0, (
|
348 |
+
'Patch factor should be divisible by image height and width')
|
349 |
+
patch_list = []
|
350 |
+
patch_h = h // patch_factor
|
351 |
+
patch_w = w // patch_factor
|
352 |
+
for i in range(patch_factor):
|
353 |
+
for j in range(patch_factor):
|
354 |
+
patch_list.append(x[:, :, i * patch_h:(i + 1) * patch_h, j * patch_w:(j + 1) * patch_w])
|
355 |
+
|
356 |
+
y = torch.stack(patch_list, 1)
|
357 |
+
|
358 |
+
return torch.rfft(y, 2, onesided=False, normalized=True)
|
359 |
+
|
360 |
+
def loss_formulation(self, recon_freq, real_freq, matrix=None):
|
361 |
+
if matrix is not None:
|
362 |
+
weight_matrix = matrix.detach()
|
363 |
+
else:
|
364 |
+
matrix_tmp = (recon_freq - real_freq) ** 2
|
365 |
+
matrix_tmp = torch.sqrt(matrix_tmp[..., 0] + matrix_tmp[..., 1]) ** self.alpha
|
366 |
+
if self.log_matrix:
|
367 |
+
matrix_tmp = torch.log(matrix_tmp + 1.0)
|
368 |
+
|
369 |
+
if self.batch_matrix:
|
370 |
+
matrix_tmp = matrix_tmp / matrix_tmp.max()
|
371 |
+
else:
|
372 |
+
matrix_tmp = matrix_tmp / matrix_tmp.max(-1).values.max(-1).values[:, :, :, None, None]
|
373 |
+
|
374 |
+
matrix_tmp[torch.isnan(matrix_tmp)] = 0.0
|
375 |
+
matrix_tmp = torch.clamp(matrix_tmp, min=0.0, max=1.0)
|
376 |
+
weight_matrix = matrix_tmp.clone().detach()
|
377 |
+
|
378 |
+
assert weight_matrix.min().item() >= 0 and weight_matrix.max().item() <= 1, (
|
379 |
+
'The values of spectrum weight matrix should be in the range [0, 1], '
|
380 |
+
'but got Min: %.10f Max: %.10f' % (weight_matrix.min().item(), weight_matrix.max().item()))
|
381 |
+
|
382 |
+
tmp = (recon_freq - real_freq) ** 2
|
383 |
+
freq_distance = tmp[..., 0] + tmp[..., 1]
|
384 |
+
|
385 |
+
loss = weight_matrix * freq_distance
|
386 |
+
return torch.mean(loss)
|
387 |
+
|
388 |
+
def forward(self, pred, target, matrix=None, **kwargs):
|
389 |
+
|
390 |
+
pred_freq = self.tensor2freq(pred)
|
391 |
+
target_freq = self.tensor2freq(target)
|
392 |
+
|
393 |
+
if self.ave_spectrum:
|
394 |
+
pred_freq = torch.mean(pred_freq, 0, keepdim=True)
|
395 |
+
target_freq = torch.mean(target_freq, 0, keepdim=True)
|
396 |
+
|
397 |
+
return self.loss_formulation(pred_freq, target_freq, matrix) * self.loss_weight
|
test_challenge_code/architecture/HSCNN_Plus.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch
|
3 |
+
class dfus_block(nn.Module):
|
4 |
+
def __init__(self, dim):
|
5 |
+
super(dfus_block, self).__init__()
|
6 |
+
self.conv1 = nn.Conv2d(dim, 128, 1, 1, 0, bias=False)
|
7 |
+
|
8 |
+
self.conv_up1 = nn.Conv2d(128, 32, 3, 1, 1, bias=False)
|
9 |
+
self.conv_up2 = nn.Conv2d(32, 16, 1, 1, 0, bias=False)
|
10 |
+
|
11 |
+
self.conv_down1 = nn.Conv2d(128, 32, 3, 1, 1, bias=False)
|
12 |
+
self.conv_down2 = nn.Conv2d(32, 16, 1, 1, 0, bias=False)
|
13 |
+
|
14 |
+
self.conv_fution = nn.Conv2d(96, 32, 1, 1, 0, bias=False)
|
15 |
+
|
16 |
+
#### activation function
|
17 |
+
self.relu = nn.ReLU(inplace=True)
|
18 |
+
|
19 |
+
def forward(self, x):
|
20 |
+
"""
|
21 |
+
x: [b,c,h,w]
|
22 |
+
return out:[b,c,h,w]
|
23 |
+
"""
|
24 |
+
feat = self.relu(self.conv1(x))
|
25 |
+
feat_up1 = self.relu(self.conv_up1(feat))
|
26 |
+
feat_up2 = self.relu(self.conv_up2(feat_up1))
|
27 |
+
feat_down1 = self.relu(self.conv_down1(feat))
|
28 |
+
feat_down2 = self.relu(self.conv_down2(feat_down1))
|
29 |
+
feat_fution = torch.cat([feat_up1,feat_up2,feat_down1,feat_down2],dim=1)
|
30 |
+
feat_fution = self.relu(self.conv_fution(feat_fution))
|
31 |
+
out = torch.cat([x, feat_fution], dim=1)
|
32 |
+
return out
|
33 |
+
|
34 |
+
class ddfn(nn.Module):
|
35 |
+
def __init__(self, dim, num_blocks=78):
|
36 |
+
super(ddfn, self).__init__()
|
37 |
+
|
38 |
+
self.conv_up1 = nn.Conv2d(dim, 32, 3, 1, 1, bias=False)
|
39 |
+
self.conv_up2 = nn.Conv2d(32, 32, 1, 1, 0, bias=False)
|
40 |
+
|
41 |
+
self.conv_down1 = nn.Conv2d(dim, 32, 3, 1, 1, bias=False)
|
42 |
+
self.conv_down2 = nn.Conv2d(32, 32, 1, 1, 0, bias=False)
|
43 |
+
|
44 |
+
dfus_blocks = [dfus_block(dim=128+32*i) for i in range(num_blocks)]
|
45 |
+
self.dfus_blocks = nn.Sequential(*dfus_blocks)
|
46 |
+
|
47 |
+
#### activation function
|
48 |
+
self.relu = nn.ReLU(inplace=True)
|
49 |
+
|
50 |
+
def forward(self, x):
|
51 |
+
"""
|
52 |
+
x: [b,c,h,w]
|
53 |
+
return out:[b,c,h,w]
|
54 |
+
"""
|
55 |
+
feat_up1 = self.relu(self.conv_up1(x))
|
56 |
+
feat_up2 = self.relu(self.conv_up2(feat_up1))
|
57 |
+
feat_down1 = self.relu(self.conv_down1(x))
|
58 |
+
feat_down2 = self.relu(self.conv_down2(feat_down1))
|
59 |
+
feat_fution = torch.cat([feat_up1,feat_up2,feat_down1,feat_down2],dim=1)
|
60 |
+
out = self.dfus_blocks(feat_fution)
|
61 |
+
return out
|
62 |
+
|
63 |
+
class HSCNN_Plus(nn.Module):
|
64 |
+
def __init__(self, in_channels=3, out_channels=31, num_blocks=30):
|
65 |
+
super(HSCNN_Plus, self).__init__()
|
66 |
+
|
67 |
+
self.ddfn = ddfn(dim=in_channels, num_blocks=num_blocks)
|
68 |
+
self.conv_out = nn.Conv2d(128+32*num_blocks, out_channels, 1, 1, 0, bias=False)
|
69 |
+
|
70 |
+
def forward(self, x):
|
71 |
+
"""
|
72 |
+
x: [b,c,h,w]
|
73 |
+
return out:[b,c,h,w]
|
74 |
+
"""
|
75 |
+
fea = self.ddfn(x)
|
76 |
+
out = self.conv_out(fea)
|
77 |
+
return out
|
test_challenge_code/architecture/MIRNet.py
ADDED
@@ -0,0 +1,416 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --- Imports --- #
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import numpy as np
|
6 |
+
# from pdb import set_trace as stx
|
7 |
+
|
8 |
+
def get_pad_layer(pad_type):
|
9 |
+
if(pad_type in ['refl','reflect']):
|
10 |
+
PadLayer = nn.ReflectionPad2d
|
11 |
+
elif(pad_type in ['repl','replicate']):
|
12 |
+
PadLayer = nn.ReplicationPad2d
|
13 |
+
elif(pad_type=='zero'):
|
14 |
+
PadLayer = nn.ZeroPad2d
|
15 |
+
else:
|
16 |
+
print('Pad type [%s] not recognized'%pad_type)
|
17 |
+
return PadLayer
|
18 |
+
|
19 |
+
class downsamp(nn.Module):
|
20 |
+
def __init__(self, pad_type='reflect', filt_size=3, stride=2, channels=None, pad_off=0):
|
21 |
+
super(downsamp, self).__init__()
|
22 |
+
self.filt_size = filt_size
|
23 |
+
self.pad_off = pad_off
|
24 |
+
self.pad_sizes = [int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2)), int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2))]
|
25 |
+
self.pad_sizes = [pad_size+pad_off for pad_size in self.pad_sizes]
|
26 |
+
self.stride = stride
|
27 |
+
self.off = int((self.stride-1)/2.)
|
28 |
+
self.channels = channels
|
29 |
+
|
30 |
+
# print('Filter size [%i]'%filt_size)
|
31 |
+
if(self.filt_size==1):
|
32 |
+
a = np.array([1.,])
|
33 |
+
elif(self.filt_size==2):
|
34 |
+
a = np.array([1., 1.])
|
35 |
+
elif(self.filt_size==3):
|
36 |
+
a = np.array([1., 2., 1.])
|
37 |
+
elif(self.filt_size==4):
|
38 |
+
a = np.array([1., 3., 3., 1.])
|
39 |
+
elif(self.filt_size==5):
|
40 |
+
a = np.array([1., 4., 6., 4., 1.])
|
41 |
+
elif(self.filt_size==6):
|
42 |
+
a = np.array([1., 5., 10., 10., 5., 1.])
|
43 |
+
elif(self.filt_size==7):
|
44 |
+
a = np.array([1., 6., 15., 20., 15., 6., 1.])
|
45 |
+
|
46 |
+
filt = torch.Tensor(a[:,None]*a[None,:])
|
47 |
+
filt = filt/torch.sum(filt)
|
48 |
+
self.register_buffer('filt', filt[None,None,:,:].repeat((self.channels,1,1,1)))
|
49 |
+
|
50 |
+
self.pad = get_pad_layer(pad_type)(self.pad_sizes)
|
51 |
+
|
52 |
+
def forward(self, inp):
|
53 |
+
if(self.filt_size==1):
|
54 |
+
if(self.pad_off==0):
|
55 |
+
return inp[:,:,::self.stride,::self.stride]
|
56 |
+
else:
|
57 |
+
return self.pad(inp)[:,:,::self.stride,::self.stride]
|
58 |
+
else:
|
59 |
+
return F.conv2d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1])
|
60 |
+
|
61 |
+
##########################################################################
|
62 |
+
|
63 |
+
def conv(in_channels, out_channels, kernel_size, bias=False, padding=1, stride=1):
|
64 |
+
return nn.Conv2d(
|
65 |
+
in_channels, out_channels, kernel_size,
|
66 |
+
padding=(kernel_size // 2), bias=bias, stride=stride)
|
67 |
+
|
68 |
+
|
69 |
+
##########################################################################
|
70 |
+
##---------- Selective Kernel Feature Fusion (SKFF) ----------
|
71 |
+
class SKFF(nn.Module):
|
72 |
+
def __init__(self, in_channels, height=3, reduction=8, bias=False):
|
73 |
+
super(SKFF, self).__init__()
|
74 |
+
|
75 |
+
self.height = height
|
76 |
+
d = max(int(in_channels / reduction), 4)
|
77 |
+
|
78 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
79 |
+
self.conv_du = nn.Sequential(nn.Conv2d(in_channels, d, 1, padding=0, bias=bias), nn.PReLU())
|
80 |
+
|
81 |
+
self.fcs = nn.ModuleList([])
|
82 |
+
for i in range(self.height):
|
83 |
+
self.fcs.append(nn.Conv2d(d, in_channels, kernel_size=1, stride=1, bias=bias))
|
84 |
+
|
85 |
+
self.softmax = nn.Softmax(dim=1)
|
86 |
+
|
87 |
+
def forward(self, inp_feats):
|
88 |
+
batch_size = inp_feats[0].shape[0]
|
89 |
+
n_feats = inp_feats[0].shape[1]
|
90 |
+
|
91 |
+
inp_feats = torch.cat(inp_feats, dim=1)
|
92 |
+
inp_feats = inp_feats.view(batch_size, self.height, n_feats, inp_feats.shape[2], inp_feats.shape[3])
|
93 |
+
|
94 |
+
feats_U = torch.sum(inp_feats, dim=1)
|
95 |
+
feats_S = self.avg_pool(feats_U)
|
96 |
+
feats_Z = self.conv_du(feats_S)
|
97 |
+
|
98 |
+
attention_vectors = [fc(feats_Z) for fc in self.fcs]
|
99 |
+
attention_vectors = torch.cat(attention_vectors, dim=1)
|
100 |
+
attention_vectors = attention_vectors.view(batch_size, self.height, n_feats, 1, 1)
|
101 |
+
# stx()
|
102 |
+
attention_vectors = self.softmax(attention_vectors)
|
103 |
+
|
104 |
+
feats_V = torch.sum(inp_feats * attention_vectors, dim=1)
|
105 |
+
|
106 |
+
return feats_V
|
107 |
+
|
108 |
+
##########################################################################
|
109 |
+
|
110 |
+
|
111 |
+
##---------- Spatial Attention ----------
|
112 |
+
class BasicConv(nn.Module):
|
113 |
+
def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True,
|
114 |
+
bn=False, bias=False):
|
115 |
+
super(BasicConv, self).__init__()
|
116 |
+
self.out_channels = out_planes
|
117 |
+
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,
|
118 |
+
dilation=dilation, groups=groups, bias=bias)
|
119 |
+
self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None
|
120 |
+
self.relu = nn.ReLU() if relu else None
|
121 |
+
|
122 |
+
def forward(self, x):
|
123 |
+
x = self.conv(x)
|
124 |
+
if self.bn is not None:
|
125 |
+
x = self.bn(x)
|
126 |
+
if self.relu is not None:
|
127 |
+
x = self.relu(x)
|
128 |
+
return x
|
129 |
+
|
130 |
+
|
131 |
+
class ChannelPool(nn.Module):
|
132 |
+
def forward(self, x):
|
133 |
+
return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)
|
134 |
+
|
135 |
+
|
136 |
+
class spatial_attn_layer(nn.Module):
|
137 |
+
def __init__(self, kernel_size=5):
|
138 |
+
super(spatial_attn_layer, self).__init__()
|
139 |
+
self.compress = ChannelPool()
|
140 |
+
self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2, relu=False)
|
141 |
+
|
142 |
+
def forward(self, x):
|
143 |
+
# import pdb;pdb.set_trace()
|
144 |
+
x_compress = self.compress(x)
|
145 |
+
x_out = self.spatial(x_compress)
|
146 |
+
scale = torch.sigmoid(x_out) # broadcasting
|
147 |
+
return x * scale
|
148 |
+
|
149 |
+
|
150 |
+
##########################################################################
|
151 |
+
## ------ Channel Attention --------------
|
152 |
+
class ca_layer(nn.Module):
|
153 |
+
def __init__(self, channel, reduction=8, bias=True):
|
154 |
+
super(ca_layer, self).__init__()
|
155 |
+
# global average pooling: feature --> point
|
156 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
157 |
+
# feature channel downscale and upscale --> channel weight
|
158 |
+
self.conv_du = nn.Sequential(
|
159 |
+
nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=bias),
|
160 |
+
nn.ReLU(inplace=True),
|
161 |
+
nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=bias),
|
162 |
+
nn.Sigmoid()
|
163 |
+
)
|
164 |
+
|
165 |
+
def forward(self, x):
|
166 |
+
y = self.avg_pool(x)
|
167 |
+
y = self.conv_du(y)
|
168 |
+
return x * y
|
169 |
+
|
170 |
+
|
171 |
+
##########################################################################
|
172 |
+
##---------- Dual Attention Unit (DAU) ----------
|
173 |
+
class DAU(nn.Module):
|
174 |
+
def __init__(
|
175 |
+
self, n_feat, kernel_size=3, reduction=8,
|
176 |
+
bias=False, bn=False, act=nn.PReLU(), res_scale=1):
|
177 |
+
super(DAU, self).__init__()
|
178 |
+
modules_body = [conv(n_feat, n_feat, kernel_size, bias=bias), act, conv(n_feat, n_feat, kernel_size, bias=bias)]
|
179 |
+
self.body = nn.Sequential(*modules_body)
|
180 |
+
|
181 |
+
## Spatial Attention
|
182 |
+
self.SA = spatial_attn_layer()
|
183 |
+
|
184 |
+
## Channel Attention
|
185 |
+
self.CA = ca_layer(n_feat, reduction, bias=bias)
|
186 |
+
|
187 |
+
self.conv1x1 = nn.Conv2d(n_feat * 2, n_feat, kernel_size=1, bias=bias)
|
188 |
+
|
189 |
+
def forward(self, x):
|
190 |
+
res = self.body(x)
|
191 |
+
sa_branch = self.SA(res)
|
192 |
+
ca_branch = self.CA(res)
|
193 |
+
res = torch.cat([sa_branch, ca_branch], dim=1)
|
194 |
+
res = self.conv1x1(res)
|
195 |
+
res += x
|
196 |
+
return res
|
197 |
+
|
198 |
+
|
199 |
+
##########################################################################
|
200 |
+
##---------- Resizing Modules ----------
|
201 |
+
class ResidualDownSample(nn.Module):
|
202 |
+
def __init__(self, in_channels, bias=False):
|
203 |
+
super(ResidualDownSample, self).__init__()
|
204 |
+
|
205 |
+
self.top = nn.Sequential(nn.Conv2d(in_channels, in_channels, 1, stride=1, padding=0, bias=bias),
|
206 |
+
nn.PReLU(),
|
207 |
+
nn.Conv2d(in_channels, in_channels, 3, stride=1, padding=1, bias=bias),
|
208 |
+
nn.PReLU(),
|
209 |
+
downsamp(channels=in_channels, filt_size=3, stride=2),
|
210 |
+
nn.Conv2d(in_channels, in_channels * 2, 1, stride=1, padding=0, bias=bias))
|
211 |
+
|
212 |
+
self.bot = nn.Sequential(downsamp(channels=in_channels, filt_size=3, stride=2),
|
213 |
+
nn.Conv2d(in_channels, in_channels * 2, 1, stride=1, padding=0, bias=bias))
|
214 |
+
|
215 |
+
def forward(self, x):
|
216 |
+
top = self.top(x)
|
217 |
+
bot = self.bot(x)
|
218 |
+
out = top + bot
|
219 |
+
return out
|
220 |
+
|
221 |
+
|
222 |
+
class DownSample(nn.Module):
|
223 |
+
def __init__(self, in_channels, scale_factor, stride=2, kernel_size=3):
|
224 |
+
super(DownSample, self).__init__()
|
225 |
+
self.scale_factor = int(np.log2(scale_factor))
|
226 |
+
|
227 |
+
modules_body = []
|
228 |
+
for i in range(self.scale_factor):
|
229 |
+
modules_body.append(ResidualDownSample(in_channels))
|
230 |
+
in_channels = int(in_channels * stride)
|
231 |
+
|
232 |
+
self.body = nn.Sequential(*modules_body)
|
233 |
+
|
234 |
+
def forward(self, x):
|
235 |
+
x = self.body(x)
|
236 |
+
return x
|
237 |
+
|
238 |
+
|
239 |
+
class ResidualUpSample(nn.Module):
|
240 |
+
def __init__(self, in_channels, bias=False):
|
241 |
+
super(ResidualUpSample, self).__init__()
|
242 |
+
|
243 |
+
self.top = nn.Sequential(nn.Conv2d(in_channels, in_channels, 1, stride=1, padding=0, bias=bias),
|
244 |
+
nn.PReLU(),
|
245 |
+
nn.ConvTranspose2d(in_channels, in_channels, 3, stride=2, padding=1, output_padding=1,
|
246 |
+
bias=bias),
|
247 |
+
nn.PReLU(),
|
248 |
+
nn.Conv2d(in_channels, in_channels // 2, 1, stride=1, padding=0, bias=bias))
|
249 |
+
|
250 |
+
self.bot = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=bias),
|
251 |
+
nn.Conv2d(in_channels, in_channels // 2, 1, stride=1, padding=0, bias=bias))
|
252 |
+
|
253 |
+
def forward(self, x):
|
254 |
+
top = self.top(x)
|
255 |
+
bot = self.bot(x)
|
256 |
+
out = top + bot
|
257 |
+
return out
|
258 |
+
|
259 |
+
|
260 |
+
class UpSample(nn.Module):
|
261 |
+
def __init__(self, in_channels, scale_factor, stride=2, kernel_size=3):
|
262 |
+
super(UpSample, self).__init__()
|
263 |
+
self.scale_factor = int(np.log2(scale_factor))
|
264 |
+
|
265 |
+
modules_body = []
|
266 |
+
for i in range(self.scale_factor):
|
267 |
+
modules_body.append(ResidualUpSample(in_channels))
|
268 |
+
in_channels = int(in_channels // stride)
|
269 |
+
|
270 |
+
self.body = nn.Sequential(*modules_body)
|
271 |
+
|
272 |
+
def forward(self, x):
|
273 |
+
x = self.body(x)
|
274 |
+
return x
|
275 |
+
|
276 |
+
|
277 |
+
##########################################################################
|
278 |
+
##---------- Multi-Scale Resiudal Block (MSRB) ----------
|
279 |
+
class MSRB(nn.Module):
|
280 |
+
def __init__(self, n_feat, height, width, stride, bias):
|
281 |
+
super(MSRB, self).__init__()
|
282 |
+
|
283 |
+
self.n_feat, self.height, self.width = n_feat, height, width
|
284 |
+
self.blocks = nn.ModuleList([nn.ModuleList([DAU(int(n_feat * stride ** i))] * width) for i in range(height)])
|
285 |
+
|
286 |
+
INDEX = np.arange(0, width, 2)
|
287 |
+
FEATS = [int((stride ** i) * n_feat) for i in range(height)]
|
288 |
+
SCALE = [2 ** i for i in range(1, height)]
|
289 |
+
|
290 |
+
self.last_up = nn.ModuleDict()
|
291 |
+
for i in range(1, height):
|
292 |
+
self.last_up.update({f'{i}': UpSample(int(n_feat * stride ** i), 2 ** i, stride)})
|
293 |
+
|
294 |
+
self.down = nn.ModuleDict()
|
295 |
+
self.up = nn.ModuleDict()
|
296 |
+
|
297 |
+
i = 0
|
298 |
+
SCALE.reverse()
|
299 |
+
for feat in FEATS:
|
300 |
+
for scale in SCALE[i:]:
|
301 |
+
self.down.update({f'{feat}_{scale}': DownSample(feat, scale, stride)})
|
302 |
+
i += 1
|
303 |
+
|
304 |
+
i = 0
|
305 |
+
FEATS.reverse()
|
306 |
+
for feat in FEATS:
|
307 |
+
for scale in SCALE[i:]:
|
308 |
+
self.up.update({f'{feat}_{scale}': UpSample(feat, scale, stride)})
|
309 |
+
i += 1
|
310 |
+
|
311 |
+
self.conv_out = nn.Conv2d(n_feat, n_feat, kernel_size=3, padding=1, bias=bias)
|
312 |
+
|
313 |
+
self.selective_kernel = nn.ModuleList([SKFF(n_feat * stride ** i, height) for i in range(height)])
|
314 |
+
|
315 |
+
def forward(self, x):
|
316 |
+
inp = x.clone()
|
317 |
+
# col 1 only
|
318 |
+
blocks_out = []
|
319 |
+
for j in range(self.height):
|
320 |
+
if j == 0:
|
321 |
+
inp = self.blocks[j][0](inp)
|
322 |
+
else:
|
323 |
+
inp = self.blocks[j][0](self.down[f'{inp.size(1)}_{2}'](inp))
|
324 |
+
blocks_out.append(inp)
|
325 |
+
|
326 |
+
# rest of grid
|
327 |
+
for i in range(1, self.width):
|
328 |
+
# Mesh
|
329 |
+
# Replace condition(i%2!=0) with True(Mesh) or False(Plain)
|
330 |
+
# if i%2!=0:
|
331 |
+
if True:
|
332 |
+
tmp = []
|
333 |
+
for j in range(self.height):
|
334 |
+
TENSOR = []
|
335 |
+
nfeats = (2 ** j) * self.n_feat
|
336 |
+
for k in range(self.height):
|
337 |
+
TENSOR.append(self.select_up_down(blocks_out[k], j, k))
|
338 |
+
|
339 |
+
selective_kernel_fusion = self.selective_kernel[j](TENSOR)
|
340 |
+
tmp.append(selective_kernel_fusion)
|
341 |
+
# Plain
|
342 |
+
else:
|
343 |
+
tmp = blocks_out
|
344 |
+
# Forward through either mesh or plain
|
345 |
+
for j in range(self.height):
|
346 |
+
blocks_out[j] = self.blocks[j][i](tmp[j])
|
347 |
+
|
348 |
+
# Sum after grid
|
349 |
+
out = []
|
350 |
+
for k in range(self.height):
|
351 |
+
out.append(self.select_last_up(blocks_out[k], k))
|
352 |
+
|
353 |
+
out = self.selective_kernel[0](out)
|
354 |
+
|
355 |
+
out = self.conv_out(out)
|
356 |
+
out = out + x
|
357 |
+
|
358 |
+
return out
|
359 |
+
|
360 |
+
def select_up_down(self, tensor, j, k):
|
361 |
+
if j == k:
|
362 |
+
return tensor
|
363 |
+
else:
|
364 |
+
diff = 2 ** np.abs(j - k)
|
365 |
+
if j < k:
|
366 |
+
return self.up[f'{tensor.size(1)}_{diff}'](tensor)
|
367 |
+
else:
|
368 |
+
return self.down[f'{tensor.size(1)}_{diff}'](tensor)
|
369 |
+
|
370 |
+
def select_last_up(self, tensor, k):
|
371 |
+
if k == 0:
|
372 |
+
return tensor
|
373 |
+
else:
|
374 |
+
return self.last_up[f'{k}'](tensor)
|
375 |
+
|
376 |
+
|
377 |
+
##########################################################################
|
378 |
+
##---------- Recursive Residual Group (RRG) ----------
|
379 |
+
class RRG(nn.Module):
|
380 |
+
def __init__(self, n_feat, n_MSRB, height, width, stride, bias=False):
|
381 |
+
super(RRG, self).__init__()
|
382 |
+
modules_body = [MSRB(n_feat, height, width, stride, bias) for _ in range(n_MSRB)]
|
383 |
+
modules_body.append(conv(n_feat, n_feat, kernel_size=3))
|
384 |
+
self.body = nn.Sequential(*modules_body)
|
385 |
+
|
386 |
+
def forward(self, x):
|
387 |
+
res = self.body(x)
|
388 |
+
res += x
|
389 |
+
return res
|
390 |
+
|
391 |
+
##########################################################################
|
392 |
+
##---------- MIRNet -----------------------
|
393 |
+
class MIRNet(nn.Module):
|
394 |
+
def __init__(self, in_channels=3, out_channels=31, n_feat=31, kernel_size=3, stride=2, n_RRG=2, n_MSRB=1, height=3,
|
395 |
+
width=1, bias=False):
|
396 |
+
super(MIRNet, self).__init__()
|
397 |
+
|
398 |
+
self.conv_in = nn.Conv2d(in_channels, n_feat, kernel_size=kernel_size, padding=(kernel_size - 1) // 2,
|
399 |
+
bias=bias)
|
400 |
+
|
401 |
+
modules_body = [RRG(n_feat, n_MSRB, height, width, stride, bias) for _ in range(n_RRG)]
|
402 |
+
self.body = nn.Sequential(*modules_body)
|
403 |
+
|
404 |
+
self.conv_out = nn.Conv2d(n_feat, out_channels, kernel_size=kernel_size, padding=(kernel_size - 1) // 2,
|
405 |
+
bias=bias)
|
406 |
+
def forward(self, x):
|
407 |
+
b, c, h_inp, w_inp = x.shape
|
408 |
+
hb, wb = 8, 8
|
409 |
+
pad_h = (hb - h_inp % hb) % hb
|
410 |
+
pad_w = (wb - w_inp % wb) % wb
|
411 |
+
x = F.pad(x, [0, pad_w, 0, pad_h], mode='reflect')
|
412 |
+
x = self.conv_in(x)
|
413 |
+
h = self.body(x)
|
414 |
+
h = self.conv_out(h)
|
415 |
+
h += x
|
416 |
+
return h[:, :, :h_inp, :w_inp]
|
test_challenge_code/architecture/MPRNet.py
ADDED
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
##########################################################################
|
6 |
+
def conv(in_channels, out_channels, kernel_size, bias=False, stride = 1):
|
7 |
+
return nn.Conv2d(
|
8 |
+
in_channels, out_channels, kernel_size,
|
9 |
+
padding=(kernel_size//2), bias=bias, stride = stride)
|
10 |
+
|
11 |
+
|
12 |
+
##########################################################################
|
13 |
+
## Channel Attention Layer
|
14 |
+
class CALayer(nn.Module):
|
15 |
+
def __init__(self, channel, reduction=16, bias=False):
|
16 |
+
super(CALayer, self).__init__()
|
17 |
+
# global average pooling: feature --> point
|
18 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
19 |
+
# feature channel downscale and upscale --> channel weight
|
20 |
+
self.conv_du = nn.Sequential(
|
21 |
+
nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=bias),
|
22 |
+
nn.ReLU(inplace=True),
|
23 |
+
nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=bias),
|
24 |
+
nn.Sigmoid()
|
25 |
+
)
|
26 |
+
|
27 |
+
def forward(self, x):
|
28 |
+
y = self.avg_pool(x)
|
29 |
+
y = self.conv_du(y)
|
30 |
+
return x * y
|
31 |
+
|
32 |
+
|
33 |
+
##########################################################################
|
34 |
+
## Channel Attention Block (CAB)
|
35 |
+
class CAB(nn.Module):
|
36 |
+
def __init__(self, n_feat, kernel_size, reduction, bias, act):
|
37 |
+
super(CAB, self).__init__()
|
38 |
+
modules_body = []
|
39 |
+
modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))
|
40 |
+
modules_body.append(act)
|
41 |
+
modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))
|
42 |
+
|
43 |
+
self.CA = CALayer(n_feat, reduction, bias=bias)
|
44 |
+
self.body = nn.Sequential(*modules_body)
|
45 |
+
|
46 |
+
def forward(self, x):
|
47 |
+
res = self.body(x)
|
48 |
+
res = self.CA(res)
|
49 |
+
res += x
|
50 |
+
return res
|
51 |
+
|
52 |
+
##########################################################################
|
53 |
+
## Supervised Attention Module
|
54 |
+
class SAM(nn.Module):
|
55 |
+
def __init__(self, n_feat, kernel_size, bias):
|
56 |
+
super(SAM, self).__init__()
|
57 |
+
self.conv1 = conv(n_feat, n_feat, kernel_size, bias=bias)
|
58 |
+
self.conv2 = conv(n_feat, 31, kernel_size, bias=bias)
|
59 |
+
self.conv3 = conv(31, n_feat, kernel_size, bias=bias)
|
60 |
+
|
61 |
+
def forward(self, x, x_img):
|
62 |
+
x1 = self.conv1(x)
|
63 |
+
img = self.conv2(x) + x_img
|
64 |
+
x2 = torch.sigmoid(self.conv3(img))
|
65 |
+
x1 = x1*x2
|
66 |
+
x1 = x1+x
|
67 |
+
return x1, img
|
68 |
+
|
69 |
+
##########################################################################
|
70 |
+
## U-Net
|
71 |
+
|
72 |
+
class Encoder(nn.Module):
|
73 |
+
def __init__(self, n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff):
|
74 |
+
super(Encoder, self).__init__()
|
75 |
+
|
76 |
+
self.encoder_level1 = [CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(2)]
|
77 |
+
self.encoder_level2 = [CAB(n_feat+scale_unetfeats, kernel_size, reduction, bias=bias, act=act) for _ in range(2)]
|
78 |
+
self.encoder_level3 = [CAB(n_feat+(scale_unetfeats*2), kernel_size, reduction, bias=bias, act=act) for _ in range(2)]
|
79 |
+
|
80 |
+
self.encoder_level1 = nn.Sequential(*self.encoder_level1)
|
81 |
+
self.encoder_level2 = nn.Sequential(*self.encoder_level2)
|
82 |
+
self.encoder_level3 = nn.Sequential(*self.encoder_level3)
|
83 |
+
|
84 |
+
self.down12 = DownSample(n_feat, scale_unetfeats)
|
85 |
+
self.down23 = DownSample(n_feat+scale_unetfeats, scale_unetfeats)
|
86 |
+
|
87 |
+
# Cross Stage Feature Fusion (CSFF)
|
88 |
+
if csff:
|
89 |
+
self.csff_enc1 = nn.Conv2d(n_feat, n_feat, kernel_size=1, bias=bias)
|
90 |
+
self.csff_enc2 = nn.Conv2d(n_feat+scale_unetfeats, n_feat+scale_unetfeats, kernel_size=1, bias=bias)
|
91 |
+
self.csff_enc3 = nn.Conv2d(n_feat+(scale_unetfeats*2), n_feat+(scale_unetfeats*2), kernel_size=1, bias=bias)
|
92 |
+
|
93 |
+
self.csff_dec1 = nn.Conv2d(n_feat, n_feat, kernel_size=1, bias=bias)
|
94 |
+
self.csff_dec2 = nn.Conv2d(n_feat+scale_unetfeats, n_feat+scale_unetfeats, kernel_size=1, bias=bias)
|
95 |
+
self.csff_dec3 = nn.Conv2d(n_feat+(scale_unetfeats*2), n_feat+(scale_unetfeats*2), kernel_size=1, bias=bias)
|
96 |
+
|
97 |
+
def forward(self, x, encoder_outs=None, decoder_outs=None):
|
98 |
+
enc1 = self.encoder_level1(x)
|
99 |
+
if (encoder_outs is not None) and (decoder_outs is not None):
|
100 |
+
enc1 = enc1 + self.csff_enc1(encoder_outs[0]) + self.csff_dec1(decoder_outs[0])
|
101 |
+
|
102 |
+
x = self.down12(enc1)
|
103 |
+
|
104 |
+
enc2 = self.encoder_level2(x)
|
105 |
+
if (encoder_outs is not None) and (decoder_outs is not None):
|
106 |
+
enc2 = enc2 + self.csff_enc2(encoder_outs[1]) + self.csff_dec2(decoder_outs[1])
|
107 |
+
|
108 |
+
x = self.down23(enc2)
|
109 |
+
|
110 |
+
enc3 = self.encoder_level3(x)
|
111 |
+
if (encoder_outs is not None) and (decoder_outs is not None):
|
112 |
+
enc3 = enc3 + self.csff_enc3(encoder_outs[2]) + self.csff_dec3(decoder_outs[2])
|
113 |
+
|
114 |
+
return [enc1, enc2, enc3]
|
115 |
+
|
116 |
+
class Decoder(nn.Module):
|
117 |
+
def __init__(self, n_feat, kernel_size, reduction, act, bias, scale_unetfeats):
|
118 |
+
super(Decoder, self).__init__()
|
119 |
+
|
120 |
+
self.decoder_level1 = [CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(2)]
|
121 |
+
self.decoder_level2 = [CAB(n_feat+scale_unetfeats, kernel_size, reduction, bias=bias, act=act) for _ in range(2)]
|
122 |
+
self.decoder_level3 = [CAB(n_feat+(scale_unetfeats*2), kernel_size, reduction, bias=bias, act=act) for _ in range(2)]
|
123 |
+
|
124 |
+
self.decoder_level1 = nn.Sequential(*self.decoder_level1)
|
125 |
+
self.decoder_level2 = nn.Sequential(*self.decoder_level2)
|
126 |
+
self.decoder_level3 = nn.Sequential(*self.decoder_level3)
|
127 |
+
|
128 |
+
self.skip_attn1 = CAB(n_feat, kernel_size, reduction, bias=bias, act=act)
|
129 |
+
self.skip_attn2 = CAB(n_feat+scale_unetfeats, kernel_size, reduction, bias=bias, act=act)
|
130 |
+
|
131 |
+
self.up21 = SkipUpSample(n_feat, scale_unetfeats)
|
132 |
+
self.up32 = SkipUpSample(n_feat+scale_unetfeats, scale_unetfeats)
|
133 |
+
|
134 |
+
def forward(self, outs):
|
135 |
+
enc1, enc2, enc3 = outs
|
136 |
+
dec3 = self.decoder_level3(enc3)
|
137 |
+
|
138 |
+
x = self.up32(dec3, self.skip_attn2(enc2))
|
139 |
+
dec2 = self.decoder_level2(x)
|
140 |
+
|
141 |
+
x = self.up21(dec2, self.skip_attn1(enc1))
|
142 |
+
dec1 = self.decoder_level1(x)
|
143 |
+
|
144 |
+
return [dec1,dec2,dec3]
|
145 |
+
|
146 |
+
##########################################################################
|
147 |
+
##---------- Resizing Modules ----------
|
148 |
+
class DownSample(nn.Module):
|
149 |
+
def __init__(self, in_channels,s_factor):
|
150 |
+
super(DownSample, self).__init__()
|
151 |
+
self.down = nn.Sequential(nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=False),
|
152 |
+
nn.Conv2d(in_channels, in_channels+s_factor, 1, stride=1, padding=0, bias=False))
|
153 |
+
|
154 |
+
def forward(self, x):
|
155 |
+
x = self.down(x)
|
156 |
+
return x
|
157 |
+
|
158 |
+
class UpSample(nn.Module):
|
159 |
+
def __init__(self, in_channels,s_factor):
|
160 |
+
super(UpSample, self).__init__()
|
161 |
+
self.up = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
|
162 |
+
nn.Conv2d(in_channels+s_factor, in_channels, 1, stride=1, padding=0, bias=False))
|
163 |
+
|
164 |
+
def forward(self, x):
|
165 |
+
x = self.up(x)
|
166 |
+
return x
|
167 |
+
|
168 |
+
class SkipUpSample(nn.Module):
|
169 |
+
def __init__(self, in_channels,s_factor):
|
170 |
+
super(SkipUpSample, self).__init__()
|
171 |
+
self.up = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
|
172 |
+
nn.Conv2d(in_channels+s_factor, in_channels, 1, stride=1, padding=0, bias=False))
|
173 |
+
|
174 |
+
def forward(self, x, y):
|
175 |
+
x = self.up(x)
|
176 |
+
x = x + y
|
177 |
+
return x
|
178 |
+
|
179 |
+
##########################################################################
|
180 |
+
## Original Resolution Block (ORB)
|
181 |
+
class ORB(nn.Module):
|
182 |
+
def __init__(self, n_feat, kernel_size, reduction, act, bias, num_cab):
|
183 |
+
super(ORB, self).__init__()
|
184 |
+
modules_body = []
|
185 |
+
modules_body = [CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(num_cab)]
|
186 |
+
modules_body.append(conv(n_feat, n_feat, kernel_size))
|
187 |
+
self.body = nn.Sequential(*modules_body)
|
188 |
+
|
189 |
+
def forward(self, x):
|
190 |
+
res = self.body(x)
|
191 |
+
res += x
|
192 |
+
return res
|
193 |
+
|
194 |
+
##########################################################################
|
195 |
+
class ORSNet(nn.Module):
|
196 |
+
def __init__(self, n_feat, scale_orsnetfeats, kernel_size, reduction, act, bias, scale_unetfeats, num_cab):
|
197 |
+
super(ORSNet, self).__init__()
|
198 |
+
|
199 |
+
self.orb1 = ORB(n_feat+scale_orsnetfeats, kernel_size, reduction, act, bias, num_cab)
|
200 |
+
self.orb2 = ORB(n_feat+scale_orsnetfeats, kernel_size, reduction, act, bias, num_cab)
|
201 |
+
self.orb3 = ORB(n_feat+scale_orsnetfeats, kernel_size, reduction, act, bias, num_cab)
|
202 |
+
|
203 |
+
self.up_enc1 = UpSample(n_feat, scale_unetfeats)
|
204 |
+
self.up_dec1 = UpSample(n_feat, scale_unetfeats)
|
205 |
+
|
206 |
+
self.up_enc2 = nn.Sequential(UpSample(n_feat+scale_unetfeats, scale_unetfeats), UpSample(n_feat, scale_unetfeats))
|
207 |
+
self.up_dec2 = nn.Sequential(UpSample(n_feat+scale_unetfeats, scale_unetfeats), UpSample(n_feat, scale_unetfeats))
|
208 |
+
|
209 |
+
self.conv_enc1 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias)
|
210 |
+
self.conv_enc2 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias)
|
211 |
+
self.conv_enc3 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias)
|
212 |
+
|
213 |
+
self.conv_dec1 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias)
|
214 |
+
self.conv_dec2 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias)
|
215 |
+
self.conv_dec3 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias)
|
216 |
+
|
217 |
+
def forward(self, x, encoder_outs, decoder_outs):
|
218 |
+
x = self.orb1(x)
|
219 |
+
x = x + self.conv_enc1(encoder_outs[0]) + self.conv_dec1(decoder_outs[0])
|
220 |
+
|
221 |
+
x = self.orb2(x)
|
222 |
+
x = x + self.conv_enc2(self.up_enc1(encoder_outs[1])) + self.conv_dec2(self.up_dec1(decoder_outs[1]))
|
223 |
+
|
224 |
+
x = self.orb3(x)
|
225 |
+
x = x + self.conv_enc3(self.up_enc2(encoder_outs[2])) + self.conv_dec3(self.up_dec2(decoder_outs[2]))
|
226 |
+
|
227 |
+
return x
|
228 |
+
|
229 |
+
|
230 |
+
##########################################################################
|
231 |
+
class MPRNet(nn.Module):
|
232 |
+
def __init__(self, in_c=31, out_c=31, n_feat=31, scale_unetfeats=31, scale_orsnetfeats=31, num_cab=4, kernel_size=3, reduction=1, bias=False):
|
233 |
+
super(MPRNet, self).__init__()
|
234 |
+
|
235 |
+
self.conv_in = nn.Conv2d(3, in_c, kernel_size=kernel_size, padding=(kernel_size - 1) // 2,
|
236 |
+
bias=bias)
|
237 |
+
|
238 |
+
act=nn.PReLU()
|
239 |
+
self.shallow_feat1 = nn.Sequential(conv(in_c, n_feat, kernel_size, bias=bias), CAB(n_feat,kernel_size, reduction, bias=bias, act=act))
|
240 |
+
self.shallow_feat2 = nn.Sequential(conv(in_c, n_feat, kernel_size, bias=bias), CAB(n_feat,kernel_size, reduction, bias=bias, act=act))
|
241 |
+
self.shallow_feat3 = nn.Sequential(conv(in_c, n_feat, kernel_size, bias=bias), CAB(n_feat,kernel_size, reduction, bias=bias, act=act))
|
242 |
+
|
243 |
+
# Cross Stage Feature Fusion (CSFF)
|
244 |
+
self.stage1_encoder = Encoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff=False)
|
245 |
+
self.stage1_decoder = Decoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats)
|
246 |
+
|
247 |
+
self.stage2_encoder = Encoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff=True)
|
248 |
+
self.stage2_decoder = Decoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats)
|
249 |
+
|
250 |
+
self.stage3_orsnet = ORSNet(n_feat, scale_orsnetfeats, kernel_size, reduction, act, bias, scale_unetfeats, num_cab)
|
251 |
+
|
252 |
+
self.sam12 = SAM(n_feat, kernel_size=1, bias=bias)
|
253 |
+
self.sam23 = SAM(n_feat, kernel_size=1, bias=bias)
|
254 |
+
|
255 |
+
self.concat12 = conv(n_feat*2, n_feat, kernel_size, bias=bias)
|
256 |
+
self.concat23 = conv(n_feat*2, n_feat+scale_orsnetfeats, kernel_size, bias=bias)
|
257 |
+
self.tail = conv(n_feat+scale_orsnetfeats, out_c, kernel_size, bias=bias)
|
258 |
+
|
259 |
+
def forward(self, x3_img):
|
260 |
+
b, c, h_inp, w_inp = x3_img.shape
|
261 |
+
hb, wb = 8, 8
|
262 |
+
pad_h = (hb - h_inp % hb) % hb
|
263 |
+
pad_w = (wb - w_inp % wb) % wb
|
264 |
+
x3_img = F.pad(x3_img, [0, pad_w, 0, pad_h], mode='reflect')
|
265 |
+
x3_img = self.conv_in(x3_img)
|
266 |
+
|
267 |
+
# Original-resolution Image for Stage 3
|
268 |
+
H = x3_img.size(2)
|
269 |
+
W = x3_img.size(3)
|
270 |
+
|
271 |
+
# Multi-Patch Hierarchy: Split Image into four non-overlapping patches
|
272 |
+
|
273 |
+
# Two Patches for Stage 2
|
274 |
+
x2top_img = x3_img[:,:,0:int(H/2),:]
|
275 |
+
x2bot_img = x3_img[:,:,int(H/2):H,:]
|
276 |
+
|
277 |
+
# Four Patches for Stage 1
|
278 |
+
x1ltop_img = x2top_img[:,:,:,0:int(W/2)]
|
279 |
+
x1rtop_img = x2top_img[:,:,:,int(W/2):W]
|
280 |
+
x1lbot_img = x2bot_img[:,:,:,0:int(W/2)]
|
281 |
+
x1rbot_img = x2bot_img[:,:,:,int(W/2):W]
|
282 |
+
|
283 |
+
##-------------------------------------------
|
284 |
+
##-------------- Stage 1---------------------
|
285 |
+
##-------------------------------------------
|
286 |
+
## Compute Shallow Features
|
287 |
+
x1ltop = self.shallow_feat1(x1ltop_img)
|
288 |
+
x1rtop = self.shallow_feat1(x1rtop_img)
|
289 |
+
x1lbot = self.shallow_feat1(x1lbot_img)
|
290 |
+
x1rbot = self.shallow_feat1(x1rbot_img)
|
291 |
+
|
292 |
+
## Process features of all 4 patches with Encoder of Stage 1
|
293 |
+
feat1_ltop = self.stage1_encoder(x1ltop)
|
294 |
+
feat1_rtop = self.stage1_encoder(x1rtop)
|
295 |
+
feat1_lbot = self.stage1_encoder(x1lbot)
|
296 |
+
feat1_rbot = self.stage1_encoder(x1rbot)
|
297 |
+
|
298 |
+
## Concat deep features
|
299 |
+
feat1_top = [torch.cat((k,v), 3) for k,v in zip(feat1_ltop,feat1_rtop)]
|
300 |
+
feat1_bot = [torch.cat((k,v), 3) for k,v in zip(feat1_lbot,feat1_rbot)]
|
301 |
+
|
302 |
+
## Pass features through Decoder of Stage 1
|
303 |
+
res1_top = self.stage1_decoder(feat1_top)
|
304 |
+
res1_bot = self.stage1_decoder(feat1_bot)
|
305 |
+
|
306 |
+
## Apply Supervised Attention Module (SAM)
|
307 |
+
x2top_samfeats, stage1_img_top = self.sam12(res1_top[0], x2top_img)
|
308 |
+
x2bot_samfeats, stage1_img_bot = self.sam12(res1_bot[0], x2bot_img)
|
309 |
+
|
310 |
+
## Output image at Stage 1
|
311 |
+
stage1_img = torch.cat([stage1_img_top, stage1_img_bot],2)
|
312 |
+
##-------------------------------------------
|
313 |
+
##-------------- Stage 2---------------------
|
314 |
+
##-------------------------------------------
|
315 |
+
## Compute Shallow Features
|
316 |
+
x2top = self.shallow_feat2(x2top_img)
|
317 |
+
x2bot = self.shallow_feat2(x2bot_img)
|
318 |
+
|
319 |
+
## Concatenate SAM features of Stage 1 with shallow features of Stage 2
|
320 |
+
x2top_cat = self.concat12(torch.cat([x2top, x2top_samfeats], 1))
|
321 |
+
x2bot_cat = self.concat12(torch.cat([x2bot, x2bot_samfeats], 1))
|
322 |
+
|
323 |
+
## Process features of both patches with Encoder of Stage 2
|
324 |
+
feat2_top = self.stage2_encoder(x2top_cat, feat1_top, res1_top)
|
325 |
+
feat2_bot = self.stage2_encoder(x2bot_cat, feat1_bot, res1_bot)
|
326 |
+
|
327 |
+
## Concat deep features
|
328 |
+
feat2 = [torch.cat((k,v), 2) for k,v in zip(feat2_top,feat2_bot)]
|
329 |
+
|
330 |
+
## Pass features through Decoder of Stage 2
|
331 |
+
res2 = self.stage2_decoder(feat2)
|
332 |
+
|
333 |
+
## Apply SAM
|
334 |
+
x3_samfeats, stage2_img = self.sam23(res2[0], x3_img)
|
335 |
+
|
336 |
+
|
337 |
+
##-------------------------------------------
|
338 |
+
##-------------- Stage 3---------------------
|
339 |
+
##-------------------------------------------
|
340 |
+
## Compute Shallow Features
|
341 |
+
x3 = self.shallow_feat3(x3_img)
|
342 |
+
|
343 |
+
## Concatenate SAM features of Stage 2 with shallow features of Stage 3
|
344 |
+
x3_cat = self.concat23(torch.cat([x3, x3_samfeats], 1))
|
345 |
+
|
346 |
+
x3_cat = self.stage3_orsnet(x3_cat, feat2, res2)
|
347 |
+
|
348 |
+
stage3_img = self.tail(x3_cat)
|
349 |
+
|
350 |
+
return (stage3_img + x3_img)[:, :, :h_inp, :w_inp]
|
test_challenge_code/architecture/MST.py
ADDED
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from einops import rearrange
|
5 |
+
import math
|
6 |
+
import warnings
|
7 |
+
from torch.nn.init import _calculate_fan_in_and_fan_out
|
8 |
+
|
9 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
10 |
+
def norm_cdf(x):
|
11 |
+
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
12 |
+
|
13 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
14 |
+
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
15 |
+
"The distribution of values may be incorrect.",
|
16 |
+
stacklevel=2)
|
17 |
+
with torch.no_grad():
|
18 |
+
l = norm_cdf((a - mean) / std)
|
19 |
+
u = norm_cdf((b - mean) / std)
|
20 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
21 |
+
tensor.erfinv_()
|
22 |
+
tensor.mul_(std * math.sqrt(2.))
|
23 |
+
tensor.add_(mean)
|
24 |
+
tensor.clamp_(min=a, max=b)
|
25 |
+
return tensor
|
26 |
+
|
27 |
+
|
28 |
+
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
29 |
+
# type: (Tensor, float, float, float, float) -> Tensor
|
30 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
31 |
+
|
32 |
+
|
33 |
+
def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
|
34 |
+
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
|
35 |
+
if mode == 'fan_in':
|
36 |
+
denom = fan_in
|
37 |
+
elif mode == 'fan_out':
|
38 |
+
denom = fan_out
|
39 |
+
elif mode == 'fan_avg':
|
40 |
+
denom = (fan_in + fan_out) / 2
|
41 |
+
variance = scale / denom
|
42 |
+
if distribution == "truncated_normal":
|
43 |
+
trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978)
|
44 |
+
elif distribution == "normal":
|
45 |
+
tensor.normal_(std=math.sqrt(variance))
|
46 |
+
elif distribution == "uniform":
|
47 |
+
bound = math.sqrt(3 * variance)
|
48 |
+
tensor.uniform_(-bound, bound)
|
49 |
+
else:
|
50 |
+
raise ValueError(f"invalid distribution {distribution}")
|
51 |
+
|
52 |
+
|
53 |
+
def lecun_normal_(tensor):
|
54 |
+
variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal')
|
55 |
+
|
56 |
+
|
57 |
+
class PreNorm(nn.Module):
|
58 |
+
def __init__(self, dim, fn):
|
59 |
+
super().__init__()
|
60 |
+
self.fn = fn
|
61 |
+
self.norm = nn.LayerNorm(dim)
|
62 |
+
|
63 |
+
def forward(self, x, *args, **kwargs):
|
64 |
+
x = self.norm(x)
|
65 |
+
return self.fn(x, *args, **kwargs)
|
66 |
+
|
67 |
+
|
68 |
+
class GELU(nn.Module):
|
69 |
+
def forward(self, x):
|
70 |
+
return F.gelu(x)
|
71 |
+
|
72 |
+
def conv(in_channels, out_channels, kernel_size, bias=False, padding = 1, stride = 1):
|
73 |
+
return nn.Conv2d(
|
74 |
+
in_channels, out_channels, kernel_size,
|
75 |
+
padding=(kernel_size//2), bias=bias, stride=stride)
|
76 |
+
|
77 |
+
|
78 |
+
def shift_back(inputs,step=2): # input [bs,28,256,310] output [bs, 28, 256, 256]
|
79 |
+
[bs, nC, row, col] = inputs.shape
|
80 |
+
down_sample = 256//row
|
81 |
+
step = float(step)/float(down_sample*down_sample)
|
82 |
+
out_col = row
|
83 |
+
for i in range(nC):
|
84 |
+
inputs[:,i,:,:out_col] = \
|
85 |
+
inputs[:,i,:,int(step*i):int(step*i)+out_col]
|
86 |
+
return inputs[:, :, :, :out_col]
|
87 |
+
|
88 |
+
class MaskGuidedMechanism(nn.Module):
|
89 |
+
def __init__(
|
90 |
+
self, n_feat):
|
91 |
+
super(MaskGuidedMechanism, self).__init__()
|
92 |
+
|
93 |
+
self.conv1 = nn.Conv2d(n_feat, n_feat, kernel_size=1, bias=True)
|
94 |
+
self.conv2 = nn.Conv2d(n_feat, n_feat, kernel_size=1, bias=True)
|
95 |
+
self.depth_conv = nn.Conv2d(n_feat, n_feat, kernel_size=5, padding=2, bias=True, groups=n_feat)
|
96 |
+
|
97 |
+
def forward(self, mask_shift):
|
98 |
+
# x: b,c,h,w
|
99 |
+
[bs, nC, row, col] = mask_shift.shape
|
100 |
+
mask_shift = self.conv1(mask_shift)
|
101 |
+
attn_map = torch.sigmoid(self.depth_conv(self.conv2(mask_shift)))
|
102 |
+
res = mask_shift * attn_map
|
103 |
+
mask_emb = res + mask_shift
|
104 |
+
return mask_emb
|
105 |
+
|
106 |
+
class MS_MSA(nn.Module):
|
107 |
+
def __init__(
|
108 |
+
self,
|
109 |
+
dim,
|
110 |
+
dim_head,
|
111 |
+
heads,
|
112 |
+
):
|
113 |
+
super().__init__()
|
114 |
+
self.num_heads = heads
|
115 |
+
self.dim_head = dim_head
|
116 |
+
self.to_q = nn.Linear(dim, dim_head * heads, bias=False)
|
117 |
+
self.to_k = nn.Linear(dim, dim_head * heads, bias=False)
|
118 |
+
self.to_v = nn.Linear(dim, dim_head * heads, bias=False)
|
119 |
+
self.rescale = nn.Parameter(torch.ones(heads, 1, 1))
|
120 |
+
self.proj = nn.Linear(dim_head * heads, dim, bias=True)
|
121 |
+
self.pos_emb = nn.Sequential(
|
122 |
+
nn.Conv2d(dim, dim, 3, 1, 1, bias=False, groups=dim),
|
123 |
+
GELU(),
|
124 |
+
nn.Conv2d(dim, dim, 3, 1, 1, bias=False, groups=dim),
|
125 |
+
)
|
126 |
+
self.mm = MaskGuidedMechanism(dim)
|
127 |
+
self.dim = dim
|
128 |
+
|
129 |
+
def forward(self, x_in, mask=None):
|
130 |
+
"""
|
131 |
+
x_in: [b,h,w,c]
|
132 |
+
mask: [1,h,w,c]
|
133 |
+
return out: [b,h,w,c]
|
134 |
+
"""
|
135 |
+
b, h, w, c = x_in.shape
|
136 |
+
x = x_in.reshape(b,h*w,c)
|
137 |
+
q_inp = self.to_q(x)
|
138 |
+
k_inp = self.to_k(x)
|
139 |
+
v_inp = self.to_v(x)
|
140 |
+
mask_attn = self.mm(mask.permute(0,3,1,2)).permute(0,2,3,1)
|
141 |
+
if b != 0:
|
142 |
+
mask_attn = (mask_attn[0, :, :, :]).expand([b, h, w, c])
|
143 |
+
q, k, v, mask_attn = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.num_heads),
|
144 |
+
(q_inp, k_inp, v_inp, mask_attn.flatten(1, 2)))
|
145 |
+
v = v * mask_attn
|
146 |
+
# q: b,heads,hw,c
|
147 |
+
q = q.transpose(-2, -1)
|
148 |
+
k = k.transpose(-2, -1)
|
149 |
+
v = v.transpose(-2, -1)
|
150 |
+
q = F.normalize(q, dim=-1, p=2)
|
151 |
+
k = F.normalize(k, dim=-1, p=2)
|
152 |
+
attn = (k @ q.transpose(-2, -1)) # A = K^T*Q
|
153 |
+
attn = attn * self.rescale
|
154 |
+
attn = attn.softmax(dim=-1)
|
155 |
+
x = attn @ v # b,heads,d,hw
|
156 |
+
x = x.permute(0, 3, 1, 2) # Transpose
|
157 |
+
x = x.reshape(b, h * w, self.num_heads * self.dim_head)
|
158 |
+
out_c = self.proj(x).view(b, h, w, c)
|
159 |
+
out_p = self.pos_emb(v_inp.reshape(b,h,w,c).permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
|
160 |
+
out = out_c + out_p
|
161 |
+
|
162 |
+
return out
|
163 |
+
|
164 |
+
class FeedForward(nn.Module):
|
165 |
+
def __init__(self, dim, mult=4):
|
166 |
+
super().__init__()
|
167 |
+
self.net = nn.Sequential(
|
168 |
+
nn.Conv2d(dim, dim * mult, 1, 1, bias=False),
|
169 |
+
GELU(),
|
170 |
+
nn.Conv2d(dim * mult, dim * mult, 3, 1, 1, bias=False, groups=dim * mult),
|
171 |
+
GELU(),
|
172 |
+
nn.Conv2d(dim * mult, dim, 1, 1, bias=False),
|
173 |
+
)
|
174 |
+
|
175 |
+
def forward(self, x):
|
176 |
+
"""
|
177 |
+
x: [b,h,w,c]
|
178 |
+
return out: [b,h,w,c]
|
179 |
+
"""
|
180 |
+
out = self.net(x.permute(0, 3, 1, 2))
|
181 |
+
return out.permute(0, 2, 3, 1)
|
182 |
+
|
183 |
+
class MSAB(nn.Module):
|
184 |
+
def __init__(
|
185 |
+
self,
|
186 |
+
dim,
|
187 |
+
dim_head,
|
188 |
+
heads,
|
189 |
+
num_blocks,
|
190 |
+
):
|
191 |
+
super().__init__()
|
192 |
+
self.blocks = nn.ModuleList([])
|
193 |
+
for _ in range(num_blocks):
|
194 |
+
self.blocks.append(nn.ModuleList([
|
195 |
+
MS_MSA(dim=dim, dim_head=dim_head, heads=heads),
|
196 |
+
PreNorm(dim, FeedForward(dim=dim))
|
197 |
+
]))
|
198 |
+
|
199 |
+
def forward(self, x, mask):
|
200 |
+
"""
|
201 |
+
x: [b,c,h,w]
|
202 |
+
return out: [b,c,h,w]
|
203 |
+
"""
|
204 |
+
x = x.permute(0, 2, 3, 1)
|
205 |
+
for (attn, ff) in self.blocks:
|
206 |
+
x = attn(x, mask=mask.permute(0, 2, 3, 1)) + x
|
207 |
+
x = ff(x) + x
|
208 |
+
out = x.permute(0, 3, 1, 2)
|
209 |
+
return out
|
210 |
+
|
211 |
+
class MST(nn.Module):
|
212 |
+
def __init__(self, dim, stage, num_blocks):
|
213 |
+
super(MST, self).__init__()
|
214 |
+
self.dim = dim
|
215 |
+
self.stage = stage
|
216 |
+
|
217 |
+
# Input projection
|
218 |
+
self.embedding_1 = nn.Conv2d(3, self.dim, 3, 1, 1, bias=False)
|
219 |
+
self.embedding_2 = nn.Conv2d(3, self.dim, 3, 1, 1, bias=False)
|
220 |
+
|
221 |
+
# Encoder
|
222 |
+
self.encoder_layers = nn.ModuleList([])
|
223 |
+
dim_stage = dim
|
224 |
+
for i in range(stage):
|
225 |
+
self.encoder_layers.append(nn.ModuleList([
|
226 |
+
MSAB(
|
227 |
+
dim=dim_stage, num_blocks=num_blocks[i], dim_head=dim, heads=dim_stage // dim),
|
228 |
+
nn.Conv2d(dim_stage, dim_stage * 2, 4, 2, 1, bias=False),
|
229 |
+
nn.Conv2d(dim_stage, dim_stage * 2, 4, 2, 1, bias=False)
|
230 |
+
]))
|
231 |
+
dim_stage *= 2
|
232 |
+
|
233 |
+
# Bottleneck
|
234 |
+
self.bottleneck = MSAB(
|
235 |
+
dim=dim_stage, dim_head=dim, heads=dim_stage // dim, num_blocks=num_blocks[-1])
|
236 |
+
|
237 |
+
# Decoder
|
238 |
+
self.decoder_layers = nn.ModuleList([])
|
239 |
+
for i in range(stage):
|
240 |
+
self.decoder_layers.append(nn.ModuleList([
|
241 |
+
nn.ConvTranspose2d(dim_stage, dim_stage // 2, stride=2, kernel_size=2, padding=0, output_padding=0),
|
242 |
+
nn.Conv2d(dim_stage, dim_stage // 2, 1, 1, bias=False),
|
243 |
+
MSAB(
|
244 |
+
dim=dim_stage // 2, num_blocks=num_blocks[stage - 1 - i], dim_head=dim,
|
245 |
+
heads=(dim_stage // 2) // dim),
|
246 |
+
]))
|
247 |
+
dim_stage //= 2
|
248 |
+
|
249 |
+
# Output projection
|
250 |
+
self.mapping = nn.Conv2d(self.dim, 31, 3, 1, 1, bias=False)
|
251 |
+
|
252 |
+
#### activation function
|
253 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
254 |
+
|
255 |
+
def forward(self, x):
|
256 |
+
"""
|
257 |
+
x: [b,c,h,w]
|
258 |
+
return out:[b,c,h,w]
|
259 |
+
"""
|
260 |
+
b, c, h_inp, w_inp = x.shape
|
261 |
+
hb, wb = 8, 8
|
262 |
+
pad_h = (hb - h_inp % hb) % hb
|
263 |
+
pad_w = (wb - w_inp % wb) % wb
|
264 |
+
x = F.pad(x, [0, pad_w, 0, pad_h], mode='reflect')
|
265 |
+
|
266 |
+
# Embedding
|
267 |
+
mask = self.lrelu(self.embedding_1(x))
|
268 |
+
x = self.lrelu(self.embedding_2(x))
|
269 |
+
fea = x
|
270 |
+
|
271 |
+
# Encoder
|
272 |
+
fea_encoder = []
|
273 |
+
masks = []
|
274 |
+
for (MSAB, FeaDownSample, MaskDownSample) in self.encoder_layers:
|
275 |
+
fea = MSAB(fea, mask)
|
276 |
+
masks.append(mask)
|
277 |
+
fea_encoder.append(fea)
|
278 |
+
fea = FeaDownSample(fea)
|
279 |
+
mask = MaskDownSample(mask)
|
280 |
+
|
281 |
+
# Bottleneck
|
282 |
+
fea = self.bottleneck(fea, mask)
|
283 |
+
|
284 |
+
# Decoder
|
285 |
+
for i, (FeaUpSample, Fution, LeWinBlcok) in enumerate(self.decoder_layers):
|
286 |
+
fea = FeaUpSample(fea)
|
287 |
+
fea = Fution(torch.cat([fea, fea_encoder[self.stage-1-i]], dim=1))
|
288 |
+
mask = masks[self.stage - 1 - i]
|
289 |
+
fea = LeWinBlcok(fea, mask)
|
290 |
+
|
291 |
+
# Mapping
|
292 |
+
out = self.mapping(fea) + x
|
293 |
+
|
294 |
+
return out[:, :, :h_inp, :w_inp]
|
295 |
+
|
296 |
+
|
297 |
+
|
298 |
+
|
299 |
+
|
300 |
+
|
301 |
+
|
302 |
+
|
303 |
+
|
304 |
+
|
305 |
+
|
306 |
+
|
307 |
+
|
308 |
+
|
309 |
+
|
310 |
+
|
311 |
+
|
312 |
+
|
313 |
+
|
test_challenge_code/architecture/MST_Plus_Plus.py
ADDED
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from einops import rearrange
|
5 |
+
import math
|
6 |
+
import warnings
|
7 |
+
from torch.nn.init import _calculate_fan_in_and_fan_out
|
8 |
+
|
9 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
10 |
+
def norm_cdf(x):
|
11 |
+
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
12 |
+
|
13 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
14 |
+
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
15 |
+
"The distribution of values may be incorrect.",
|
16 |
+
stacklevel=2)
|
17 |
+
with torch.no_grad():
|
18 |
+
l = norm_cdf((a - mean) / std)
|
19 |
+
u = norm_cdf((b - mean) / std)
|
20 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
21 |
+
tensor.erfinv_()
|
22 |
+
tensor.mul_(std * math.sqrt(2.))
|
23 |
+
tensor.add_(mean)
|
24 |
+
tensor.clamp_(min=a, max=b)
|
25 |
+
return tensor
|
26 |
+
|
27 |
+
|
28 |
+
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
29 |
+
# type: (Tensor, float, float, float, float) -> Tensor
|
30 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
31 |
+
|
32 |
+
|
33 |
+
def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
|
34 |
+
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
|
35 |
+
if mode == 'fan_in':
|
36 |
+
denom = fan_in
|
37 |
+
elif mode == 'fan_out':
|
38 |
+
denom = fan_out
|
39 |
+
elif mode == 'fan_avg':
|
40 |
+
denom = (fan_in + fan_out) / 2
|
41 |
+
variance = scale / denom
|
42 |
+
if distribution == "truncated_normal":
|
43 |
+
trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978)
|
44 |
+
elif distribution == "normal":
|
45 |
+
tensor.normal_(std=math.sqrt(variance))
|
46 |
+
elif distribution == "uniform":
|
47 |
+
bound = math.sqrt(3 * variance)
|
48 |
+
tensor.uniform_(-bound, bound)
|
49 |
+
else:
|
50 |
+
raise ValueError(f"invalid distribution {distribution}")
|
51 |
+
|
52 |
+
|
53 |
+
def lecun_normal_(tensor):
|
54 |
+
variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal')
|
55 |
+
|
56 |
+
|
57 |
+
class PreNorm(nn.Module):
|
58 |
+
def __init__(self, dim, fn):
|
59 |
+
super().__init__()
|
60 |
+
self.fn = fn
|
61 |
+
self.norm = nn.LayerNorm(dim)
|
62 |
+
|
63 |
+
def forward(self, x, *args, **kwargs):
|
64 |
+
x = self.norm(x)
|
65 |
+
return self.fn(x, *args, **kwargs)
|
66 |
+
|
67 |
+
|
68 |
+
class GELU(nn.Module):
|
69 |
+
def forward(self, x):
|
70 |
+
return F.gelu(x)
|
71 |
+
|
72 |
+
def conv(in_channels, out_channels, kernel_size, bias=False, padding = 1, stride = 1):
|
73 |
+
return nn.Conv2d(
|
74 |
+
in_channels, out_channels, kernel_size,
|
75 |
+
padding=(kernel_size//2), bias=bias, stride=stride)
|
76 |
+
|
77 |
+
|
78 |
+
def shift_back(inputs,step=2): # input [bs,28,256,310] output [bs, 28, 256, 256]
|
79 |
+
[bs, nC, row, col] = inputs.shape
|
80 |
+
down_sample = 256//row
|
81 |
+
step = float(step)/float(down_sample*down_sample)
|
82 |
+
out_col = row
|
83 |
+
for i in range(nC):
|
84 |
+
inputs[:,i,:,:out_col] = \
|
85 |
+
inputs[:,i,:,int(step*i):int(step*i)+out_col]
|
86 |
+
return inputs[:, :, :, :out_col]
|
87 |
+
|
88 |
+
class MS_MSA(nn.Module):
|
89 |
+
def __init__(
|
90 |
+
self,
|
91 |
+
dim,
|
92 |
+
dim_head,
|
93 |
+
heads,
|
94 |
+
):
|
95 |
+
super().__init__()
|
96 |
+
self.num_heads = heads
|
97 |
+
self.dim_head = dim_head
|
98 |
+
self.to_q = nn.Linear(dim, dim_head * heads, bias=False)
|
99 |
+
self.to_k = nn.Linear(dim, dim_head * heads, bias=False)
|
100 |
+
self.to_v = nn.Linear(dim, dim_head * heads, bias=False)
|
101 |
+
self.rescale = nn.Parameter(torch.ones(heads, 1, 1))
|
102 |
+
self.proj = nn.Linear(dim_head * heads, dim, bias=True)
|
103 |
+
self.pos_emb = nn.Sequential(
|
104 |
+
nn.Conv2d(dim, dim, 3, 1, 1, bias=False, groups=dim),
|
105 |
+
GELU(),
|
106 |
+
nn.Conv2d(dim, dim, 3, 1, 1, bias=False, groups=dim),
|
107 |
+
)
|
108 |
+
self.dim = dim
|
109 |
+
|
110 |
+
def forward(self, x_in):
|
111 |
+
"""
|
112 |
+
x_in: [b,h,w,c]
|
113 |
+
return out: [b,h,w,c]
|
114 |
+
"""
|
115 |
+
b, h, w, c = x_in.shape
|
116 |
+
x = x_in.reshape(b,h*w,c)
|
117 |
+
q_inp = self.to_q(x)
|
118 |
+
k_inp = self.to_k(x)
|
119 |
+
v_inp = self.to_v(x)
|
120 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.num_heads),
|
121 |
+
(q_inp, k_inp, v_inp))
|
122 |
+
v = v
|
123 |
+
# q: b,heads,hw,c
|
124 |
+
q = q.transpose(-2, -1)
|
125 |
+
k = k.transpose(-2, -1)
|
126 |
+
v = v.transpose(-2, -1)
|
127 |
+
q = F.normalize(q, dim=-1, p=2)
|
128 |
+
k = F.normalize(k, dim=-1, p=2)
|
129 |
+
attn = (k @ q.transpose(-2, -1)) # A = K^T*Q
|
130 |
+
attn = attn * self.rescale
|
131 |
+
attn = attn.softmax(dim=-1)
|
132 |
+
x = attn @ v # b,heads,d,hw
|
133 |
+
x = x.permute(0, 3, 1, 2) # Transpose
|
134 |
+
x = x.reshape(b, h * w, self.num_heads * self.dim_head)
|
135 |
+
out_c = self.proj(x).view(b, h, w, c)
|
136 |
+
out_p = self.pos_emb(v_inp.reshape(b,h,w,c).permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
|
137 |
+
out = out_c + out_p
|
138 |
+
|
139 |
+
return out
|
140 |
+
|
141 |
+
class FeedForward(nn.Module):
|
142 |
+
def __init__(self, dim, mult=4):
|
143 |
+
super().__init__()
|
144 |
+
self.net = nn.Sequential(
|
145 |
+
nn.Conv2d(dim, dim * mult, 1, 1, bias=False),
|
146 |
+
GELU(),
|
147 |
+
nn.Conv2d(dim * mult, dim * mult, 3, 1, 1, bias=False, groups=dim * mult),
|
148 |
+
GELU(),
|
149 |
+
nn.Conv2d(dim * mult, dim, 1, 1, bias=False),
|
150 |
+
)
|
151 |
+
|
152 |
+
def forward(self, x):
|
153 |
+
"""
|
154 |
+
x: [b,h,w,c]
|
155 |
+
return out: [b,h,w,c]
|
156 |
+
"""
|
157 |
+
out = self.net(x.permute(0, 3, 1, 2))
|
158 |
+
return out.permute(0, 2, 3, 1)
|
159 |
+
|
160 |
+
class MSAB(nn.Module):
|
161 |
+
def __init__(
|
162 |
+
self,
|
163 |
+
dim,
|
164 |
+
dim_head,
|
165 |
+
heads,
|
166 |
+
num_blocks,
|
167 |
+
):
|
168 |
+
super().__init__()
|
169 |
+
self.blocks = nn.ModuleList([])
|
170 |
+
for _ in range(num_blocks):
|
171 |
+
self.blocks.append(nn.ModuleList([
|
172 |
+
MS_MSA(dim=dim, dim_head=dim_head, heads=heads),
|
173 |
+
PreNorm(dim, FeedForward(dim=dim))
|
174 |
+
]))
|
175 |
+
|
176 |
+
def forward(self, x):
|
177 |
+
"""
|
178 |
+
x: [b,c,h,w]
|
179 |
+
return out: [b,c,h,w]
|
180 |
+
"""
|
181 |
+
x = x.permute(0, 2, 3, 1)
|
182 |
+
for (attn, ff) in self.blocks:
|
183 |
+
x = attn(x) + x
|
184 |
+
x = ff(x) + x
|
185 |
+
out = x.permute(0, 3, 1, 2)
|
186 |
+
return out
|
187 |
+
|
188 |
+
class MST(nn.Module):
|
189 |
+
def __init__(self, in_dim=31, out_dim=31, dim=31, stage=2, num_blocks=[2,4,4]):
|
190 |
+
super(MST, self).__init__()
|
191 |
+
self.dim = dim
|
192 |
+
self.stage = stage
|
193 |
+
|
194 |
+
# Input projection
|
195 |
+
self.embedding = nn.Conv2d(in_dim, self.dim, 3, 1, 1, bias=False)
|
196 |
+
|
197 |
+
# Encoder
|
198 |
+
self.encoder_layers = nn.ModuleList([])
|
199 |
+
dim_stage = dim
|
200 |
+
for i in range(stage):
|
201 |
+
self.encoder_layers.append(nn.ModuleList([
|
202 |
+
MSAB(
|
203 |
+
dim=dim_stage, num_blocks=num_blocks[i], dim_head=dim, heads=dim_stage // dim),
|
204 |
+
nn.Conv2d(dim_stage, dim_stage * 2, 4, 2, 1, bias=False),
|
205 |
+
]))
|
206 |
+
dim_stage *= 2
|
207 |
+
|
208 |
+
# Bottleneck
|
209 |
+
self.bottleneck = MSAB(
|
210 |
+
dim=dim_stage, dim_head=dim, heads=dim_stage // dim, num_blocks=num_blocks[-1])
|
211 |
+
|
212 |
+
# Decoder
|
213 |
+
self.decoder_layers = nn.ModuleList([])
|
214 |
+
for i in range(stage):
|
215 |
+
self.decoder_layers.append(nn.ModuleList([
|
216 |
+
nn.ConvTranspose2d(dim_stage, dim_stage // 2, stride=2, kernel_size=2, padding=0, output_padding=0),
|
217 |
+
nn.Conv2d(dim_stage, dim_stage // 2, 1, 1, bias=False),
|
218 |
+
MSAB(
|
219 |
+
dim=dim_stage // 2, num_blocks=num_blocks[stage - 1 - i], dim_head=dim,
|
220 |
+
heads=(dim_stage // 2) // dim),
|
221 |
+
]))
|
222 |
+
dim_stage //= 2
|
223 |
+
|
224 |
+
# Output projection
|
225 |
+
self.mapping = nn.Conv2d(self.dim, out_dim, 3, 1, 1, bias=False)
|
226 |
+
|
227 |
+
#### activation function
|
228 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
229 |
+
self.apply(self._init_weights)
|
230 |
+
|
231 |
+
def _init_weights(self, m):
|
232 |
+
if isinstance(m, nn.Linear):
|
233 |
+
trunc_normal_(m.weight, std=.02)
|
234 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
235 |
+
nn.init.constant_(m.bias, 0)
|
236 |
+
elif isinstance(m, nn.LayerNorm):
|
237 |
+
nn.init.constant_(m.bias, 0)
|
238 |
+
nn.init.constant_(m.weight, 1.0)
|
239 |
+
|
240 |
+
def forward(self, x):
|
241 |
+
"""
|
242 |
+
x: [b,c,h,w]
|
243 |
+
return out:[b,c,h,w]
|
244 |
+
"""
|
245 |
+
|
246 |
+
# Embedding
|
247 |
+
fea = self.embedding(x)
|
248 |
+
|
249 |
+
# Encoder
|
250 |
+
fea_encoder = []
|
251 |
+
for (MSAB, FeaDownSample) in self.encoder_layers:
|
252 |
+
fea = MSAB(fea)
|
253 |
+
fea_encoder.append(fea)
|
254 |
+
fea = FeaDownSample(fea)
|
255 |
+
|
256 |
+
# Bottleneck
|
257 |
+
fea = self.bottleneck(fea)
|
258 |
+
|
259 |
+
# Decoder
|
260 |
+
for i, (FeaUpSample, Fution, LeWinBlcok) in enumerate(self.decoder_layers):
|
261 |
+
fea = FeaUpSample(fea)
|
262 |
+
fea = Fution(torch.cat([fea, fea_encoder[self.stage-1-i]], dim=1))
|
263 |
+
fea = LeWinBlcok(fea)
|
264 |
+
|
265 |
+
# Mapping
|
266 |
+
out = self.mapping(fea) + x
|
267 |
+
|
268 |
+
return out
|
269 |
+
|
270 |
+
class MST_Plus_Plus(nn.Module):
|
271 |
+
def __init__(self, in_channels=3, out_channels=31, n_feat=31, stage=3):
|
272 |
+
super(MST_Plus_Plus, self).__init__()
|
273 |
+
self.stage = stage
|
274 |
+
self.conv_in = nn.Conv2d(in_channels, n_feat, kernel_size=3, padding=(3 - 1) // 2,bias=False)
|
275 |
+
modules_body = [MST(dim=31, stage=2, num_blocks=[1,1,1]) for _ in range(stage)]
|
276 |
+
self.body = nn.Sequential(*modules_body)
|
277 |
+
self.conv_out = nn.Conv2d(n_feat, out_channels, kernel_size=3, padding=(3 - 1) // 2,bias=False)
|
278 |
+
|
279 |
+
def forward(self, x):
|
280 |
+
"""
|
281 |
+
x: [b,c,h,w]
|
282 |
+
return out:[b,c,h,w]
|
283 |
+
"""
|
284 |
+
b, c, h_inp, w_inp = x.shape
|
285 |
+
hb, wb = 8, 8
|
286 |
+
pad_h = (hb - h_inp % hb) % hb
|
287 |
+
pad_w = (wb - w_inp % wb) % wb
|
288 |
+
x = F.pad(x, [0, pad_w, 0, pad_h], mode='reflect')
|
289 |
+
x = self.conv_in(x)
|
290 |
+
h = self.body(x)
|
291 |
+
h = self.conv_out(h)
|
292 |
+
h += x
|
293 |
+
return h[:, :, :h_inp, :w_inp]
|
294 |
+
|
295 |
+
|
296 |
+
|
297 |
+
|
298 |
+
|
299 |
+
|
300 |
+
|
301 |
+
|
302 |
+
|
303 |
+
|
304 |
+
|
305 |
+
|
306 |
+
|
307 |
+
|
test_challenge_code/architecture/Restormer.py
ADDED
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numbers
|
5 |
+
from einops import rearrange
|
6 |
+
|
7 |
+
|
8 |
+
##########################################################################
|
9 |
+
## Layer Norm
|
10 |
+
|
11 |
+
def to_3d(x):
|
12 |
+
return rearrange(x, 'b c h w -> b (h w) c')
|
13 |
+
|
14 |
+
|
15 |
+
def to_4d(x, h, w):
|
16 |
+
return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
|
17 |
+
|
18 |
+
|
19 |
+
class BiasFree_LayerNorm(nn.Module):
|
20 |
+
def __init__(self, normalized_shape):
|
21 |
+
super(BiasFree_LayerNorm, self).__init__()
|
22 |
+
if isinstance(normalized_shape, numbers.Integral):
|
23 |
+
normalized_shape = (normalized_shape,)
|
24 |
+
normalized_shape = torch.Size(normalized_shape)
|
25 |
+
|
26 |
+
assert len(normalized_shape) == 1
|
27 |
+
|
28 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
29 |
+
self.normalized_shape = normalized_shape
|
30 |
+
|
31 |
+
def forward(self, x):
|
32 |
+
sigma = x.var(-1, keepdim=True, unbiased=False)
|
33 |
+
return x / torch.sqrt(sigma + 1e-5) * self.weight
|
34 |
+
|
35 |
+
|
36 |
+
class WithBias_LayerNorm(nn.Module):
|
37 |
+
def __init__(self, normalized_shape):
|
38 |
+
super(WithBias_LayerNorm, self).__init__()
|
39 |
+
if isinstance(normalized_shape, numbers.Integral):
|
40 |
+
normalized_shape = (normalized_shape,)
|
41 |
+
normalized_shape = torch.Size(normalized_shape)
|
42 |
+
|
43 |
+
assert len(normalized_shape) == 1
|
44 |
+
|
45 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
46 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
47 |
+
self.normalized_shape = normalized_shape
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
mu = x.mean(-1, keepdim=True)
|
51 |
+
sigma = x.var(-1, keepdim=True, unbiased=False)
|
52 |
+
return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias
|
53 |
+
|
54 |
+
|
55 |
+
class LayerNorm(nn.Module):
|
56 |
+
def __init__(self, dim, LayerNorm_type):
|
57 |
+
super(LayerNorm, self).__init__()
|
58 |
+
if LayerNorm_type == 'BiasFree':
|
59 |
+
self.body = BiasFree_LayerNorm(dim)
|
60 |
+
else:
|
61 |
+
self.body = WithBias_LayerNorm(dim)
|
62 |
+
|
63 |
+
def forward(self, x):
|
64 |
+
h, w = x.shape[-2:]
|
65 |
+
return to_4d(self.body(to_3d(x)), h, w)
|
66 |
+
|
67 |
+
|
68 |
+
##########################################################################
|
69 |
+
## Gated-Dconv Feed-Forward Network (GDFN)
|
70 |
+
class FeedForward(nn.Module):
|
71 |
+
def __init__(self, dim, ffn_expansion_factor, bias):
|
72 |
+
super(FeedForward, self).__init__()
|
73 |
+
|
74 |
+
hidden_features = int(dim * ffn_expansion_factor)
|
75 |
+
|
76 |
+
self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias)
|
77 |
+
|
78 |
+
self.dwconv = nn.Conv2d(hidden_features * 2, hidden_features * 2, kernel_size=3, stride=1, padding=1,
|
79 |
+
groups=hidden_features * 2, bias=bias)
|
80 |
+
|
81 |
+
self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
|
82 |
+
|
83 |
+
def forward(self, x):
|
84 |
+
x = self.project_in(x)
|
85 |
+
x1, x2 = self.dwconv(x).chunk(2, dim=1)
|
86 |
+
x = F.gelu(x1) * x2
|
87 |
+
x = self.project_out(x)
|
88 |
+
return x
|
89 |
+
|
90 |
+
|
91 |
+
##########################################################################
|
92 |
+
## Multi-DConv Head Transposed Self-Attention (MDTA)
|
93 |
+
class Attention(nn.Module):
|
94 |
+
def __init__(self, dim, num_heads, bias):
|
95 |
+
super(Attention, self).__init__()
|
96 |
+
self.num_heads = num_heads
|
97 |
+
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
|
98 |
+
|
99 |
+
self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
|
100 |
+
self.qkv_dwconv = nn.Conv2d(dim * 3, dim * 3, kernel_size=3, stride=1, padding=1, groups=dim * 3, bias=bias)
|
101 |
+
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
|
102 |
+
|
103 |
+
def forward(self, x):
|
104 |
+
b, c, h, w = x.shape
|
105 |
+
|
106 |
+
qkv = self.qkv_dwconv(self.qkv(x))
|
107 |
+
q, k, v = qkv.chunk(3, dim=1)
|
108 |
+
|
109 |
+
q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
|
110 |
+
k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
|
111 |
+
v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
|
112 |
+
|
113 |
+
q = torch.nn.functional.normalize(q, dim=-1)
|
114 |
+
k = torch.nn.functional.normalize(k, dim=-1)
|
115 |
+
|
116 |
+
attn = (q @ k.transpose(-2, -1)) * self.temperature
|
117 |
+
attn = attn.softmax(dim=-1)
|
118 |
+
|
119 |
+
out = (attn @ v)
|
120 |
+
|
121 |
+
out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
|
122 |
+
|
123 |
+
out = self.project_out(out)
|
124 |
+
return out
|
125 |
+
|
126 |
+
|
127 |
+
##########################################################################
|
128 |
+
class TransformerBlock(nn.Module):
|
129 |
+
def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type):
|
130 |
+
super(TransformerBlock, self).__init__()
|
131 |
+
|
132 |
+
self.norm1 = LayerNorm(dim, LayerNorm_type)
|
133 |
+
self.attn = Attention(dim, num_heads, bias)
|
134 |
+
self.norm2 = LayerNorm(dim, LayerNorm_type)
|
135 |
+
self.ffn = FeedForward(dim, ffn_expansion_factor, bias)
|
136 |
+
|
137 |
+
def forward(self, x):
|
138 |
+
x = x + self.attn(self.norm1(x))
|
139 |
+
x = x + self.ffn(self.norm2(x))
|
140 |
+
|
141 |
+
return x
|
142 |
+
|
143 |
+
|
144 |
+
##########################################################################
|
145 |
+
## Overlapped image patch embedding with 3x3 Conv
|
146 |
+
class OverlapPatchEmbed(nn.Module):
|
147 |
+
def __init__(self, in_c=3, embed_dim=48, bias=False):
|
148 |
+
super(OverlapPatchEmbed, self).__init__()
|
149 |
+
|
150 |
+
self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias)
|
151 |
+
|
152 |
+
def forward(self, x):
|
153 |
+
x = self.proj(x)
|
154 |
+
|
155 |
+
return x
|
156 |
+
|
157 |
+
def pixel_unshuffle(input, downscale_factor):
|
158 |
+
'''
|
159 |
+
input: batchSize * c * k*w * k*h
|
160 |
+
downscale_factor: k
|
161 |
+
batchSize * c * k*w * k*h -> batchSize * k*k*c * w * h
|
162 |
+
'''
|
163 |
+
c = input.shape[1]
|
164 |
+
kernel = torch.zeros(size = [downscale_factor * downscale_factor * c, 1, downscale_factor, downscale_factor],
|
165 |
+
device = input.device)
|
166 |
+
for y in range(downscale_factor):
|
167 |
+
for x in range(downscale_factor):
|
168 |
+
kernel[x + y * downscale_factor::downscale_factor * downscale_factor, 0, y, x] = 1
|
169 |
+
return F.conv2d(input, kernel, stride = downscale_factor, groups = c)
|
170 |
+
|
171 |
+
class PixelUnShuffle(nn.Module):
|
172 |
+
def __init__(self, downscale_factor):
|
173 |
+
super(PixelUnShuffle, self).__init__()
|
174 |
+
self.downscale_factor = downscale_factor
|
175 |
+
|
176 |
+
def forward(self, input):
|
177 |
+
'''
|
178 |
+
input: batchSize * c * k*w * k*h
|
179 |
+
downscale_factor: k
|
180 |
+
batchSize * c * k*w * k*h -> batchSize * k*k*c * w * h
|
181 |
+
'''
|
182 |
+
return pixel_unshuffle(input, self.downscale_factor)
|
183 |
+
|
184 |
+
##########################################################################
|
185 |
+
## Resizing modules
|
186 |
+
class Downsample(nn.Module):
|
187 |
+
def __init__(self, n_feat):
|
188 |
+
super(Downsample, self).__init__()
|
189 |
+
|
190 |
+
self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False),
|
191 |
+
PixelUnShuffle(2))
|
192 |
+
|
193 |
+
def forward(self, x):
|
194 |
+
return self.body(x)
|
195 |
+
|
196 |
+
|
197 |
+
class Upsample(nn.Module):
|
198 |
+
def __init__(self, n_feat):
|
199 |
+
super(Upsample, self).__init__()
|
200 |
+
|
201 |
+
self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False),
|
202 |
+
nn.PixelShuffle(2))
|
203 |
+
|
204 |
+
def forward(self, x):
|
205 |
+
return self.body(x)
|
206 |
+
|
207 |
+
|
208 |
+
##########################################################################
|
209 |
+
##---------- Restormer -----------------------
|
210 |
+
class Restormer(nn.Module):
|
211 |
+
def __init__(self,
|
212 |
+
inp_channels=3,
|
213 |
+
out_channels=31,
|
214 |
+
dim=48,
|
215 |
+
num_blocks=[2, 3, 3, 4],
|
216 |
+
num_refinement_blocks=3,
|
217 |
+
heads=[1, 2, 4, 8],
|
218 |
+
ffn_expansion_factor=2.66,
|
219 |
+
bias=False,
|
220 |
+
LayerNorm_type='WithBias', ## Other option 'BiasFree'
|
221 |
+
dual_pixel_task=True ## True for dual-pixel defocus deblurring only. Also set inp_channels=6
|
222 |
+
):
|
223 |
+
|
224 |
+
super(Restormer, self).__init__()
|
225 |
+
|
226 |
+
self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
|
227 |
+
|
228 |
+
self.encoder_level1 = nn.Sequential(*[
|
229 |
+
TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias,
|
230 |
+
LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
|
231 |
+
|
232 |
+
self.down1_2 = Downsample(dim) ## From Level 1 to Level 2
|
233 |
+
self.encoder_level2 = nn.Sequential(*[
|
234 |
+
TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor,
|
235 |
+
bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
|
236 |
+
|
237 |
+
self.down2_3 = Downsample(int(dim * 2 ** 1)) ## From Level 2 to Level 3
|
238 |
+
self.encoder_level3 = nn.Sequential(*[
|
239 |
+
TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor,
|
240 |
+
bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])
|
241 |
+
|
242 |
+
self.down3_4 = Downsample(int(dim * 2 ** 2)) ## From Level 3 to Level 4
|
243 |
+
self.latent = nn.Sequential(*[
|
244 |
+
TransformerBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor,
|
245 |
+
bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[3])])
|
246 |
+
|
247 |
+
self.up4_3 = Upsample(int(dim * 2 ** 3)) ## From Level 4 to Level 3
|
248 |
+
self.reduce_chan_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=1, bias=bias)
|
249 |
+
self.decoder_level3 = nn.Sequential(*[
|
250 |
+
TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor,
|
251 |
+
bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])
|
252 |
+
|
253 |
+
self.up3_2 = Upsample(int(dim * 2 ** 2)) ## From Level 3 to Level 2
|
254 |
+
self.reduce_chan_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias)
|
255 |
+
self.decoder_level2 = nn.Sequential(*[
|
256 |
+
TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor,
|
257 |
+
bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
|
258 |
+
|
259 |
+
self.up2_1 = Upsample(int(dim * 2 ** 1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels)
|
260 |
+
|
261 |
+
self.decoder_level1 = nn.Sequential(*[
|
262 |
+
TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor,
|
263 |
+
bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
|
264 |
+
|
265 |
+
self.refinement = nn.Sequential(*[
|
266 |
+
TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor,
|
267 |
+
bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)])
|
268 |
+
|
269 |
+
#### For Dual-Pixel Defocus Deblurring Task ####
|
270 |
+
self.dual_pixel_task = dual_pixel_task
|
271 |
+
if self.dual_pixel_task:
|
272 |
+
self.skip_conv = nn.Conv2d(dim, int(dim * 2 ** 1), kernel_size=1, bias=bias)
|
273 |
+
###########################
|
274 |
+
|
275 |
+
self.output = nn.Conv2d(int(dim * 2 ** 1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias)
|
276 |
+
|
277 |
+
def forward(self, inp_img):
|
278 |
+
b, c, h_inp, w_inp = inp_img.shape
|
279 |
+
hb, wb = 8, 8
|
280 |
+
pad_h = (hb - h_inp % hb) % hb
|
281 |
+
pad_w = (wb - w_inp % wb) % wb
|
282 |
+
inp_img = F.pad(inp_img, [0, pad_w, 0, pad_h], mode='reflect')
|
283 |
+
|
284 |
+
inp_enc_level1 = self.patch_embed(inp_img)
|
285 |
+
out_enc_level1 = self.encoder_level1(inp_enc_level1)
|
286 |
+
|
287 |
+
inp_enc_level2 = self.down1_2(out_enc_level1)
|
288 |
+
out_enc_level2 = self.encoder_level2(inp_enc_level2)
|
289 |
+
|
290 |
+
inp_enc_level3 = self.down2_3(out_enc_level2)
|
291 |
+
out_enc_level3 = self.encoder_level3(inp_enc_level3)
|
292 |
+
|
293 |
+
inp_enc_level4 = self.down3_4(out_enc_level3)
|
294 |
+
latent = self.latent(inp_enc_level4)
|
295 |
+
|
296 |
+
inp_dec_level3 = self.up4_3(latent)
|
297 |
+
inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1)
|
298 |
+
inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3)
|
299 |
+
out_dec_level3 = self.decoder_level3(inp_dec_level3)
|
300 |
+
|
301 |
+
inp_dec_level2 = self.up3_2(out_dec_level3)
|
302 |
+
inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1)
|
303 |
+
inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2)
|
304 |
+
out_dec_level2 = self.decoder_level2(inp_dec_level2)
|
305 |
+
|
306 |
+
inp_dec_level1 = self.up2_1(out_dec_level2)
|
307 |
+
inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1)
|
308 |
+
out_dec_level1 = self.decoder_level1(inp_dec_level1)
|
309 |
+
|
310 |
+
out_dec_level1 = self.refinement(out_dec_level1)
|
311 |
+
|
312 |
+
#### For Dual-Pixel Defocus Deblurring Task ####
|
313 |
+
if self.dual_pixel_task:
|
314 |
+
out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1)
|
315 |
+
out_dec_level1 = self.output(out_dec_level1)
|
316 |
+
###########################
|
317 |
+
else:
|
318 |
+
out_dec_level1 = self.output(out_dec_level1) + inp_img
|
319 |
+
|
320 |
+
return out_dec_level1[:, :, :h_inp, :w_inp]
|
test_challenge_code/architecture/__init__.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from .edsr import EDSR
|
3 |
+
from .HDNet import HDNet
|
4 |
+
from .hinet import HINet
|
5 |
+
from .hrnet import SGN
|
6 |
+
from .HSCNN_Plus import HSCNN_Plus
|
7 |
+
from .MIRNet import MIRNet
|
8 |
+
from .MPRNet import MPRNet
|
9 |
+
from .MST import MST
|
10 |
+
from .MST_Plus_Plus import MST_Plus_Plus
|
11 |
+
from .Restormer import Restormer
|
12 |
+
|
13 |
+
def model_generator(method, pretrained_model_path=None):
|
14 |
+
if method == 'mirnet':
|
15 |
+
model = MIRNet(n_RRG=3, n_MSRB=1, height=3, width=1).cuda()
|
16 |
+
elif method == 'mst_plus_plus':
|
17 |
+
model = MST_Plus_Plus().cuda()
|
18 |
+
elif method == 'mst':
|
19 |
+
model = MST(dim=31, stage=2, num_blocks=[4, 7, 5]).cuda()
|
20 |
+
elif method == 'hinet':
|
21 |
+
model = HINet(depth=4).cuda()
|
22 |
+
elif method == 'mprnet':
|
23 |
+
model = MPRNet(num_cab=4).cuda()
|
24 |
+
elif method == 'restormer':
|
25 |
+
model = Restormer().cuda()
|
26 |
+
elif method == 'edsr':
|
27 |
+
model = EDSR().cuda()
|
28 |
+
elif method == 'hdnet':
|
29 |
+
model = HDNet().cuda()
|
30 |
+
elif method == 'hrnet':
|
31 |
+
model = SGN().cuda()
|
32 |
+
elif method == 'hscnn_plus':
|
33 |
+
model = HSCNN_Plus().cuda()
|
34 |
+
else:
|
35 |
+
print(f'Method {method} is not defined !!!!')
|
36 |
+
if pretrained_model_path is not None:
|
37 |
+
print(f'load model from {pretrained_model_path}')
|
38 |
+
checkpoint = torch.load(pretrained_model_path)
|
39 |
+
model.load_state_dict({k.replace('module.', ''): v for k, v in checkpoint['state_dict'].items()},
|
40 |
+
strict=True)
|
41 |
+
return model
|
test_challenge_code/architecture/edsr.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
def default_conv(in_channels, out_channels, kernel_size, bias=True):
|
4 |
+
return nn.Conv2d(
|
5 |
+
in_channels, out_channels, kernel_size,
|
6 |
+
padding=(kernel_size//2), bias=bias)
|
7 |
+
|
8 |
+
class BasicBlock(nn.Sequential):
|
9 |
+
def __init__(
|
10 |
+
self, conv, in_channels, out_channels, kernel_size, stride=1, bias=False,
|
11 |
+
bn=True, act=nn.ReLU(True)):
|
12 |
+
|
13 |
+
m = [conv(in_channels, out_channels, kernel_size, bias=bias)]
|
14 |
+
if bn:
|
15 |
+
m.append(nn.BatchNorm2d(out_channels))
|
16 |
+
if act is not None:
|
17 |
+
m.append(act)
|
18 |
+
|
19 |
+
super(BasicBlock, self).__init__(*m)
|
20 |
+
|
21 |
+
class ResBlock(nn.Module):
|
22 |
+
def __init__(
|
23 |
+
self, conv, n_feats, kernel_size,
|
24 |
+
bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
|
25 |
+
|
26 |
+
super(ResBlock, self).__init__()
|
27 |
+
m = []
|
28 |
+
for i in range(2):
|
29 |
+
m.append(conv(n_feats, n_feats, kernel_size, bias=bias))
|
30 |
+
if bn:
|
31 |
+
m.append(nn.BatchNorm2d(n_feats))
|
32 |
+
if i == 0:
|
33 |
+
m.append(act)
|
34 |
+
|
35 |
+
self.body = nn.Sequential(*m)
|
36 |
+
self.res_scale = res_scale
|
37 |
+
|
38 |
+
def forward(self, x):
|
39 |
+
res = self.body(x).mul(self.res_scale)
|
40 |
+
res += x
|
41 |
+
|
42 |
+
return res
|
43 |
+
|
44 |
+
|
45 |
+
|
46 |
+
class EDSR(nn.Module):
|
47 |
+
def __init__(self, conv=default_conv):
|
48 |
+
super(EDSR, self).__init__()
|
49 |
+
|
50 |
+
n_resblocks = 32
|
51 |
+
n_feats = 64
|
52 |
+
kernel_size = 3
|
53 |
+
n_colors = 3
|
54 |
+
out_channels = 31
|
55 |
+
act = nn.ReLU(True)
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
# define head module
|
60 |
+
m_head = [conv(n_colors, n_feats, kernel_size)]
|
61 |
+
|
62 |
+
# define body module
|
63 |
+
m_body = [
|
64 |
+
ResBlock(
|
65 |
+
conv, n_feats, kernel_size, act=act, res_scale=1
|
66 |
+
) for _ in range(n_resblocks)
|
67 |
+
]
|
68 |
+
m_body.append(conv(n_feats, n_feats, kernel_size))
|
69 |
+
|
70 |
+
# define tail module
|
71 |
+
m_tail = [
|
72 |
+
conv(n_feats, out_channels, kernel_size)
|
73 |
+
]
|
74 |
+
|
75 |
+
self.head = nn.Sequential(*m_head)
|
76 |
+
self.body = nn.Sequential(*m_body)
|
77 |
+
self.tail = nn.Sequential(*m_tail)
|
78 |
+
|
79 |
+
def forward(self, x):
|
80 |
+
x = self.head(x)
|
81 |
+
|
82 |
+
res = self.body(x)
|
83 |
+
res += x
|
84 |
+
|
85 |
+
x = self.tail(res)
|
86 |
+
|
87 |
+
return x
|
test_challenge_code/architecture/hinet.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
def conv3x3(in_chn, out_chn, bias=True):
|
6 |
+
layer = nn.Conv2d(in_chn, out_chn, kernel_size=3, stride=1, padding=1, bias=bias)
|
7 |
+
return layer
|
8 |
+
|
9 |
+
def conv_down(in_chn, out_chn, bias=False):
|
10 |
+
layer = nn.Conv2d(in_chn, out_chn, kernel_size=4, stride=2, padding=1, bias=bias)
|
11 |
+
return layer
|
12 |
+
|
13 |
+
def conv(in_channels, out_channels, kernel_size, bias=False, stride = 1):
|
14 |
+
return nn.Conv2d(
|
15 |
+
in_channels, out_channels, kernel_size,
|
16 |
+
padding=(kernel_size//2), bias=bias, stride = stride)
|
17 |
+
|
18 |
+
## Supervised Attention Module
|
19 |
+
class SAM(nn.Module):
|
20 |
+
def __init__(self, n_feat, kernel_size=3, bias=True):
|
21 |
+
super(SAM, self).__init__()
|
22 |
+
self.conv1 = conv(n_feat, n_feat, kernel_size, bias=bias)
|
23 |
+
self.conv2 = conv(n_feat, n_feat, kernel_size, bias=bias)
|
24 |
+
self.conv3 = conv(n_feat, n_feat, kernel_size, bias=bias)
|
25 |
+
|
26 |
+
def forward(self, x, x_img):
|
27 |
+
x1 = self.conv1(x)
|
28 |
+
img = self.conv2(x) + x_img
|
29 |
+
x2 = torch.sigmoid(self.conv3(img))
|
30 |
+
x1 = x1*x2
|
31 |
+
x1 = x1+x
|
32 |
+
return x1, img
|
33 |
+
|
34 |
+
class HINet(nn.Module):
|
35 |
+
|
36 |
+
def __init__(self, in_chn=31, out_chn=31, wf=31, depth=4, relu_slope=0.2, hin_position_left=0, hin_position_right=4):
|
37 |
+
super(HINet, self).__init__()
|
38 |
+
|
39 |
+
self.conv_in = nn.Conv2d(3, in_chn, kernel_size=3, padding=(3 - 1) // 2,
|
40 |
+
bias=False)
|
41 |
+
self.depth = depth
|
42 |
+
self.down_path_1 = nn.ModuleList()
|
43 |
+
self.down_path_2 = nn.ModuleList()
|
44 |
+
self.conv_01 = nn.Conv2d(in_chn, wf, 3, 1, 1)
|
45 |
+
self.conv_02 = nn.Conv2d(in_chn, wf, 3, 1, 1)
|
46 |
+
|
47 |
+
prev_channels = self.get_input_chn(wf)
|
48 |
+
for i in range(depth): #0,1,2,3,4
|
49 |
+
use_HIN = True if hin_position_left <= i and i <= hin_position_right else False
|
50 |
+
downsample = True if (i+1) < depth else False
|
51 |
+
self.down_path_1.append(UNetConvBlock(prev_channels, (2**i) * wf, downsample, relu_slope, use_HIN=use_HIN))
|
52 |
+
self.down_path_2.append(UNetConvBlock(prev_channels, (2**i) * wf, downsample, relu_slope, use_csff=downsample, use_HIN=use_HIN))
|
53 |
+
prev_channels = (2**i) * wf
|
54 |
+
|
55 |
+
self.up_path_1 = nn.ModuleList()
|
56 |
+
self.up_path_2 = nn.ModuleList()
|
57 |
+
self.skip_conv_1 = nn.ModuleList()
|
58 |
+
self.skip_conv_2 = nn.ModuleList()
|
59 |
+
for i in reversed(range(depth - 1)):
|
60 |
+
self.up_path_1.append(UNetUpBlock(prev_channels, (2**i)*wf, relu_slope))
|
61 |
+
self.up_path_2.append(UNetUpBlock(prev_channels, (2**i)*wf, relu_slope))
|
62 |
+
self.skip_conv_1.append(nn.Conv2d((2**i)*wf, (2**i)*wf, 3, 1, 1))
|
63 |
+
self.skip_conv_2.append(nn.Conv2d((2**i)*wf, (2**i)*wf, 3, 1, 1))
|
64 |
+
prev_channels = (2**i)*wf
|
65 |
+
self.sam12 = SAM(prev_channels)
|
66 |
+
self.cat12 = nn.Conv2d(prev_channels*2, prev_channels, 1, 1, 0)
|
67 |
+
|
68 |
+
self.last = conv3x3(prev_channels, out_chn, bias=True)
|
69 |
+
|
70 |
+
def forward(self, x):
|
71 |
+
|
72 |
+
b, c, h_inp, w_inp = x.shape
|
73 |
+
hb, wb = 16, 16
|
74 |
+
pad_h = (hb - h_inp % hb) % hb
|
75 |
+
pad_w = (wb - w_inp % wb) % wb
|
76 |
+
x = F.pad(x, [0, pad_w, 0, pad_h], mode='reflect')
|
77 |
+
|
78 |
+
image = self.conv_in(x)
|
79 |
+
|
80 |
+
#stage 1
|
81 |
+
x1 = self.conv_01(image)
|
82 |
+
encs = []
|
83 |
+
decs = []
|
84 |
+
for i, down in enumerate(self.down_path_1):
|
85 |
+
if (i+1) < self.depth:
|
86 |
+
x1, x1_up = down(x1)
|
87 |
+
encs.append(x1_up)
|
88 |
+
else:
|
89 |
+
x1 = down(x1)
|
90 |
+
|
91 |
+
for i, up in enumerate(self.up_path_1):
|
92 |
+
x1 = up(x1, self.skip_conv_1[i](encs[-i-1]))
|
93 |
+
decs.append(x1)
|
94 |
+
|
95 |
+
sam_feature, out_1 = self.sam12(x1, image)
|
96 |
+
#stage 2
|
97 |
+
x2 = self.conv_02(image)
|
98 |
+
x2 = self.cat12(torch.cat([x2, sam_feature], dim=1))
|
99 |
+
blocks = []
|
100 |
+
for i, down in enumerate(self.down_path_2):
|
101 |
+
if (i+1) < self.depth:
|
102 |
+
x2, x2_up = down(x2, encs[i], decs[-i-1])
|
103 |
+
blocks.append(x2_up)
|
104 |
+
else:
|
105 |
+
x2 = down(x2)
|
106 |
+
|
107 |
+
for i, up in enumerate(self.up_path_2):
|
108 |
+
x2 = up(x2, self.skip_conv_2[i](blocks[-i-1]))
|
109 |
+
|
110 |
+
out_2 = self.last(x2)
|
111 |
+
out_2 = out_2 + image
|
112 |
+
return out_2[:, :, :h_inp, :w_inp]
|
113 |
+
|
114 |
+
def get_input_chn(self, in_chn):
|
115 |
+
return in_chn
|
116 |
+
|
117 |
+
def _initialize(self):
|
118 |
+
gain = nn.init.calculate_gain('leaky_relu', 0.20)
|
119 |
+
for m in self.modules():
|
120 |
+
if isinstance(m, nn.Conv2d):
|
121 |
+
nn.init.orthogonal_(m.weight, gain=gain)
|
122 |
+
if not m.bias is None:
|
123 |
+
nn.init.constant_(m.bias, 0)
|
124 |
+
|
125 |
+
|
126 |
+
class UNetConvBlock(nn.Module):
|
127 |
+
def __init__(self, in_size, out_size, downsample, relu_slope, use_csff=False, use_HIN=False):
|
128 |
+
super(UNetConvBlock, self).__init__()
|
129 |
+
self.downsample = downsample
|
130 |
+
self.identity = nn.Conv2d(in_size, out_size, 1, 1, 0)
|
131 |
+
self.use_csff = use_csff
|
132 |
+
|
133 |
+
self.conv_1 = nn.Conv2d(in_size, out_size, kernel_size=3, padding=1, bias=True)
|
134 |
+
self.relu_1 = nn.LeakyReLU(relu_slope, inplace=False)
|
135 |
+
self.conv_2 = nn.Conv2d(out_size, out_size, kernel_size=3, padding=1, bias=True)
|
136 |
+
self.relu_2 = nn.LeakyReLU(relu_slope, inplace=False)
|
137 |
+
|
138 |
+
if downsample and use_csff:
|
139 |
+
self.csff_enc = nn.Conv2d(out_size, out_size, 3, 1, 1)
|
140 |
+
self.csff_dec = nn.Conv2d(out_size, out_size, 3, 1, 1)
|
141 |
+
|
142 |
+
if use_HIN:
|
143 |
+
self.norm = nn.InstanceNorm2d(((out_size+1)//2), affine=True)
|
144 |
+
self.use_HIN = use_HIN
|
145 |
+
|
146 |
+
if downsample:
|
147 |
+
self.downsample = conv_down(out_size, out_size, bias=False)
|
148 |
+
|
149 |
+
def forward(self, x, enc=None, dec=None):
|
150 |
+
out = self.conv_1(x)
|
151 |
+
|
152 |
+
if self.use_HIN:
|
153 |
+
out_1, out_2 = torch.chunk(out, 2, dim=1)
|
154 |
+
out = torch.cat([self.norm(out_1), out_2], dim=1)
|
155 |
+
out = self.relu_1(out)
|
156 |
+
out = self.relu_2(self.conv_2(out))
|
157 |
+
|
158 |
+
out += self.identity(x)
|
159 |
+
if enc is not None and dec is not None:
|
160 |
+
assert self.use_csff
|
161 |
+
out = out + self.csff_enc(enc) + self.csff_dec(dec)
|
162 |
+
if self.downsample:
|
163 |
+
out_down = self.downsample(out)
|
164 |
+
return out_down, out
|
165 |
+
else:
|
166 |
+
return out
|
167 |
+
|
168 |
+
|
169 |
+
class UNetUpBlock(nn.Module):
|
170 |
+
def __init__(self, in_size, out_size, relu_slope):
|
171 |
+
super(UNetUpBlock, self).__init__()
|
172 |
+
self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2, bias=True)
|
173 |
+
self.conv_block = UNetConvBlock(in_size, out_size, False, relu_slope)
|
174 |
+
|
175 |
+
def forward(self, x, bridge):
|
176 |
+
up = self.up(x)
|
177 |
+
out = torch.cat([up, bridge], 1)
|
178 |
+
out = self.conv_block(out)
|
179 |
+
return out
|
180 |
+
|
181 |
+
class Subspace(nn.Module):
|
182 |
+
|
183 |
+
def __init__(self, in_size, out_size):
|
184 |
+
super(Subspace, self).__init__()
|
185 |
+
self.blocks = nn.ModuleList()
|
186 |
+
self.blocks.append(UNetConvBlock(in_size, out_size, False, 0.2))
|
187 |
+
self.shortcut = nn.Conv2d(in_size, out_size, kernel_size=1, bias=True)
|
188 |
+
|
189 |
+
def forward(self, x):
|
190 |
+
sc = self.shortcut(x)
|
191 |
+
for i in range(len(self.blocks)):
|
192 |
+
x = self.blocks[i](x)
|
193 |
+
return x + sc
|
194 |
+
|
195 |
+
class skip_blocks(nn.Module):
|
196 |
+
|
197 |
+
def __init__(self, in_size, out_size, repeat_num=1):
|
198 |
+
super(skip_blocks, self).__init__()
|
199 |
+
self.blocks = nn.ModuleList()
|
200 |
+
self.re_num = repeat_num
|
201 |
+
mid_c = 128
|
202 |
+
self.blocks.append(UNetConvBlock(in_size, mid_c, False, 0.2))
|
203 |
+
for i in range(self.re_num - 2):
|
204 |
+
self.blocks.append(UNetConvBlock(mid_c, mid_c, False, 0.2))
|
205 |
+
self.blocks.append(UNetConvBlock(mid_c, out_size, False, 0.2))
|
206 |
+
self.shortcut = nn.Conv2d(in_size, out_size, kernel_size=1, bias=True)
|
207 |
+
|
208 |
+
def forward(self, x):
|
209 |
+
sc = self.shortcut(x)
|
210 |
+
for m in self.blocks:
|
211 |
+
x = m(x)
|
212 |
+
return x + sc
|
test_challenge_code/architecture/hrnet.py
ADDED
@@ -0,0 +1,484 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch.nn import Parameter
|
5 |
+
# ----------------------------------------
|
6 |
+
# Conv2d Block
|
7 |
+
# ----------------------------------------
|
8 |
+
class Conv2dLayer(nn.Module):
|
9 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, pad_type='zero',
|
10 |
+
activation='lrelu', norm='none', sn=False):
|
11 |
+
super(Conv2dLayer, self).__init__()
|
12 |
+
# Initialize the padding scheme
|
13 |
+
if pad_type == 'reflect':
|
14 |
+
self.pad = nn.ReflectionPad2d(padding)
|
15 |
+
elif pad_type == 'replicate':
|
16 |
+
self.pad = nn.ReplicationPad2d(padding)
|
17 |
+
elif pad_type == 'zero':
|
18 |
+
self.pad = nn.ZeroPad2d(padding)
|
19 |
+
else:
|
20 |
+
assert 0, "Unsupported padding type: {}".format(pad_type)
|
21 |
+
|
22 |
+
# Initialize the normalization type
|
23 |
+
if norm == 'bn':
|
24 |
+
self.norm = nn.BatchNorm2d(out_channels)
|
25 |
+
elif norm == 'in':
|
26 |
+
self.norm = nn.InstanceNorm2d(out_channels)
|
27 |
+
elif norm == 'ln':
|
28 |
+
self.norm = LayerNorm(out_channels)
|
29 |
+
elif norm == 'none':
|
30 |
+
self.norm = None
|
31 |
+
else:
|
32 |
+
assert 0, "Unsupported normalization: {}".format(norm)
|
33 |
+
|
34 |
+
# Initialize the activation funtion
|
35 |
+
if activation == 'relu':
|
36 |
+
self.activation = nn.ReLU(inplace=True)
|
37 |
+
elif activation == 'lrelu':
|
38 |
+
self.activation = nn.LeakyReLU(0.2, inplace=True)
|
39 |
+
elif activation == 'prelu':
|
40 |
+
self.activation = nn.PReLU()
|
41 |
+
elif activation == 'selu':
|
42 |
+
self.activation = nn.SELU(inplace=True)
|
43 |
+
elif activation == 'tanh':
|
44 |
+
self.activation = nn.Tanh()
|
45 |
+
elif activation == 'sigmoid':
|
46 |
+
self.activation = nn.Sigmoid()
|
47 |
+
elif activation == 'none':
|
48 |
+
self.activation = None
|
49 |
+
else:
|
50 |
+
assert 0, "Unsupported activation: {}".format(activation)
|
51 |
+
|
52 |
+
# Initialize the convolution layers
|
53 |
+
if sn:
|
54 |
+
self.conv2d = SpectralNorm(
|
55 |
+
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=0, dilation=dilation))
|
56 |
+
else:
|
57 |
+
self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=0, dilation=dilation)
|
58 |
+
|
59 |
+
def forward(self, x):
|
60 |
+
x = self.pad(x)
|
61 |
+
x = self.conv2d(x)
|
62 |
+
if self.norm:
|
63 |
+
x = self.norm(x)
|
64 |
+
if self.activation:
|
65 |
+
x = self.activation(x)
|
66 |
+
return x
|
67 |
+
|
68 |
+
|
69 |
+
class TransposeConv2dLayer(nn.Module):
|
70 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, pad_type='zero',
|
71 |
+
activation='lrelu', norm='none', sn=False, scale_factor=2):
|
72 |
+
super(TransposeConv2dLayer, self).__init__()
|
73 |
+
# Initialize the conv scheme
|
74 |
+
self.scale_factor = scale_factor
|
75 |
+
self.conv2d = Conv2dLayer(in_channels, out_channels, kernel_size, stride, padding, dilation, pad_type,
|
76 |
+
activation, norm, sn)
|
77 |
+
|
78 |
+
def forward(self, x):
|
79 |
+
x = F.interpolate(x, scale_factor=self.scale_factor, mode='nearest')
|
80 |
+
x = self.conv2d(x)
|
81 |
+
return x
|
82 |
+
|
83 |
+
|
84 |
+
class ResConv2dLayer(nn.Module):
|
85 |
+
def __init__(self, in_channels, kernel_size=3, stride=1, padding=1, dilation=1, pad_type='zero', activation='lrelu',
|
86 |
+
norm='none', sn=False, scale_factor=2):
|
87 |
+
super(ResConv2dLayer, self).__init__()
|
88 |
+
# Initialize the conv scheme
|
89 |
+
self.conv2d = nn.Sequential(
|
90 |
+
Conv2dLayer(in_channels, in_channels, kernel_size, stride, padding, dilation, pad_type, activation, norm,
|
91 |
+
sn),
|
92 |
+
Conv2dLayer(in_channels, in_channels, kernel_size, stride, padding, dilation, pad_type, activation='none',
|
93 |
+
norm=norm, sn=sn)
|
94 |
+
)
|
95 |
+
|
96 |
+
def forward(self, x):
|
97 |
+
residual = x
|
98 |
+
out = self.conv2d(x)
|
99 |
+
out = 0.1 * out + residual
|
100 |
+
return out
|
101 |
+
|
102 |
+
|
103 |
+
class DenseConv2dLayer_5C(nn.Module):
|
104 |
+
def __init__(self, in_channels, latent_channels, kernel_size=3, stride=1, padding=1, dilation=1, pad_type='zero',
|
105 |
+
activation='lrelu', norm='none', sn=False):
|
106 |
+
super(DenseConv2dLayer_5C, self).__init__()
|
107 |
+
# dense convolutions
|
108 |
+
self.conv1 = Conv2dLayer(in_channels, latent_channels, kernel_size, stride, padding, dilation, pad_type,
|
109 |
+
activation, norm, sn)
|
110 |
+
self.conv2 = Conv2dLayer(in_channels + latent_channels, latent_channels, kernel_size, stride, padding, dilation,
|
111 |
+
pad_type, activation, norm, sn)
|
112 |
+
self.conv3 = Conv2dLayer(in_channels + latent_channels * 2, latent_channels, kernel_size, stride, padding,
|
113 |
+
dilation, pad_type, activation, norm, sn)
|
114 |
+
self.conv4 = Conv2dLayer(in_channels + latent_channels * 3, latent_channels, kernel_size, stride, padding,
|
115 |
+
dilation, pad_type, activation, norm, sn)
|
116 |
+
self.conv5 = Conv2dLayer(in_channels + latent_channels * 4, in_channels, kernel_size, stride, padding, dilation,
|
117 |
+
pad_type, activation, norm, sn)
|
118 |
+
|
119 |
+
def forward(self, x):
|
120 |
+
x1 = self.conv1(x)
|
121 |
+
x2 = self.conv2(torch.cat((x, x1), 1))
|
122 |
+
x3 = self.conv3(torch.cat((x, x1, x2), 1))
|
123 |
+
x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
|
124 |
+
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
125 |
+
return x5
|
126 |
+
|
127 |
+
|
128 |
+
class ResidualDenseBlock_5C(nn.Module):
|
129 |
+
def __init__(self, in_channels, latent_channels, kernel_size=3, stride=1, padding=1, dilation=1, pad_type='zero',
|
130 |
+
activation='lrelu', norm='none', sn=False):
|
131 |
+
super(ResidualDenseBlock_5C, self).__init__()
|
132 |
+
# dense convolutions
|
133 |
+
self.conv1 = Conv2dLayer(in_channels, latent_channels, kernel_size, stride, padding, dilation, pad_type,
|
134 |
+
activation, norm, sn)
|
135 |
+
self.conv2 = Conv2dLayer(in_channels + latent_channels, latent_channels, kernel_size, stride, padding, dilation,
|
136 |
+
pad_type, activation, norm, sn)
|
137 |
+
self.conv3 = Conv2dLayer(in_channels + latent_channels * 2, latent_channels, kernel_size, stride, padding,
|
138 |
+
dilation, pad_type, activation, norm, sn)
|
139 |
+
self.conv4 = Conv2dLayer(in_channels + latent_channels * 3, latent_channels, kernel_size, stride, padding,
|
140 |
+
dilation, pad_type, activation, norm, sn)
|
141 |
+
self.conv5 = Conv2dLayer(in_channels + latent_channels * 4, in_channels, kernel_size, stride, padding, dilation,
|
142 |
+
pad_type, activation, norm, sn)
|
143 |
+
|
144 |
+
def forward(self, x):
|
145 |
+
residual = x
|
146 |
+
x1 = self.conv1(x)
|
147 |
+
x2 = self.conv2(torch.cat((x, x1), 1))
|
148 |
+
x3 = self.conv3(torch.cat((x, x1, x2), 1))
|
149 |
+
x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
|
150 |
+
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
151 |
+
x5 = 0.1 * x5 + residual
|
152 |
+
return x5
|
153 |
+
|
154 |
+
|
155 |
+
# ----------------------------------------
|
156 |
+
# Layer Norm
|
157 |
+
# ----------------------------------------
|
158 |
+
class LayerNorm(nn.Module):
|
159 |
+
def __init__(self, num_features, eps=1e-8, affine=True):
|
160 |
+
super(LayerNorm, self).__init__()
|
161 |
+
self.num_features = num_features
|
162 |
+
self.affine = affine
|
163 |
+
self.eps = eps
|
164 |
+
|
165 |
+
if self.affine:
|
166 |
+
self.gamma = Parameter(torch.Tensor(num_features).uniform_())
|
167 |
+
self.beta = Parameter(torch.zeros(num_features))
|
168 |
+
|
169 |
+
def forward(self, x):
|
170 |
+
# layer norm
|
171 |
+
shape = [-1] + [1] * (x.dim() - 1) # for 4d input: [-1, 1, 1, 1]
|
172 |
+
if x.size(0) == 1:
|
173 |
+
# These two lines run much faster in pytorch 0.4 than the two lines listed below.
|
174 |
+
mean = x.view(-1).mean().view(*shape)
|
175 |
+
std = x.view(-1).std().view(*shape)
|
176 |
+
else:
|
177 |
+
mean = x.view(x.size(0), -1).mean(1).view(*shape)
|
178 |
+
std = x.view(x.size(0), -1).std(1).view(*shape)
|
179 |
+
x = (x - mean) / (std + self.eps)
|
180 |
+
# if it is learnable
|
181 |
+
if self.affine:
|
182 |
+
shape = [1, -1] + [1] * (x.dim() - 2) # for 4d input: [1, -1, 1, 1]
|
183 |
+
x = x * self.gamma.view(*shape) + self.beta.view(*shape)
|
184 |
+
return x
|
185 |
+
|
186 |
+
|
187 |
+
# ----------------------------------------
|
188 |
+
# Spectral Norm Block
|
189 |
+
# ----------------------------------------
|
190 |
+
def l2normalize(v, eps=1e-12):
|
191 |
+
return v / (v.norm() + eps)
|
192 |
+
|
193 |
+
|
194 |
+
class SpectralNorm(nn.Module):
|
195 |
+
def __init__(self, module, name='weight', power_iterations=1):
|
196 |
+
super(SpectralNorm, self).__init__()
|
197 |
+
self.module = module
|
198 |
+
self.name = name
|
199 |
+
self.power_iterations = power_iterations
|
200 |
+
if not self._made_params():
|
201 |
+
self._make_params()
|
202 |
+
|
203 |
+
def _update_u_v(self):
|
204 |
+
u = getattr(self.module, self.name + "_u")
|
205 |
+
v = getattr(self.module, self.name + "_v")
|
206 |
+
w = getattr(self.module, self.name + "_bar")
|
207 |
+
|
208 |
+
height = w.data.shape[0]
|
209 |
+
for _ in range(self.power_iterations):
|
210 |
+
v.data = l2normalize(torch.mv(torch.t(w.view(height, -1).data), u.data))
|
211 |
+
u.data = l2normalize(torch.mv(w.view(height, -1).data, v.data))
|
212 |
+
|
213 |
+
# sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
|
214 |
+
sigma = u.dot(w.view(height, -1).mv(v))
|
215 |
+
setattr(self.module, self.name, w / sigma.expand_as(w))
|
216 |
+
|
217 |
+
def _made_params(self):
|
218 |
+
try:
|
219 |
+
u = getattr(self.module, self.name + "_u")
|
220 |
+
v = getattr(self.module, self.name + "_v")
|
221 |
+
w = getattr(self.module, self.name + "_bar")
|
222 |
+
return True
|
223 |
+
except AttributeError:
|
224 |
+
return False
|
225 |
+
|
226 |
+
def _make_params(self):
|
227 |
+
w = getattr(self.module, self.name)
|
228 |
+
|
229 |
+
height = w.data.shape[0]
|
230 |
+
width = w.view(height, -1).data.shape[1]
|
231 |
+
|
232 |
+
u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
|
233 |
+
v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
|
234 |
+
u.data = l2normalize(u.data)
|
235 |
+
v.data = l2normalize(v.data)
|
236 |
+
w_bar = Parameter(w.data)
|
237 |
+
|
238 |
+
del self.module._parameters[self.name]
|
239 |
+
|
240 |
+
self.module.register_parameter(self.name + "_u", u)
|
241 |
+
self.module.register_parameter(self.name + "_v", v)
|
242 |
+
self.module.register_parameter(self.name + "_bar", w_bar)
|
243 |
+
|
244 |
+
def forward(self, *args):
|
245 |
+
self._update_u_v()
|
246 |
+
return self.module.forward(*args)
|
247 |
+
|
248 |
+
|
249 |
+
# ----------------------------------------
|
250 |
+
# Non-local Block
|
251 |
+
# ----------------------------------------
|
252 |
+
class Self_Attn(nn.Module):
|
253 |
+
""" Self attention Layer for Feature Map dimension"""
|
254 |
+
|
255 |
+
def __init__(self, in_dim, latent_dim=8):
|
256 |
+
super(Self_Attn, self).__init__()
|
257 |
+
self.channel_in = in_dim
|
258 |
+
self.channel_latent = in_dim // latent_dim
|
259 |
+
self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // latent_dim, kernel_size=1)
|
260 |
+
self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // latent_dim, kernel_size=1)
|
261 |
+
self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
|
262 |
+
self.gamma = nn.Parameter(torch.zeros(1))
|
263 |
+
self.softmax = nn.Softmax(dim=-1)
|
264 |
+
|
265 |
+
def forward(self, x):
|
266 |
+
"""
|
267 |
+
inputs :
|
268 |
+
x : input feature maps(B X C X H X W)
|
269 |
+
returns :
|
270 |
+
out : self attention value + input feature
|
271 |
+
attention: B X N X N (N is Height * Width)
|
272 |
+
"""
|
273 |
+
batchsize, C, height, width = x.size()
|
274 |
+
# proj_query: reshape to B x N x c, N = H x W
|
275 |
+
proj_query = self.query_conv(x).view(batchsize, -1, height * width).permute(0, 2, 1)
|
276 |
+
# proj_query: reshape to B x c x N, N = H x W
|
277 |
+
proj_key = self.key_conv(x).view(batchsize, -1, height * width)
|
278 |
+
# transpose check, energy: B x N x N, N = H x W
|
279 |
+
energy = torch.bmm(proj_query, proj_key)
|
280 |
+
# attention: B x N x N, N = H x W
|
281 |
+
attention = self.softmax(energy)
|
282 |
+
# proj_value is normal convolution, B x C x N
|
283 |
+
proj_value = self.value_conv(x).view(batchsize, -1, height * width)
|
284 |
+
# out: B x C x N
|
285 |
+
out = torch.bmm(proj_value, attention.permute(0, 2, 1))
|
286 |
+
out = out.view(batchsize, C, height, width)
|
287 |
+
|
288 |
+
out = self.gamma * out + x
|
289 |
+
return out
|
290 |
+
|
291 |
+
|
292 |
+
# ----------------------------------------
|
293 |
+
# Global Block
|
294 |
+
# ----------------------------------------
|
295 |
+
class SELayer(nn.Module):
|
296 |
+
def __init__(self, channel, reduction=16):
|
297 |
+
super(SELayer, self).__init__()
|
298 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
299 |
+
self.fc = nn.Sequential(
|
300 |
+
nn.Linear(channel, channel // reduction, bias=False),
|
301 |
+
nn.ReLU(inplace=True),
|
302 |
+
nn.Linear(channel // reduction, channel // reduction, bias=False),
|
303 |
+
nn.ReLU(inplace=True),
|
304 |
+
nn.Linear(channel // reduction, channel, bias=False),
|
305 |
+
nn.Sigmoid()
|
306 |
+
)
|
307 |
+
|
308 |
+
def forward(self, x):
|
309 |
+
b, c, _, _ = x.size()
|
310 |
+
y = self.avg_pool(x).view(b, c)
|
311 |
+
y = self.fc(y).view(b, c, 1, 1)
|
312 |
+
return x * y.expand_as(x)
|
313 |
+
|
314 |
+
|
315 |
+
class GlobalBlock(nn.Module):
|
316 |
+
def __init__(self, in_channels, kernel_size, stride=1, padding=0, dilation=1, pad_type='zero', activation='lrelu',
|
317 |
+
norm='none', sn=False, reduction=8):
|
318 |
+
super(GlobalBlock, self).__init__()
|
319 |
+
self.conv1 = Conv2dLayer(in_channels, in_channels, kernel_size, stride, padding, dilation, pad_type, activation,
|
320 |
+
norm, sn)
|
321 |
+
self.conv2 = Conv2dLayer(in_channels, in_channels, kernel_size, stride, padding, dilation, pad_type, activation,
|
322 |
+
norm, sn)
|
323 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
324 |
+
self.fc = nn.Sequential(
|
325 |
+
nn.Linear(in_channels, in_channels // reduction, bias=False),
|
326 |
+
nn.ReLU(inplace=True),
|
327 |
+
nn.Linear(in_channels // reduction, in_channels // reduction, bias=False),
|
328 |
+
nn.ReLU(inplace=True),
|
329 |
+
nn.Linear(in_channels // reduction, in_channels, bias=False),
|
330 |
+
nn.Sigmoid()
|
331 |
+
)
|
332 |
+
|
333 |
+
def forward(self, x):
|
334 |
+
# residual
|
335 |
+
residual = x
|
336 |
+
# Sequeeze-and-Excitation(SE)
|
337 |
+
b, c, _, _ = x.size()
|
338 |
+
x = self.conv1(x)
|
339 |
+
y = self.avg_pool(x).view(b, c)
|
340 |
+
y = self.fc(y).view(b, c, 1, 1)
|
341 |
+
y = x * y.expand_as(x)
|
342 |
+
y = self.conv2(x)
|
343 |
+
# addition
|
344 |
+
out = 0.1 * y + residual
|
345 |
+
return out
|
346 |
+
def pixel_unshuffle(input, downscale_factor):
|
347 |
+
'''
|
348 |
+
input: batchSize * c * k*w * k*h
|
349 |
+
downscale_factor: k
|
350 |
+
batchSize * c * k*w * k*h -> batchSize * k*k*c * w * h
|
351 |
+
'''
|
352 |
+
c = input.shape[1]
|
353 |
+
kernel = torch.zeros(size = [downscale_factor * downscale_factor * c, 1, downscale_factor, downscale_factor],
|
354 |
+
device = input.device)
|
355 |
+
for y in range(downscale_factor):
|
356 |
+
for x in range(downscale_factor):
|
357 |
+
kernel[x + y * downscale_factor::downscale_factor * downscale_factor, 0, y, x] = 1
|
358 |
+
return F.conv2d(input, kernel, stride = downscale_factor, groups = c)
|
359 |
+
|
360 |
+
class PixelUnShuffle(nn.Module):
|
361 |
+
def __init__(self, downscale_factor):
|
362 |
+
super(PixelUnShuffle, self).__init__()
|
363 |
+
self.downscale_factor = downscale_factor
|
364 |
+
|
365 |
+
def forward(self, input):
|
366 |
+
'''
|
367 |
+
input: batchSize * c * k*w * k*h
|
368 |
+
downscale_factor: k
|
369 |
+
batchSize * c * k*w * k*h -> batchSize * k*k*c * w * h
|
370 |
+
'''
|
371 |
+
return pixel_unshuffle(input, self.downscale_factor)
|
372 |
+
|
373 |
+
# ----------------------------------------
|
374 |
+
# Initialize the networks
|
375 |
+
# ----------------------------------------
|
376 |
+
def weights_init(net, init_type = 'normal', init_gain = 0.02):
|
377 |
+
"""Initialize network weights.
|
378 |
+
Parameters:
|
379 |
+
net (network) -- network to be initialized
|
380 |
+
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
|
381 |
+
init_gain (float) -- scaling factor for normal, xavier and orthogonal
|
382 |
+
In our paper, we choose the default setting: zero mean Gaussian distribution with a standard deviation of 0.02
|
383 |
+
"""
|
384 |
+
def init_func(m):
|
385 |
+
classname = m.__class__.__name__
|
386 |
+
if hasattr(m, 'weight') and classname.find('Conv') != -1:
|
387 |
+
if init_type == 'normal':
|
388 |
+
torch.nn.init.normal_(m.weight.data, 0.0, init_gain)
|
389 |
+
elif init_type == 'xavier':
|
390 |
+
torch.nn.init.xavier_normal_(m.weight.data, gain = init_gain)
|
391 |
+
elif init_type == 'kaiming':
|
392 |
+
torch.nn.init.kaiming_normal_(m.weight.data, a = 0, mode = 'fan_in')
|
393 |
+
elif init_type == 'orthogonal':
|
394 |
+
torch.nn.init.orthogonal_(m.weight.data, gain = init_gain)
|
395 |
+
else:
|
396 |
+
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
|
397 |
+
elif classname.find('BatchNorm2d') != -1:
|
398 |
+
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
|
399 |
+
torch.nn.init.constant_(m.bias.data, 0.0)
|
400 |
+
|
401 |
+
# apply the initialization function <init_func>
|
402 |
+
print('initialize network with %s type' % init_type)
|
403 |
+
net.apply(init_func)
|
404 |
+
|
405 |
+
# ----------------------------------------
|
406 |
+
# Generator
|
407 |
+
# ----------------------------------------
|
408 |
+
class SGN(nn.Module):
|
409 |
+
def __init__(self, in_channels=3, out_channels=31, start_channels=64, pad='zero', activ='lrelu', norm='none', ):
|
410 |
+
super(SGN, self).__init__()
|
411 |
+
# Top subnetwork, K = 3
|
412 |
+
self.top1 = Conv2dLayer(in_channels * (4 ** 3), start_channels * (2 ** 3), 3, 1, 1, pad_type = pad, activation = activ, norm = norm)
|
413 |
+
self.top21 = ResidualDenseBlock_5C(start_channels * (2 ** 3), start_channels * (2 ** 2), 3, 1, 1, pad_type = pad, activation = activ, norm = norm)
|
414 |
+
self.top22 = GlobalBlock(start_channels * (2 ** 3), 3, 1, 1, pad_type = pad, activation = activ, norm = norm, sn = False, reduction = 4)
|
415 |
+
self.top3 = Conv2dLayer(start_channels * (2 ** 3), start_channels * (2 ** 3), 1, 1, 0, pad_type = pad, activation = activ, norm = norm)
|
416 |
+
# Middle subnetwork, K = 2
|
417 |
+
self.mid1 = Conv2dLayer(in_channels * (4 ** 2), start_channels * (2 ** 2), 3, 1, 1, pad_type = pad, activation = activ, norm = norm)
|
418 |
+
self.mid2 = Conv2dLayer(int(start_channels * (2 ** 2 + 2 ** 3 / 4)), start_channels * (2 ** 2), 3, 1, 1, pad_type = pad, activation = activ, norm = norm)
|
419 |
+
self.mid31 = ResidualDenseBlock_5C(start_channels * (2 ** 2), start_channels * (2 ** 1), 3, 1, 1, pad_type = pad, activation = activ, norm = norm)
|
420 |
+
self.mid32 = GlobalBlock(start_channels * (2 ** 2), 3, 1, 1, pad_type = pad, activation = activ, norm = norm, sn = False, reduction = 4)
|
421 |
+
self.mid4 = Conv2dLayer(start_channels * (2 ** 2), start_channels * (2 ** 2), 3, 1, 1, pad_type = pad, activation = activ, norm = norm)
|
422 |
+
# Bottom subnetwork, K = 1
|
423 |
+
self.bot1 = Conv2dLayer(in_channels * (4 ** 1), start_channels * (2 ** 1), 3, 1, 1, pad_type = pad, activation = activ, norm = norm)
|
424 |
+
self.bot2 = Conv2dLayer(int(start_channels * (2 ** 1 + 2 ** 2 / 4)), start_channels * (2 ** 1), 3, 1, 1, pad_type = pad, activation = activ, norm = norm)
|
425 |
+
self.bot31 = ResidualDenseBlock_5C(start_channels * (2 ** 1), start_channels * (2 ** 0), 3, 1, 1, pad_type = pad, activation = activ, norm = norm)
|
426 |
+
self.bot32 = ResidualDenseBlock_5C(start_channels * (2 ** 1), start_channels * (2 ** 0), 3, 1, 1, pad_type = pad, activation = activ, norm = norm)
|
427 |
+
self.bot33 = GlobalBlock(start_channels * (2 ** 1), 3, 1, 1, pad_type = pad, activation = activ, norm = norm, sn = False, reduction = 4)
|
428 |
+
self.bot4 = Conv2dLayer(start_channels * (2 ** 1), start_channels * (2 ** 1), 3, 1, 1, pad_type = pad, activation = activ, norm = norm)
|
429 |
+
# Mainstream
|
430 |
+
self.main1 = Conv2dLayer(in_channels, start_channels, 3, 1, 1, pad_type = pad, activation = activ, norm = norm)
|
431 |
+
self.main2 = Conv2dLayer(int(start_channels * (2 ** 0 + 2 ** 1 / 4)), start_channels, 3, 1, 1, pad_type = pad, activation = activ, norm = norm)
|
432 |
+
self.main31 = ResidualDenseBlock_5C(start_channels, start_channels // 2, 3, 1, 1, pad_type = pad, activation = activ, norm = norm)
|
433 |
+
self.main32 = ResidualDenseBlock_5C(start_channels, start_channels // 2, 3, 1, 1, pad_type = pad, activation = activ, norm = norm)
|
434 |
+
self.main33 = ResidualDenseBlock_5C(start_channels, start_channels // 2, 3, 1, 1, pad_type = pad, activation = activ, norm = norm)
|
435 |
+
self.main34 = ResidualDenseBlock_5C(start_channels, start_channels // 2, 3, 1, 1, pad_type = pad, activation = activ, norm = norm)
|
436 |
+
self.main35 = GlobalBlock(start_channels, 3, 1, 1, pad_type = pad, activation = activ, norm = norm, sn = False, reduction = 4)
|
437 |
+
self.main4 = Conv2dLayer(start_channels, out_channels, 3, 1, 1, pad_type = pad, activation = activ, norm = norm)
|
438 |
+
|
439 |
+
def forward(self, x):
|
440 |
+
b, c, h_inp, w_inp = x.shape
|
441 |
+
hb, wb = 8, 8
|
442 |
+
pad_h = (hb - h_inp % hb) % hb
|
443 |
+
pad_w = (wb - w_inp % wb) % wb
|
444 |
+
x = F.pad(x, [0, pad_w, 0, pad_h], mode='reflect')
|
445 |
+
# PixelUnShuffle input: batch * 3 * 256 * 256
|
446 |
+
x1 = pixel_unshuffle(x, 2) # out: batch * 12 * 128 * 128
|
447 |
+
x2 = pixel_unshuffle(x, 4) # out: batch * 48 * 64 * 64
|
448 |
+
x3 = pixel_unshuffle(x, 8) # out: batch * 192 * 32 * 32
|
449 |
+
# Top subnetwork suppose the start_channels = 32
|
450 |
+
x3 = self.top1(x3) # out: batch * 256 * 32 * 32
|
451 |
+
x3 = self.top21(x3) # out: batch * 256 * 32 * 32
|
452 |
+
x3 = self.top22(x3) # out: batch * 256 * 32 * 32
|
453 |
+
x3 = self.top3(x3) # out: batch * 256 * 32 * 32
|
454 |
+
x3 = F.pixel_shuffle(x3, 2) # out: batch * 64 * 64 * 64, ready to be concatenated
|
455 |
+
# Middle subnetwork
|
456 |
+
x2 = self.mid1(x2) # out: batch * 128 * 64 * 64
|
457 |
+
x2 = torch.cat((x2, x3), 1) # out: batch * (128 + 64) * 64 * 64
|
458 |
+
x2 = self.mid2(x2) # out: batch * 128 * 64 * 64
|
459 |
+
x2 = self.mid31(x2) # out: batch * 128 * 64 * 64
|
460 |
+
x2 = self.mid32(x2) # out: batch * 128 * 64 * 64
|
461 |
+
x2 = self.mid4(x2) # out: batch * 128 * 64 * 64
|
462 |
+
x2 = F.pixel_shuffle(x2, 2) # out: batch * 32 * 128 * 128, ready to be concatenated
|
463 |
+
# Bottom subnetwork
|
464 |
+
x1 = self.bot1(x1) # out: batch * 64 * 128 * 128
|
465 |
+
x1 = torch.cat((x1, x2), 1) # out: batch * (64 + 32) * 128 * 128
|
466 |
+
x1 = self.bot2(x1) # out: batch * 64 * 128 * 128
|
467 |
+
x1 = self.bot31(x1) # out: batch * 64 * 128 * 128
|
468 |
+
x1 = self.bot32(x1) # out: batch * 64 * 128 * 128
|
469 |
+
x1 = self.bot33(x1) # out: batch * 64 * 128 * 128
|
470 |
+
x1 = self.bot4(x1) # out: batch * 64 * 128 * 128
|
471 |
+
x1 = F.pixel_shuffle(x1, 2) # out: batch * 16 * 256 * 256, ready to be concatenated
|
472 |
+
# U-Net generator with skip connections from encoder to decoder
|
473 |
+
x = self.main1(x) # out: batch * 32 * 256 * 256
|
474 |
+
x = torch.cat((x, x1), 1) # out: batch * (32 + 16) * 256 * 256
|
475 |
+
x = self.main2(x) # out: batch * 32 * 256 * 256
|
476 |
+
x = self.main31(x) # out: batch * 32 * 256 * 256
|
477 |
+
x = self.main32(x) # out: batch * 32 * 256 * 256
|
478 |
+
x = self.main33(x) # out: batch * 32 * 256 * 256
|
479 |
+
x = self.main34(x) # out: batch * 32 * 256 * 256
|
480 |
+
x = self.main35(x) # out: batch * 32 * 256 * 256
|
481 |
+
x = self.main4(x) # out: batch * 3 * 256 * 256
|
482 |
+
|
483 |
+
return x[:, :, :h_inp, :w_inp]
|
484 |
+
|