josedolot commited on
Commit
89564ab
·
1 Parent(s): 05dfc74

Upload backbone.py

Browse files
Files changed (1) hide show
  1. backbone.py +143 -0
backbone.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import timm
4
+
5
+ from hybridnets.model import BiFPN, Regressor, Classifier, BiFPNDecoder
6
+ from utils.utils import Anchors
7
+ from hybridnets.model import SegmentationHead
8
+
9
+ from encoders import get_encoder
10
+
11
+ class HybridNetsBackbone(nn.Module):
12
+ def __init__(self, num_classes=80, compound_coef=0, seg_classes=1, backbone_name=None, **kwargs):
13
+ super(HybridNetsBackbone, self).__init__()
14
+ self.compound_coef = compound_coef
15
+
16
+ self.seg_classes = seg_classes
17
+
18
+ self.backbone_compound_coef = [0, 1, 2, 3, 4, 5, 6, 6, 7]
19
+ self.fpn_num_filters = [64, 88, 112, 160, 224, 288, 384, 384, 384]
20
+ self.fpn_cell_repeats = [3, 4, 5, 6, 7, 7, 8, 8, 8]
21
+ self.input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536, 1536]
22
+ self.box_class_repeats = [3, 3, 3, 4, 4, 4, 5, 5, 5]
23
+ self.pyramid_levels = [5, 5, 5, 5, 5, 5, 5, 5, 6]
24
+ self.anchor_scale = [1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,]
25
+ self.aspect_ratios = kwargs.get('ratios', [(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)])
26
+ self.num_scales = len(kwargs.get('scales', [2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)]))
27
+ conv_channel_coef = {
28
+ # the channels of P3/P4/P5.
29
+ 0: [40, 112, 320],
30
+ 1: [40, 112, 320],
31
+ 2: [48, 120, 352],
32
+ 3: [48, 136, 384],
33
+ 4: [56, 160, 448],
34
+ 5: [64, 176, 512],
35
+ 6: [72, 200, 576],
36
+ 7: [72, 200, 576],
37
+ 8: [80, 224, 640],
38
+ }
39
+
40
+ num_anchors = len(self.aspect_ratios) * self.num_scales
41
+
42
+ self.bifpn = nn.Sequential(
43
+ *[BiFPN(self.fpn_num_filters[self.compound_coef],
44
+ conv_channel_coef[compound_coef],
45
+ True if _ == 0 else False,
46
+ attention=True if compound_coef < 6 else False,
47
+ use_p8=compound_coef > 7)
48
+ for _ in range(self.fpn_cell_repeats[compound_coef])])
49
+
50
+ self.num_classes = num_classes
51
+ self.regressor = Regressor(in_channels=self.fpn_num_filters[self.compound_coef], num_anchors=num_anchors,
52
+ num_layers=self.box_class_repeats[self.compound_coef],
53
+ pyramid_levels=self.pyramid_levels[self.compound_coef])
54
+
55
+ '''Modified by Dat Vu'''
56
+ # self.decoder = DecoderModule()
57
+ self.bifpndecoder = BiFPNDecoder(pyramid_channels=self.fpn_num_filters[self.compound_coef])
58
+
59
+ self.segmentation_head = SegmentationHead(
60
+ in_channels=64,
61
+ out_channels=self.seg_classes+1 if self.seg_classes > 1 else self.seg_classes,
62
+ activation='softmax2d' if self.seg_classes > 1 else 'sigmoid',
63
+ kernel_size=1,
64
+ upsampling=4,
65
+ )
66
+
67
+ self.classifier = Classifier(in_channels=self.fpn_num_filters[self.compound_coef], num_anchors=num_anchors,
68
+ num_classes=num_classes,
69
+ num_layers=self.box_class_repeats[self.compound_coef],
70
+ pyramid_levels=self.pyramid_levels[self.compound_coef])
71
+
72
+ self.anchors = Anchors(anchor_scale=self.anchor_scale[compound_coef],
73
+ pyramid_levels=(torch.arange(self.pyramid_levels[self.compound_coef]) + 3).tolist(),
74
+ **kwargs)
75
+
76
+ if backbone_name:
77
+ # Use timm to create another backbone that you prefer
78
+ # https://github.com/rwightman/pytorch-image-models
79
+ self.encoder = timm.create_model(backbone_name, pretrained=True, features_only=True, out_indices=(2,3,4)) # P3,P4,P5
80
+ else:
81
+ # EfficientNet_Pytorch
82
+ self.encoder = get_encoder(
83
+ 'efficientnet-b' + str(self.backbone_compound_coef[compound_coef]),
84
+ in_channels=3,
85
+ depth=5,
86
+ weights='imagenet',
87
+ )
88
+
89
+ self.initialize_decoder(self.bifpndecoder)
90
+ self.initialize_head(self.segmentation_head)
91
+ self.initialize_decoder(self.bifpn)
92
+
93
+ def freeze_bn(self):
94
+ for m in self.modules():
95
+ if isinstance(m, nn.BatchNorm2d):
96
+ m.eval()
97
+
98
+ def forward(self, inputs):
99
+ max_size = inputs.shape[-1]
100
+
101
+ # p1, p2, p3, p4, p5 = self.backbone_net(inputs)
102
+ p2, p3, p4, p5 = self.encoder(inputs)[-4:] # self.backbone_net(inputs)
103
+
104
+ features = (p3, p4, p5)
105
+
106
+ features = self.bifpn(features)
107
+
108
+ p3,p4,p5,p6,p7 = features
109
+
110
+ outputs = self.bifpndecoder((p2,p3,p4,p5,p6,p7))
111
+
112
+ segmentation = self.segmentation_head(outputs)
113
+
114
+ regression = self.regressor(features)
115
+ classification = self.classifier(features)
116
+ anchors = self.anchors(inputs, inputs.dtype)
117
+
118
+ return features, regression, classification, anchors, segmentation
119
+
120
+ def initialize_decoder(self, module):
121
+ for m in module.modules():
122
+
123
+ if isinstance(m, nn.Conv2d):
124
+ nn.init.kaiming_uniform_(m.weight, mode="fan_in", nonlinearity="relu")
125
+ if m.bias is not None:
126
+ nn.init.constant_(m.bias, 0)
127
+
128
+ elif isinstance(m, nn.BatchNorm2d):
129
+ nn.init.constant_(m.weight, 1)
130
+ nn.init.constant_(m.bias, 0)
131
+
132
+ elif isinstance(m, nn.Linear):
133
+ nn.init.xavier_uniform_(m.weight)
134
+ if m.bias is not None:
135
+ nn.init.constant_(m.bias, 0)
136
+
137
+
138
+ def initialize_head(self, module):
139
+ for m in module.modules():
140
+ if isinstance(m, (nn.Linear, nn.Conv2d)):
141
+ nn.init.xavier_uniform_(m.weight)
142
+ if m.bias is not None:
143
+ nn.init.constant_(m.bias, 0)