import torch import torch.nn as nn from torchvision import models class EfficientNetB4Classifier(nn.Module): def __init__(self, train_base=False): super().__init__() self.base_model = models.efficientnet_b4(weights=models.EfficientNet_B4_Weights.DEFAULT) for param in self.base_model.features.parameters(): param.requires_grad = train_base self.classifier = nn.Sequential( nn.BatchNorm1d(1792), nn.Dropout(0.5), nn.Linear(1792, 256), nn.ReLU(), nn.BatchNorm1d(256), nn.Dropout(0.5), nn.Linear(256, 1), nn.Sigmoid() ) def forward(self, x): x = self.base_model.features(x) x = self.base_model.avgpool(x) x = torch.flatten(x, 1) return self.classifier(x)