|
import torch |
|
import torch.nn as nn |
|
from torchvision import models |
|
|
|
class EnhancedHybridFusionClassifier(nn.Module): |
|
def __init__(self, train_base=False): |
|
super().__init__() |
|
|
|
|
|
self.eff = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT) |
|
self.eff_features = self.eff.features |
|
self.eff_pool = self.eff.avgpool |
|
eff_out = 1280 |
|
|
|
|
|
self.res = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) |
|
res_out = self.res.fc.in_features |
|
self.res.fc = nn.Identity() |
|
|
|
|
|
self.mob = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.DEFAULT) |
|
mob_out = self.mob.classifier[1].in_features |
|
self.mob.classifier = nn.Identity() |
|
|
|
|
|
if not train_base: |
|
for p in self.eff_features.parameters(): p.requires_grad = False |
|
for p in self.res.parameters(): p.requires_grad = False |
|
for p in self.mob.parameters(): p.requires_grad = False |
|
|
|
|
|
total_features = eff_out + res_out + mob_out |
|
self.classifier = nn.Sequential( |
|
nn.BatchNorm1d(total_features), |
|
nn.Dropout(0.5), |
|
nn.Linear(total_features, 512), |
|
nn.ReLU(), |
|
nn.BatchNorm1d(512), |
|
nn.Dropout(0.4), |
|
nn.Linear(512, 128), |
|
nn.ReLU(), |
|
nn.BatchNorm1d(128), |
|
nn.Dropout(0.3), |
|
nn.Linear(128, 1), |
|
nn.Sigmoid() |
|
) |
|
|
|
def forward(self, x): |
|
eff_x = self.eff_pool(self.eff_features(x)) |
|
eff_x = torch.flatten(eff_x, 1) |
|
res_x = self.res(x) |
|
mob_x = self.mob(x) |
|
fusion = torch.cat([eff_x, res_x, mob_x], dim=1) |
|
return self.classifier(fusion) |