File size: 1,841 Bytes
c17bef1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
import torch.nn as nn
from torchvision import models

class EnhancedHybridFusionClassifier(nn.Module):
    def __init__(self, train_base=False):
        super().__init__()

        # EfficientNetB0
        self.eff = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT)
        self.eff_features = self.eff.features
        self.eff_pool = self.eff.avgpool
        eff_out = 1280

        # ResNet50
        self.res = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
        res_out = self.res.fc.in_features
        self.res.fc = nn.Identity()

        # MobileNetV2
        self.mob = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.DEFAULT)
        mob_out = self.mob.classifier[1].in_features
        self.mob.classifier = nn.Identity()

        # Optionally freeze
        if not train_base:
            for p in self.eff_features.parameters(): p.requires_grad = False
            for p in self.res.parameters(): p.requires_grad = False
            for p in self.mob.parameters(): p.requires_grad = False

        # Final Classifier Head
        total_features = eff_out + res_out + mob_out
        self.classifier = nn.Sequential(
            nn.BatchNorm1d(total_features),
            nn.Dropout(0.5),
            nn.Linear(total_features, 512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(0.4),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Dropout(0.3),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        eff_x = self.eff_pool(self.eff_features(x))
        eff_x = torch.flatten(eff_x, 1)
        res_x = self.res(x)
        mob_x = self.mob(x)
        fusion = torch.cat([eff_x, res_x, mob_x], dim=1)
        return self.classifier(fusion)