josedolot commited on
Commit
ca27009
·
1 Parent(s): 34d220d

Upload encoders/senet.py

Browse files
Files changed (1) hide show
  1. encoders/senet.py +174 -0
encoders/senet.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin`
2
+
3
+ Attributes:
4
+
5
+ _out_channels (list of int): specify number of channels for each encoder feature tensor
6
+ _depth (int): specify number of stages in decoder (in other words number of downsampling operations)
7
+ _in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3)
8
+
9
+ Methods:
10
+
11
+ forward(self, x: torch.Tensor)
12
+ produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of
13
+ shape NCHW (features should be sorted in descending order according to spatial resolution, starting
14
+ with resolution same as input `x` tensor).
15
+
16
+ Input: `x` with shape (1, 3, 64, 64)
17
+ Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes
18
+ [(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8),
19
+ (1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ)
20
+
21
+ also should support number of features according to specified depth, e.g. if depth = 5,
22
+ number of feature tensors = 6 (one with same resolution as input and 5 downsampled),
23
+ depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled).
24
+ """
25
+
26
+ import torch.nn as nn
27
+
28
+ from pretrainedmodels.models.senet import (
29
+ SENet,
30
+ SEBottleneck,
31
+ SEResNetBottleneck,
32
+ SEResNeXtBottleneck,
33
+ pretrained_settings,
34
+ )
35
+ from ._base import EncoderMixin
36
+
37
+
38
+ class SENetEncoder(SENet, EncoderMixin):
39
+ def __init__(self, out_channels, depth=5, **kwargs):
40
+ super().__init__(**kwargs)
41
+
42
+ self._out_channels = out_channels
43
+ self._depth = depth
44
+ self._in_channels = 3
45
+
46
+ del self.last_linear
47
+ del self.avg_pool
48
+
49
+ def get_stages(self):
50
+ return [
51
+ nn.Identity(),
52
+ self.layer0[:-1],
53
+ nn.Sequential(self.layer0[-1], self.layer1),
54
+ self.layer2,
55
+ self.layer3,
56
+ self.layer4,
57
+ ]
58
+
59
+ def forward(self, x):
60
+ stages = self.get_stages()
61
+
62
+ features = []
63
+ for i in range(self._depth + 1):
64
+ x = stages[i](x)
65
+ features.append(x)
66
+
67
+ return features
68
+
69
+ def load_state_dict(self, state_dict, **kwargs):
70
+ state_dict.pop("last_linear.bias", None)
71
+ state_dict.pop("last_linear.weight", None)
72
+ super().load_state_dict(state_dict, **kwargs)
73
+
74
+
75
+ senet_encoders = {
76
+ "senet154": {
77
+ "encoder": SENetEncoder,
78
+ "pretrained_settings": pretrained_settings["senet154"],
79
+ "params": {
80
+ "out_channels": (3, 128, 256, 512, 1024, 2048),
81
+ "block": SEBottleneck,
82
+ "dropout_p": 0.2,
83
+ "groups": 64,
84
+ "layers": [3, 8, 36, 3],
85
+ "num_classes": 1000,
86
+ "reduction": 16,
87
+ },
88
+ },
89
+ "se_resnet50": {
90
+ "encoder": SENetEncoder,
91
+ "pretrained_settings": pretrained_settings["se_resnet50"],
92
+ "params": {
93
+ "out_channels": (3, 64, 256, 512, 1024, 2048),
94
+ "block": SEResNetBottleneck,
95
+ "layers": [3, 4, 6, 3],
96
+ "downsample_kernel_size": 1,
97
+ "downsample_padding": 0,
98
+ "dropout_p": None,
99
+ "groups": 1,
100
+ "inplanes": 64,
101
+ "input_3x3": False,
102
+ "num_classes": 1000,
103
+ "reduction": 16,
104
+ },
105
+ },
106
+ "se_resnet101": {
107
+ "encoder": SENetEncoder,
108
+ "pretrained_settings": pretrained_settings["se_resnet101"],
109
+ "params": {
110
+ "out_channels": (3, 64, 256, 512, 1024, 2048),
111
+ "block": SEResNetBottleneck,
112
+ "layers": [3, 4, 23, 3],
113
+ "downsample_kernel_size": 1,
114
+ "downsample_padding": 0,
115
+ "dropout_p": None,
116
+ "groups": 1,
117
+ "inplanes": 64,
118
+ "input_3x3": False,
119
+ "num_classes": 1000,
120
+ "reduction": 16,
121
+ },
122
+ },
123
+ "se_resnet152": {
124
+ "encoder": SENetEncoder,
125
+ "pretrained_settings": pretrained_settings["se_resnet152"],
126
+ "params": {
127
+ "out_channels": (3, 64, 256, 512, 1024, 2048),
128
+ "block": SEResNetBottleneck,
129
+ "layers": [3, 8, 36, 3],
130
+ "downsample_kernel_size": 1,
131
+ "downsample_padding": 0,
132
+ "dropout_p": None,
133
+ "groups": 1,
134
+ "inplanes": 64,
135
+ "input_3x3": False,
136
+ "num_classes": 1000,
137
+ "reduction": 16,
138
+ },
139
+ },
140
+ "se_resnext50_32x4d": {
141
+ "encoder": SENetEncoder,
142
+ "pretrained_settings": pretrained_settings["se_resnext50_32x4d"],
143
+ "params": {
144
+ "out_channels": (3, 64, 256, 512, 1024, 2048),
145
+ "block": SEResNeXtBottleneck,
146
+ "layers": [3, 4, 6, 3],
147
+ "downsample_kernel_size": 1,
148
+ "downsample_padding": 0,
149
+ "dropout_p": None,
150
+ "groups": 32,
151
+ "inplanes": 64,
152
+ "input_3x3": False,
153
+ "num_classes": 1000,
154
+ "reduction": 16,
155
+ },
156
+ },
157
+ "se_resnext101_32x4d": {
158
+ "encoder": SENetEncoder,
159
+ "pretrained_settings": pretrained_settings["se_resnext101_32x4d"],
160
+ "params": {
161
+ "out_channels": (3, 64, 256, 512, 1024, 2048),
162
+ "block": SEResNeXtBottleneck,
163
+ "layers": [3, 4, 23, 3],
164
+ "downsample_kernel_size": 1,
165
+ "downsample_padding": 0,
166
+ "dropout_p": None,
167
+ "groups": 32,
168
+ "inplanes": 64,
169
+ "input_3x3": False,
170
+ "num_classes": 1000,
171
+ "reduction": 16,
172
+ },
173
+ },
174
+ }