Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # | |
| # This source code is licensed under the Apache License, Version 2.0 | |
| # found in the LICENSE file in the root directory of this source tree. | |
| import torch | |
| import torch.nn as nn | |
| import torchvision | |
| class RNet(nn.Module): | |
| def __init__( | |
| self, | |
| n_channels=3, | |
| n_classes=13, | |
| n_pix=256, | |
| filters=(8, 16, 32, 64, 64, 128), | |
| pool=(2, 2), | |
| kernel_size=(3, 3), | |
| n_meta=0, | |
| ) -> None: | |
| super(RNet, self).__init__() | |
| def conv_block(in_filters, out_filters, kernel_size): | |
| layers = nn.Sequential( | |
| # first conv is across channels, size=1 | |
| nn.Conv2d(in_filters, out_filters, kernel_size=(1, 1), padding="same"), | |
| nn.BatchNorm2d(out_filters), | |
| nn.ReLU(), | |
| nn.Conv2d( | |
| out_filters, out_filters, kernel_size=kernel_size, padding="same" | |
| ), | |
| ) | |
| return layers | |
| def fc_block(in_features, out_features): | |
| layers = nn.Sequential( | |
| nn.Linear(in_features=in_features, out_features=out_features), | |
| #nn.BatchNorm1d(out_features), | |
| #nn.InstanceNorm1d(out_features), | |
| nn.ReLU(), | |
| ) | |
| return layers | |
| self.pool = nn.MaxPool2d(2, 2) | |
| self.input_layer = conv_block(n_channels, filters[0], kernel_size) | |
| self.conv_block1 = conv_block(filters[0], filters[1], kernel_size) | |
| self.conv_block2 = conv_block(filters[1], filters[2], kernel_size) | |
| self.conv_block3 = conv_block(filters[2], filters[3], kernel_size) | |
| self.conv_block4 = conv_block(filters[3], filters[4], kernel_size) | |
| self.conv_block5 = conv_block(filters[4], filters[5], kernel_size) | |
| n_pool = 5 | |
| self.fc1 = fc_block(in_features= int(filters[5] * (n_pix / 2**n_pool) ** 2), out_features=64) | |
| self.fc2 = fc_block(in_features=64 + n_meta, out_features=64) | |
| self.fc3 = fc_block(in_features=64, out_features=32) | |
| self.fc4 = nn.Linear(in_features=32, out_features=n_classes) | |
| def forward(self, x): | |
| x1 = self.pool(self.input_layer(x)) | |
| x2 = self.pool(self.conv_block1(x1)) | |
| x3 = self.pool(self.conv_block2(x2)) | |
| x4 = self.pool(self.conv_block3(x3)) | |
| x4b = self.pool(self.conv_block4(x4)) | |
| x5 = self.conv_block5(x4b) | |
| x6 = torch.flatten(x5, 1) # flatten all dimensions except batch | |
| x7 = self.fc1(x6) | |
| x9 = self.fc2(x7) | |
| x10 = self.fc3(x9) | |
| x11 = self.fc4(x10) | |
| return x11 |