| # models/efficientnet_b0.py | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import models | |
| class EfficientNetB0Classifier(nn.Module): | |
| def __init__(self, train_base: bool = False): | |
| """ | |
| Initialize EfficientNetB0-based binary classifier. | |
| :param train_base: If True, allows fine-tuning the base model. | |
| """ | |
| super().__init__() | |
| self.base_model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT) | |
| for param in self.base_model.features.parameters(): | |
| param.requires_grad = train_base | |
| self.classifier = nn.Sequential( | |
| nn.BatchNorm1d(1280), | |
| nn.Dropout(0.5), | |
| nn.Linear(1280, 128), | |
| nn.ReLU(), | |
| nn.BatchNorm1d(128), | |
| nn.Dropout(0.5), | |
| nn.Linear(128, 1), | |
| nn.Sigmoid() | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Forward pass through the network. | |
| :param x: Input image tensor | |
| :return: Output probability | |
| """ | |
| x = self.base_model.features(x) | |
| x = self.base_model.avgpool(x) | |
| x = torch.flatten(x, 1) | |
| return self.classifier(x) | |