import torch import torch.nn as nn import torch.nn.functional as F norm_dict = {'BATCH': nn.BatchNorm2d, 'INSTANCE': nn.InstanceNorm2d, 'GROUP': nn.GroupNorm} NUM_GROUPS = 16 __all__ = ['ConvNorm', 'ConvBlock', 'ConvBottleNeck', 'ResBlock', 'ResBottleneck', 'PromptResBlock', 'PromptResBottleneck', 'PromptAttentionModule', 'norm_dict', 'SobelEdge'] class Identity(nn.Module): """ Identity mapping for building a residual connection """ def __init__(self): super().__init__() def forward(self, x): return x class ConvNorm(nn.Module): """ Convolution and normalization """ def __init__(self, in_channels, out_channels, kernel_size, stride=1, leaky=True, norm='INSTANCE', activation=True): super().__init__() # determine basic attributes self.norm_type = norm padding = (kernel_size - 1) // 2 # activation, support PReLU and common ReLU if activation: self.act = nn.LeakyReLU() if leaky else nn.ReLU(inplace=False) # self.act = nn.ELU() if leaky else nn.ReLU(inplace=True) else: self.act = None # instantiate layers self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False) if norm in ['BATCH', 'INSTANCE']: norm_layer = norm_dict[norm] self.norm = norm_layer(out_channels) elif norm == 'GROUP': norm_layer = norm_dict[norm] self.norm = norm_layer(NUM_GROUPS, in_channels) elif norm == 'NONE': self.norm = nn.Identity() else: raise NotImplementedError(f'Normalization type {norm} not implemented') def basic_forward(self, x): x = self.conv(x) x = self.norm(x) if self.act: x = self.act(x) return x def group_forward(self, x): x = self.norm(x) if self.act: x = self.act(x) x = self.conv(x) return x def forward(self, x): if self.norm_type in ['BATCH', 'INSTANCE']: return self.basic_forward(x) else: return self.group_forward(x) class PromptAttentionModule(nn.Module): def __init__(self, in_channels: int, prompt_channels: int, mid_channels: int) -> None: super().__init__() self.gap = nn.AdaptiveAvgPool2d(1) self.conv_down = nn.Linear(in_channels, mid_channels) self.prompt_down = nn.Linear(prompt_channels, mid_channels) self.fc = nn.Linear(2 * mid_channels, in_channels) def forward(self, x: torch.Tensor, prompt_in: torch.Tensor): """ Args: x: (B, C_im, H, W) prompt_in: (B, C_prompt) """ x_gap = self.gap(x).squeeze(-1).squeeze(-1) # (B, C_im) x_gap = self.conv_down(x_gap) # (B, C_mid) prompt_down = self.prompt_down(prompt_in) # (B, C_mid) gating = torch.cat([x_gap, prompt_down], dim=-1) # (B, 2 * C_mid) gating = F.sigmoid(self.fc(F.relu(gating)))[..., None, None] # (B, C_im, 1, 1) return x * gating class ConvBlock(nn.Module): """ Convolutional blocks """ def __init__(self, in_channels, out_channels, stride=1, leaky=False, norm='INSTANCE'): super().__init__() self.norm_type = norm # activation, support PReLU and common ReLU self.act = nn.LeakyReLU() if leaky else nn.ReLU(inplace=False) # self.act = nn.ELU() if leaky else nn.ReLU(inplace=True) self.conv1 = ConvNorm(in_channels, out_channels, 3, stride, leaky, norm, True) self.conv2 = ConvNorm(out_channels, out_channels, 3, 1, leaky, norm, False) def forward(self, x): out = self.conv1(x) out = self.conv2(out) if self.norm_type != 'GROUP': out = self.act(out) return out class ResBlock(nn.Module): """ Residual blocks """ def __init__(self, in_channels, out_channels, stride=1, use_dropout=False, leaky=False, norm='INSTANCE'): super().__init__() self.norm_type = norm self.act = nn.LeakyReLU() if leaky else nn.ReLU(inplace=False) self.dropout = nn.Dropout2d(p=0.1) if use_dropout else None # self.act = nn.ELU() if leaky else nn.ReLU(inplace=True) self.conv1 = ConvNorm(in_channels, out_channels, 3, stride, leaky, norm, True) self.conv2 = ConvNorm(out_channels, out_channels, 3, 1, leaky, norm, False) need_map = in_channels != out_channels or stride != 1 self.id = ConvNorm(in_channels, out_channels, 1, stride, leaky, norm, False) if need_map else Identity() def forward(self, x): identity = x out = self.conv1(x) out = self.conv2(out) identity = self.id(identity) out = out + identity if self.norm_type != 'GROUP': out = self.act(out) if self.dropout: out = self.dropout(out) return out class ConvBottleNeck(nn.Module): """ Convolutional bottleneck blocks """ def __init__(self, in_channels, out_channels, stride=1, leaky=False, norm='INSTANCE'): super().__init__() self.norm_type = norm middle_channels = in_channels // 4 self.act = nn.LeakyReLU() if leaky else nn.ReLU(inplace=False) # self.act = nn.ELU() if leaky else nn.ReLU(inplace=True) self.conv1 = ConvNorm(in_channels, middle_channels, 1, 1, leaky, norm, True) self.conv2 = ConvNorm(middle_channels, middle_channels, 3, stride, leaky, norm, True) self.conv3 = ConvNorm(middle_channels, out_channels, 1, 1, leaky, norm, False) def forward(self, x): out = self.conv1(x) out = self.conv2(out) out = self.conv3(out) if self.norm_type != 'GROUP': out = self.act(out) return out class ResBottleneck(nn.Module): """ Residual bottleneck blocks """ def __init__(self, in_channels, out_channels, stride=1, use_dropout=False, leaky=False, norm='INSTANCE'): super().__init__() self.norm_type = norm middle_channels = in_channels // 4 self.act = nn.LeakyReLU() if leaky else nn.ReLU(inplace=False) self.dropout = nn.Dropout2d(p=0.1) if use_dropout else None # self.act = nn.ELU() if leaky else nn.ReLU(inplace=True) self.conv1 = ConvNorm(in_channels, middle_channels, 1, 1, leaky, norm, True) self.conv2 = ConvNorm(middle_channels, middle_channels, 3, stride, leaky, norm, True) self.conv3 = ConvNorm(middle_channels, out_channels, 1, 1, leaky, norm, False) need_map = in_channels != out_channels or stride != 1 self.id = ConvNorm(in_channels, out_channels, 1, stride, leaky, norm, False) if need_map else Identity() def forward(self, x): identity = x out = self.conv1(x) out = self.conv2(out) out = self.conv3(out) identity = self.id(identity) out = out + identity if self.norm_type != 'GROUP': out = self.act(out) if self.dropout: out = self.dropout(out) return out class PromptResBlock(nn.Module): """ Residual blocks """ def __init__(self, in_channels, out_channels, stride=1, use_dropout=False, leaky=False, norm='INSTANCE'): super().__init__() self.norm_type = norm self.act = nn.LeakyReLU() if leaky else nn.ReLU(inplace=False) self.dropout = nn.Dropout2d(p=0.1) if use_dropout else None # self.act = nn.ELU() if leaky else nn.ReLU(inplace=True) self.conv1 = ConvNorm(in_channels, out_channels, 3, stride, leaky, norm, True) self.conv2 = ConvNorm(out_channels, out_channels, 3, 1, leaky, norm, False) self.attn = PromptAttentionModule(out_channels, 512, out_channels // 4) need_map = in_channels != out_channels or stride != 1 self.id = ConvNorm(in_channels, out_channels, 1, stride, leaky, norm, False) if need_map else Identity() def forward(self, x, prompt_in): identity = x out = self.conv1(x) out = self.conv2(out) out = self.attn(out, prompt_in) identity = self.id(identity) out = out + identity if self.norm_type != 'GROUP': out = self.act(out) if self.dropout: out = self.dropout(out) return out class PromptResBottleneck(nn.Module): """ Residual bottleneck blocks """ def __init__(self, in_channels, out_channels, stride=1, use_dropout=False, leaky=False, norm='INSTANCE'): super().__init__() self.norm_type = norm middle_channels = in_channels // 4 self.act = nn.LeakyReLU() if leaky else nn.ReLU(inplace=False) self.dropout = nn.Dropout2d(p=0.1) if use_dropout else None # self.act = nn.ELU() if leaky else nn.ReLU(inplace=True) self.conv1 = ConvNorm(in_channels, middle_channels, 1, 1, leaky, norm, True) self.conv2 = ConvNorm(middle_channels, middle_channels, 3, stride, leaky, norm, True) self.conv3 = ConvNorm(middle_channels, out_channels, 1, 1, leaky, norm, False) self.attn = PromptAttentionModule(out_channels, 512, out_channels // 4) need_map = in_channels != out_channels or stride != 1 self.id = ConvNorm(in_channels, out_channels, 1, stride, leaky, norm, False) if need_map else Identity() def forward(self, x, prompt_in): identity = x out = self.conv1(x) out = self.conv2(out) out = self.conv3(out) out = self.attn(out, prompt_in) identity = self.id(identity) out = out + identity if self.norm_type != 'GROUP': out = self.act(out) if self.dropout: out = self.dropout(out) return out class SobelEdge(nn.Module): def __init__(self, input_dim, channels, kernel_size=3, stride=1): super().__init__() conv = getattr(nn, 'Conv%dd' % input_dim) self.filter = conv(channels, channels, kernel_size, stride, padding=(kernel_size - 1) // 2, groups=channels, bias=False) sobel = [[1, 2, 1], [0, 0, 0], [-1, -2, -1]] sobel_kernel = torch.tensor(sobel, dtype=torch.float32).unsqueeze(0).expand([channels, 1] + [kernel_size] * input_dim) self.filter.weight = nn.Parameter(sobel_kernel, requires_grad=False) def forward(self, x): with torch.no_grad(): out = self.filter(x) return out