import torch.nn as nn | |
from torchvision import models | |
class ResNet50Classifier(nn.Module): | |
def __init__(self, train_base=False): | |
super().__init__() | |
self.base_model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) | |
for param in self.base_model.parameters(): | |
param.requires_grad = train_base | |
in_features = self.base_model.fc.in_features | |
self.base_model.fc = nn.Sequential( | |
nn.BatchNorm1d(in_features), | |
nn.Dropout(0.5), | |
nn.Linear(in_features, 128), | |
nn.ReLU(), | |
nn.BatchNorm1d(128), | |
nn.Dropout(0.5), | |
nn.Linear(128, 1), | |
nn.Sigmoid() | |
) | |
def forward(self, x): | |
return self.base_model(x) | |