import torch.nn as nn from torchvision import models class ResNet50Classifier(nn.Module): def __init__(self, train_base=False): super().__init__() self.base_model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) for param in self.base_model.parameters(): param.requires_grad = train_base in_features = self.base_model.fc.in_features self.base_model.fc = nn.Sequential( nn.BatchNorm1d(in_features), nn.Dropout(0.5), nn.Linear(in_features, 128), nn.ReLU(), nn.BatchNorm1d(128), nn.Dropout(0.5), nn.Linear(128, 1), nn.Sigmoid() ) def forward(self, x): return self.base_model(x)