roqueselopeta's picture
Initial commit with clean project files
c17bef1
import torch
import torch.nn as nn
from torchvision import models
class EnhancedHybridFusionClassifier(nn.Module):
def __init__(self, train_base=False):
super().__init__()
# EfficientNetB0
self.eff = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT)
self.eff_features = self.eff.features
self.eff_pool = self.eff.avgpool
eff_out = 1280
# ResNet50
self.res = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
res_out = self.res.fc.in_features
self.res.fc = nn.Identity()
# MobileNetV2
self.mob = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.DEFAULT)
mob_out = self.mob.classifier[1].in_features
self.mob.classifier = nn.Identity()
# Optionally freeze
if not train_base:
for p in self.eff_features.parameters(): p.requires_grad = False
for p in self.res.parameters(): p.requires_grad = False
for p in self.mob.parameters(): p.requires_grad = False
# Final Classifier Head
total_features = eff_out + res_out + mob_out
self.classifier = nn.Sequential(
nn.BatchNorm1d(total_features),
nn.Dropout(0.5),
nn.Linear(total_features, 512),
nn.ReLU(),
nn.BatchNorm1d(512),
nn.Dropout(0.4),
nn.Linear(512, 128),
nn.ReLU(),
nn.BatchNorm1d(128),
nn.Dropout(0.3),
nn.Linear(128, 1),
nn.Sigmoid()
)
def forward(self, x):
eff_x = self.eff_pool(self.eff_features(x))
eff_x = torch.flatten(eff_x, 1)
res_x = self.res(x)
mob_x = self.mob(x)
fusion = torch.cat([eff_x, res_x, mob_x], dim=1)
return self.classifier(fusion)