import torch import torch.nn as nn from torchvision import models class EnhancedHybridFusionClassifier(nn.Module): def __init__(self, train_base=False): super().__init__() # EfficientNetB0 self.eff = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT) self.eff_features = self.eff.features self.eff_pool = self.eff.avgpool eff_out = 1280 # ResNet50 self.res = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) res_out = self.res.fc.in_features self.res.fc = nn.Identity() # MobileNetV2 self.mob = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.DEFAULT) mob_out = self.mob.classifier[1].in_features self.mob.classifier = nn.Identity() # Optionally freeze if not train_base: for p in self.eff_features.parameters(): p.requires_grad = False for p in self.res.parameters(): p.requires_grad = False for p in self.mob.parameters(): p.requires_grad = False # Final Classifier Head total_features = eff_out + res_out + mob_out self.classifier = nn.Sequential( nn.BatchNorm1d(total_features), nn.Dropout(0.5), nn.Linear(total_features, 512), nn.ReLU(), nn.BatchNorm1d(512), nn.Dropout(0.4), nn.Linear(512, 128), nn.ReLU(), nn.BatchNorm1d(128), nn.Dropout(0.3), nn.Linear(128, 1), nn.Sigmoid() ) def forward(self, x): eff_x = self.eff_pool(self.eff_features(x)) eff_x = torch.flatten(eff_x, 1) res_x = self.res(x) mob_x = self.mob(x) fusion = torch.cat([eff_x, res_x, mob_x], dim=1) return self.classifier(fusion)