josedolot commited on
Commit
309a856
·
1 Parent(s): 4b573e4

Upload encoders/xception.py

Browse files
Files changed (1) hide show
  1. encoders/xception.py +66 -0
encoders/xception.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import torch.nn as nn
3
+
4
+ from pretrainedmodels.models.xception import pretrained_settings
5
+ from pretrainedmodels.models.xception import Xception
6
+
7
+ from ._base import EncoderMixin
8
+
9
+
10
+ class XceptionEncoder(Xception, EncoderMixin):
11
+
12
+ def __init__(self, out_channels, *args, depth=5, **kwargs):
13
+ super().__init__(*args, **kwargs)
14
+
15
+ self._out_channels = out_channels
16
+ self._depth = depth
17
+ self._in_channels = 3
18
+
19
+ # modify padding to maintain output shape
20
+ self.conv1.padding = (1, 1)
21
+ self.conv2.padding = (1, 1)
22
+
23
+ del self.fc
24
+
25
+ def make_dilated(self, stage_list, dilation_list):
26
+ raise ValueError("Xception encoder does not support dilated mode "
27
+ "due to pooling operation for downsampling!")
28
+
29
+ def get_stages(self):
30
+ return [
31
+ nn.Identity(),
32
+ nn.Sequential(self.conv1, self.bn1, self.relu, self.conv2, self.bn2, self.relu),
33
+ self.block1,
34
+ self.block2,
35
+ nn.Sequential(self.block3, self.block4, self.block5, self.block6, self.block7,
36
+ self.block8, self.block9, self.block10, self.block11),
37
+ nn.Sequential(self.block12, self.conv3, self.bn3, self.relu, self.conv4, self.bn4),
38
+ ]
39
+
40
+ def forward(self, x):
41
+ stages = self.get_stages()
42
+
43
+ features = []
44
+ for i in range(self._depth + 1):
45
+ x = stages[i](x)
46
+ features.append(x)
47
+
48
+ return features
49
+
50
+ def load_state_dict(self, state_dict):
51
+ # remove linear
52
+ state_dict.pop('fc.bias', None)
53
+ state_dict.pop('fc.weight', None)
54
+
55
+ super().load_state_dict(state_dict)
56
+
57
+
58
+ xception_encoders = {
59
+ 'xception': {
60
+ 'encoder': XceptionEncoder,
61
+ 'pretrained_settings': pretrained_settings['xception'],
62
+ 'params': {
63
+ 'out_channels': (3, 64, 128, 256, 728, 2048),
64
+ }
65
+ },
66
+ }