import torch.nn as nn from torchvision import models class MobileNetV2Classifier(nn.Module): def __init__(self, train_base=False): super().__init__() self.base_model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.DEFAULT) for param in self.base_model.features.parameters(): param.requires_grad = train_base in_features = self.base_model.classifier[1].in_features self.base_model.classifier = nn.Sequential( nn.BatchNorm1d(in_features), nn.Dropout(0.5), nn.Linear(in_features, 128), nn.ReLU(), nn.BatchNorm1d(128), nn.Dropout(0.5), nn.Linear(128, 1), nn.Sigmoid() ) def forward(self, x): return self.base_model(x)