Spaces:
Runtime error
Runtime error
Upload encoders/xception.py
Browse files- encoders/xception.py +66 -0
encoders/xception.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from pretrainedmodels.models.xception import pretrained_settings
|
5 |
+
from pretrainedmodels.models.xception import Xception
|
6 |
+
|
7 |
+
from ._base import EncoderMixin
|
8 |
+
|
9 |
+
|
10 |
+
class XceptionEncoder(Xception, EncoderMixin):
|
11 |
+
|
12 |
+
def __init__(self, out_channels, *args, depth=5, **kwargs):
|
13 |
+
super().__init__(*args, **kwargs)
|
14 |
+
|
15 |
+
self._out_channels = out_channels
|
16 |
+
self._depth = depth
|
17 |
+
self._in_channels = 3
|
18 |
+
|
19 |
+
# modify padding to maintain output shape
|
20 |
+
self.conv1.padding = (1, 1)
|
21 |
+
self.conv2.padding = (1, 1)
|
22 |
+
|
23 |
+
del self.fc
|
24 |
+
|
25 |
+
def make_dilated(self, stage_list, dilation_list):
|
26 |
+
raise ValueError("Xception encoder does not support dilated mode "
|
27 |
+
"due to pooling operation for downsampling!")
|
28 |
+
|
29 |
+
def get_stages(self):
|
30 |
+
return [
|
31 |
+
nn.Identity(),
|
32 |
+
nn.Sequential(self.conv1, self.bn1, self.relu, self.conv2, self.bn2, self.relu),
|
33 |
+
self.block1,
|
34 |
+
self.block2,
|
35 |
+
nn.Sequential(self.block3, self.block4, self.block5, self.block6, self.block7,
|
36 |
+
self.block8, self.block9, self.block10, self.block11),
|
37 |
+
nn.Sequential(self.block12, self.conv3, self.bn3, self.relu, self.conv4, self.bn4),
|
38 |
+
]
|
39 |
+
|
40 |
+
def forward(self, x):
|
41 |
+
stages = self.get_stages()
|
42 |
+
|
43 |
+
features = []
|
44 |
+
for i in range(self._depth + 1):
|
45 |
+
x = stages[i](x)
|
46 |
+
features.append(x)
|
47 |
+
|
48 |
+
return features
|
49 |
+
|
50 |
+
def load_state_dict(self, state_dict):
|
51 |
+
# remove linear
|
52 |
+
state_dict.pop('fc.bias', None)
|
53 |
+
state_dict.pop('fc.weight', None)
|
54 |
+
|
55 |
+
super().load_state_dict(state_dict)
|
56 |
+
|
57 |
+
|
58 |
+
xception_encoders = {
|
59 |
+
'xception': {
|
60 |
+
'encoder': XceptionEncoder,
|
61 |
+
'pretrained_settings': pretrained_settings['xception'],
|
62 |
+
'params': {
|
63 |
+
'out_channels': (3, 64, 128, 256, 728, 2048),
|
64 |
+
}
|
65 |
+
},
|
66 |
+
}
|