Mbi2Spi / base /base_modules.py
hsiangyualex's picture
Upload 64 files
f97a499 verified
raw
history blame
10.5 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
norm_dict = {'BATCH': nn.BatchNorm2d, 'INSTANCE': nn.InstanceNorm2d, 'GROUP': nn.GroupNorm}
NUM_GROUPS = 16
__all__ = ['ConvNorm', 'ConvBlock', 'ConvBottleNeck', 'ResBlock', 'ResBottleneck', 'PromptResBlock', 'PromptResBottleneck', 'PromptAttentionModule', 'norm_dict', 'SobelEdge']
class Identity(nn.Module):
"""
Identity mapping for building a residual connection
"""
def __init__(self):
super().__init__()
def forward(self, x):
return x
class ConvNorm(nn.Module):
"""
Convolution and normalization
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, leaky=True, norm='INSTANCE', activation=True):
super().__init__()
# determine basic attributes
self.norm_type = norm
padding = (kernel_size - 1) // 2
# activation, support PReLU and common ReLU
if activation:
self.act = nn.LeakyReLU() if leaky else nn.ReLU(inplace=False)
# self.act = nn.ELU() if leaky else nn.ReLU(inplace=True)
else:
self.act = None
# instantiate layers
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)
if norm in ['BATCH', 'INSTANCE']:
norm_layer = norm_dict[norm]
self.norm = norm_layer(out_channels)
elif norm == 'GROUP':
norm_layer = norm_dict[norm]
self.norm = norm_layer(NUM_GROUPS, in_channels)
elif norm == 'NONE':
self.norm = nn.Identity()
else:
raise NotImplementedError(f'Normalization type {norm} not implemented')
def basic_forward(self, x):
x = self.conv(x)
x = self.norm(x)
if self.act:
x = self.act(x)
return x
def group_forward(self, x):
x = self.norm(x)
if self.act:
x = self.act(x)
x = self.conv(x)
return x
def forward(self, x):
if self.norm_type in ['BATCH', 'INSTANCE']:
return self.basic_forward(x)
else:
return self.group_forward(x)
class PromptAttentionModule(nn.Module):
def __init__(self, in_channels: int, prompt_channels: int, mid_channels: int) -> None:
super().__init__()
self.gap = nn.AdaptiveAvgPool2d(1)
self.conv_down = nn.Linear(in_channels, mid_channels)
self.prompt_down = nn.Linear(prompt_channels, mid_channels)
self.fc = nn.Linear(2 * mid_channels, in_channels)
def forward(self, x: torch.Tensor, prompt_in: torch.Tensor):
"""
Args:
x: (B, C_im, H, W)
prompt_in: (B, C_prompt)
"""
x_gap = self.gap(x).squeeze(-1).squeeze(-1) # (B, C_im)
x_gap = self.conv_down(x_gap) # (B, C_mid)
prompt_down = self.prompt_down(prompt_in) # (B, C_mid)
gating = torch.cat([x_gap, prompt_down], dim=-1) # (B, 2 * C_mid)
gating = F.sigmoid(self.fc(F.relu(gating)))[..., None, None] # (B, C_im, 1, 1)
return x * gating
class ConvBlock(nn.Module):
"""
Convolutional blocks
"""
def __init__(self, in_channels, out_channels, stride=1, leaky=False, norm='INSTANCE'):
super().__init__()
self.norm_type = norm
# activation, support PReLU and common ReLU
self.act = nn.LeakyReLU() if leaky else nn.ReLU(inplace=False)
# self.act = nn.ELU() if leaky else nn.ReLU(inplace=True)
self.conv1 = ConvNorm(in_channels, out_channels, 3, stride, leaky, norm, True)
self.conv2 = ConvNorm(out_channels, out_channels, 3, 1, leaky, norm, False)
def forward(self, x):
out = self.conv1(x)
out = self.conv2(out)
if self.norm_type != 'GROUP':
out = self.act(out)
return out
class ResBlock(nn.Module):
"""
Residual blocks
"""
def __init__(self, in_channels, out_channels, stride=1, use_dropout=False, leaky=False, norm='INSTANCE'):
super().__init__()
self.norm_type = norm
self.act = nn.LeakyReLU() if leaky else nn.ReLU(inplace=False)
self.dropout = nn.Dropout2d(p=0.1) if use_dropout else None
# self.act = nn.ELU() if leaky else nn.ReLU(inplace=True)
self.conv1 = ConvNorm(in_channels, out_channels, 3, stride, leaky, norm, True)
self.conv2 = ConvNorm(out_channels, out_channels, 3, 1, leaky, norm, False)
need_map = in_channels != out_channels or stride != 1
self.id = ConvNorm(in_channels, out_channels, 1, stride, leaky, norm, False) if need_map else Identity()
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.conv2(out)
identity = self.id(identity)
out = out + identity
if self.norm_type != 'GROUP':
out = self.act(out)
if self.dropout:
out = self.dropout(out)
return out
class ConvBottleNeck(nn.Module):
"""
Convolutional bottleneck blocks
"""
def __init__(self, in_channels, out_channels, stride=1, leaky=False, norm='INSTANCE'):
super().__init__()
self.norm_type = norm
middle_channels = in_channels // 4
self.act = nn.LeakyReLU() if leaky else nn.ReLU(inplace=False)
# self.act = nn.ELU() if leaky else nn.ReLU(inplace=True)
self.conv1 = ConvNorm(in_channels, middle_channels, 1, 1, leaky, norm, True)
self.conv2 = ConvNorm(middle_channels, middle_channels, 3, stride, leaky, norm, True)
self.conv3 = ConvNorm(middle_channels, out_channels, 1, 1, leaky, norm, False)
def forward(self, x):
out = self.conv1(x)
out = self.conv2(out)
out = self.conv3(out)
if self.norm_type != 'GROUP':
out = self.act(out)
return out
class ResBottleneck(nn.Module):
"""
Residual bottleneck blocks
"""
def __init__(self, in_channels, out_channels, stride=1, use_dropout=False, leaky=False, norm='INSTANCE'):
super().__init__()
self.norm_type = norm
middle_channels = in_channels // 4
self.act = nn.LeakyReLU() if leaky else nn.ReLU(inplace=False)
self.dropout = nn.Dropout2d(p=0.1) if use_dropout else None
# self.act = nn.ELU() if leaky else nn.ReLU(inplace=True)
self.conv1 = ConvNorm(in_channels, middle_channels, 1, 1, leaky, norm, True)
self.conv2 = ConvNorm(middle_channels, middle_channels, 3, stride, leaky, norm, True)
self.conv3 = ConvNorm(middle_channels, out_channels, 1, 1, leaky, norm, False)
need_map = in_channels != out_channels or stride != 1
self.id = ConvNorm(in_channels, out_channels, 1, stride, leaky, norm, False) if need_map else Identity()
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.conv2(out)
out = self.conv3(out)
identity = self.id(identity)
out = out + identity
if self.norm_type != 'GROUP':
out = self.act(out)
if self.dropout:
out = self.dropout(out)
return out
class PromptResBlock(nn.Module):
"""
Residual blocks
"""
def __init__(self, in_channels, out_channels, stride=1, use_dropout=False, leaky=False, norm='INSTANCE'):
super().__init__()
self.norm_type = norm
self.act = nn.LeakyReLU() if leaky else nn.ReLU(inplace=False)
self.dropout = nn.Dropout2d(p=0.1) if use_dropout else None
# self.act = nn.ELU() if leaky else nn.ReLU(inplace=True)
self.conv1 = ConvNorm(in_channels, out_channels, 3, stride, leaky, norm, True)
self.conv2 = ConvNorm(out_channels, out_channels, 3, 1, leaky, norm, False)
self.attn = PromptAttentionModule(out_channels, 512, out_channels // 4)
need_map = in_channels != out_channels or stride != 1
self.id = ConvNorm(in_channels, out_channels, 1, stride, leaky, norm, False) if need_map else Identity()
def forward(self, x, prompt_in):
identity = x
out = self.conv1(x)
out = self.conv2(out)
out = self.attn(out, prompt_in)
identity = self.id(identity)
out = out + identity
if self.norm_type != 'GROUP':
out = self.act(out)
if self.dropout:
out = self.dropout(out)
return out
class PromptResBottleneck(nn.Module):
"""
Residual bottleneck blocks
"""
def __init__(self, in_channels, out_channels, stride=1, use_dropout=False, leaky=False, norm='INSTANCE'):
super().__init__()
self.norm_type = norm
middle_channels = in_channels // 4
self.act = nn.LeakyReLU() if leaky else nn.ReLU(inplace=False)
self.dropout = nn.Dropout2d(p=0.1) if use_dropout else None
# self.act = nn.ELU() if leaky else nn.ReLU(inplace=True)
self.conv1 = ConvNorm(in_channels, middle_channels, 1, 1, leaky, norm, True)
self.conv2 = ConvNorm(middle_channels, middle_channels, 3, stride, leaky, norm, True)
self.conv3 = ConvNorm(middle_channels, out_channels, 1, 1, leaky, norm, False)
self.attn = PromptAttentionModule(out_channels, 512, out_channels // 4)
need_map = in_channels != out_channels or stride != 1
self.id = ConvNorm(in_channels, out_channels, 1, stride, leaky, norm, False) if need_map else Identity()
def forward(self, x, prompt_in):
identity = x
out = self.conv1(x)
out = self.conv2(out)
out = self.conv3(out)
out = self.attn(out, prompt_in)
identity = self.id(identity)
out = out + identity
if self.norm_type != 'GROUP':
out = self.act(out)
if self.dropout:
out = self.dropout(out)
return out
class SobelEdge(nn.Module):
def __init__(self, input_dim, channels, kernel_size=3, stride=1):
super().__init__()
conv = getattr(nn, 'Conv%dd' % input_dim)
self.filter = conv(channels, channels, kernel_size, stride, padding=(kernel_size - 1) // 2,
groups=channels, bias=False)
sobel = [[1, 2, 1], [0, 0, 0], [-1, -2, -1]]
sobel_kernel = torch.tensor(sobel, dtype=torch.float32).unsqueeze(0).expand([channels, 1] + [kernel_size] * input_dim)
self.filter.weight = nn.Parameter(sobel_kernel, requires_grad=False)
def forward(self, x):
with torch.no_grad():
out = self.filter(x)
return out