Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
norm_dict = {'BATCH': nn.BatchNorm2d, 'INSTANCE': nn.InstanceNorm2d, 'GROUP': nn.GroupNorm} | |
NUM_GROUPS = 16 | |
__all__ = ['ConvNorm', 'ConvBlock', 'ConvBottleNeck', 'ResBlock', 'ResBottleneck', 'PromptResBlock', 'PromptResBottleneck', 'PromptAttentionModule', 'norm_dict', 'SobelEdge'] | |
class Identity(nn.Module): | |
""" | |
Identity mapping for building a residual connection | |
""" | |
def __init__(self): | |
super().__init__() | |
def forward(self, x): | |
return x | |
class ConvNorm(nn.Module): | |
""" | |
Convolution and normalization | |
""" | |
def __init__(self, in_channels, out_channels, kernel_size, stride=1, leaky=True, norm='INSTANCE', activation=True): | |
super().__init__() | |
# determine basic attributes | |
self.norm_type = norm | |
padding = (kernel_size - 1) // 2 | |
# activation, support PReLU and common ReLU | |
if activation: | |
self.act = nn.LeakyReLU() if leaky else nn.ReLU(inplace=False) | |
# self.act = nn.ELU() if leaky else nn.ReLU(inplace=True) | |
else: | |
self.act = None | |
# instantiate layers | |
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False) | |
if norm in ['BATCH', 'INSTANCE']: | |
norm_layer = norm_dict[norm] | |
self.norm = norm_layer(out_channels) | |
elif norm == 'GROUP': | |
norm_layer = norm_dict[norm] | |
self.norm = norm_layer(NUM_GROUPS, in_channels) | |
elif norm == 'NONE': | |
self.norm = nn.Identity() | |
else: | |
raise NotImplementedError(f'Normalization type {norm} not implemented') | |
def basic_forward(self, x): | |
x = self.conv(x) | |
x = self.norm(x) | |
if self.act: | |
x = self.act(x) | |
return x | |
def group_forward(self, x): | |
x = self.norm(x) | |
if self.act: | |
x = self.act(x) | |
x = self.conv(x) | |
return x | |
def forward(self, x): | |
if self.norm_type in ['BATCH', 'INSTANCE']: | |
return self.basic_forward(x) | |
else: | |
return self.group_forward(x) | |
class PromptAttentionModule(nn.Module): | |
def __init__(self, in_channels: int, prompt_channels: int, mid_channels: int) -> None: | |
super().__init__() | |
self.gap = nn.AdaptiveAvgPool2d(1) | |
self.conv_down = nn.Linear(in_channels, mid_channels) | |
self.prompt_down = nn.Linear(prompt_channels, mid_channels) | |
self.fc = nn.Linear(2 * mid_channels, in_channels) | |
def forward(self, x: torch.Tensor, prompt_in: torch.Tensor): | |
""" | |
Args: | |
x: (B, C_im, H, W) | |
prompt_in: (B, C_prompt) | |
""" | |
x_gap = self.gap(x).squeeze(-1).squeeze(-1) # (B, C_im) | |
x_gap = self.conv_down(x_gap) # (B, C_mid) | |
prompt_down = self.prompt_down(prompt_in) # (B, C_mid) | |
gating = torch.cat([x_gap, prompt_down], dim=-1) # (B, 2 * C_mid) | |
gating = F.sigmoid(self.fc(F.relu(gating)))[..., None, None] # (B, C_im, 1, 1) | |
return x * gating | |
class ConvBlock(nn.Module): | |
""" | |
Convolutional blocks | |
""" | |
def __init__(self, in_channels, out_channels, stride=1, leaky=False, norm='INSTANCE'): | |
super().__init__() | |
self.norm_type = norm | |
# activation, support PReLU and common ReLU | |
self.act = nn.LeakyReLU() if leaky else nn.ReLU(inplace=False) | |
# self.act = nn.ELU() if leaky else nn.ReLU(inplace=True) | |
self.conv1 = ConvNorm(in_channels, out_channels, 3, stride, leaky, norm, True) | |
self.conv2 = ConvNorm(out_channels, out_channels, 3, 1, leaky, norm, False) | |
def forward(self, x): | |
out = self.conv1(x) | |
out = self.conv2(out) | |
if self.norm_type != 'GROUP': | |
out = self.act(out) | |
return out | |
class ResBlock(nn.Module): | |
""" | |
Residual blocks | |
""" | |
def __init__(self, in_channels, out_channels, stride=1, use_dropout=False, leaky=False, norm='INSTANCE'): | |
super().__init__() | |
self.norm_type = norm | |
self.act = nn.LeakyReLU() if leaky else nn.ReLU(inplace=False) | |
self.dropout = nn.Dropout2d(p=0.1) if use_dropout else None | |
# self.act = nn.ELU() if leaky else nn.ReLU(inplace=True) | |
self.conv1 = ConvNorm(in_channels, out_channels, 3, stride, leaky, norm, True) | |
self.conv2 = ConvNorm(out_channels, out_channels, 3, 1, leaky, norm, False) | |
need_map = in_channels != out_channels or stride != 1 | |
self.id = ConvNorm(in_channels, out_channels, 1, stride, leaky, norm, False) if need_map else Identity() | |
def forward(self, x): | |
identity = x | |
out = self.conv1(x) | |
out = self.conv2(out) | |
identity = self.id(identity) | |
out = out + identity | |
if self.norm_type != 'GROUP': | |
out = self.act(out) | |
if self.dropout: | |
out = self.dropout(out) | |
return out | |
class ConvBottleNeck(nn.Module): | |
""" | |
Convolutional bottleneck blocks | |
""" | |
def __init__(self, in_channels, out_channels, stride=1, leaky=False, norm='INSTANCE'): | |
super().__init__() | |
self.norm_type = norm | |
middle_channels = in_channels // 4 | |
self.act = nn.LeakyReLU() if leaky else nn.ReLU(inplace=False) | |
# self.act = nn.ELU() if leaky else nn.ReLU(inplace=True) | |
self.conv1 = ConvNorm(in_channels, middle_channels, 1, 1, leaky, norm, True) | |
self.conv2 = ConvNorm(middle_channels, middle_channels, 3, stride, leaky, norm, True) | |
self.conv3 = ConvNorm(middle_channels, out_channels, 1, 1, leaky, norm, False) | |
def forward(self, x): | |
out = self.conv1(x) | |
out = self.conv2(out) | |
out = self.conv3(out) | |
if self.norm_type != 'GROUP': | |
out = self.act(out) | |
return out | |
class ResBottleneck(nn.Module): | |
""" | |
Residual bottleneck blocks | |
""" | |
def __init__(self, in_channels, out_channels, stride=1, use_dropout=False, leaky=False, norm='INSTANCE'): | |
super().__init__() | |
self.norm_type = norm | |
middle_channels = in_channels // 4 | |
self.act = nn.LeakyReLU() if leaky else nn.ReLU(inplace=False) | |
self.dropout = nn.Dropout2d(p=0.1) if use_dropout else None | |
# self.act = nn.ELU() if leaky else nn.ReLU(inplace=True) | |
self.conv1 = ConvNorm(in_channels, middle_channels, 1, 1, leaky, norm, True) | |
self.conv2 = ConvNorm(middle_channels, middle_channels, 3, stride, leaky, norm, True) | |
self.conv3 = ConvNorm(middle_channels, out_channels, 1, 1, leaky, norm, False) | |
need_map = in_channels != out_channels or stride != 1 | |
self.id = ConvNorm(in_channels, out_channels, 1, stride, leaky, norm, False) if need_map else Identity() | |
def forward(self, x): | |
identity = x | |
out = self.conv1(x) | |
out = self.conv2(out) | |
out = self.conv3(out) | |
identity = self.id(identity) | |
out = out + identity | |
if self.norm_type != 'GROUP': | |
out = self.act(out) | |
if self.dropout: | |
out = self.dropout(out) | |
return out | |
class PromptResBlock(nn.Module): | |
""" | |
Residual blocks | |
""" | |
def __init__(self, in_channels, out_channels, stride=1, use_dropout=False, leaky=False, norm='INSTANCE'): | |
super().__init__() | |
self.norm_type = norm | |
self.act = nn.LeakyReLU() if leaky else nn.ReLU(inplace=False) | |
self.dropout = nn.Dropout2d(p=0.1) if use_dropout else None | |
# self.act = nn.ELU() if leaky else nn.ReLU(inplace=True) | |
self.conv1 = ConvNorm(in_channels, out_channels, 3, stride, leaky, norm, True) | |
self.conv2 = ConvNorm(out_channels, out_channels, 3, 1, leaky, norm, False) | |
self.attn = PromptAttentionModule(out_channels, 512, out_channels // 4) | |
need_map = in_channels != out_channels or stride != 1 | |
self.id = ConvNorm(in_channels, out_channels, 1, stride, leaky, norm, False) if need_map else Identity() | |
def forward(self, x, prompt_in): | |
identity = x | |
out = self.conv1(x) | |
out = self.conv2(out) | |
out = self.attn(out, prompt_in) | |
identity = self.id(identity) | |
out = out + identity | |
if self.norm_type != 'GROUP': | |
out = self.act(out) | |
if self.dropout: | |
out = self.dropout(out) | |
return out | |
class PromptResBottleneck(nn.Module): | |
""" | |
Residual bottleneck blocks | |
""" | |
def __init__(self, in_channels, out_channels, stride=1, use_dropout=False, leaky=False, norm='INSTANCE'): | |
super().__init__() | |
self.norm_type = norm | |
middle_channels = in_channels // 4 | |
self.act = nn.LeakyReLU() if leaky else nn.ReLU(inplace=False) | |
self.dropout = nn.Dropout2d(p=0.1) if use_dropout else None | |
# self.act = nn.ELU() if leaky else nn.ReLU(inplace=True) | |
self.conv1 = ConvNorm(in_channels, middle_channels, 1, 1, leaky, norm, True) | |
self.conv2 = ConvNorm(middle_channels, middle_channels, 3, stride, leaky, norm, True) | |
self.conv3 = ConvNorm(middle_channels, out_channels, 1, 1, leaky, norm, False) | |
self.attn = PromptAttentionModule(out_channels, 512, out_channels // 4) | |
need_map = in_channels != out_channels or stride != 1 | |
self.id = ConvNorm(in_channels, out_channels, 1, stride, leaky, norm, False) if need_map else Identity() | |
def forward(self, x, prompt_in): | |
identity = x | |
out = self.conv1(x) | |
out = self.conv2(out) | |
out = self.conv3(out) | |
out = self.attn(out, prompt_in) | |
identity = self.id(identity) | |
out = out + identity | |
if self.norm_type != 'GROUP': | |
out = self.act(out) | |
if self.dropout: | |
out = self.dropout(out) | |
return out | |
class SobelEdge(nn.Module): | |
def __init__(self, input_dim, channels, kernel_size=3, stride=1): | |
super().__init__() | |
conv = getattr(nn, 'Conv%dd' % input_dim) | |
self.filter = conv(channels, channels, kernel_size, stride, padding=(kernel_size - 1) // 2, | |
groups=channels, bias=False) | |
sobel = [[1, 2, 1], [0, 0, 0], [-1, -2, -1]] | |
sobel_kernel = torch.tensor(sobel, dtype=torch.float32).unsqueeze(0).expand([channels, 1] + [kernel_size] * input_dim) | |
self.filter.weight = nn.Parameter(sobel_kernel, requires_grad=False) | |
def forward(self, x): | |
with torch.no_grad(): | |
out = self.filter(x) | |
return out | |