|
import os |
|
import sys |
|
import torch |
|
import numpy as np |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
from copy import deepcopy |
|
from torch.utils.data import DataLoader |
|
from torchvision import transforms |
|
from PIL import Image |
|
from tqdm import tqdm |
|
|
|
|
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) |
|
from models.hybrid_fusion import EnhancedHybridFusionClassifier |
|
from utils.dataset_loader import ImageNpyDataset |
|
|
|
def train_epoch(model, dataloader, criterion, optimizer, device): |
|
model.train() |
|
total_loss = 0 |
|
loop = tqdm(dataloader, desc="π Training", leave=False) |
|
for x, y in loop: |
|
x, y = x.to(device), y.to(device) |
|
optimizer.zero_grad() |
|
out = model(x) |
|
loss = criterion(out, y) |
|
loss.backward() |
|
optimizer.step() |
|
total_loss += loss.item() * x.size(0) |
|
return total_loss / len(dataloader.dataset) |
|
|
|
def eval_model(model, dataloader, criterion, device): |
|
model.eval() |
|
total_loss = 0 |
|
correct = 0 |
|
with torch.no_grad(): |
|
for x, y in dataloader: |
|
x, y = x.to(device), y.to(device) |
|
out = model(x) |
|
loss = criterion(out, y) |
|
total_loss += loss.item() * x.size(0) |
|
preds = (out > 0.5).float() |
|
correct += (preds == y).sum().item() |
|
acc = correct / len(dataloader.dataset) |
|
return total_loss / len(dataloader.dataset), acc |
|
|
|
def main(): |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print(f"π» Using device: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}") |
|
|
|
|
|
train_ds = ImageNpyDataset("train_paths.npy", "train_labels.npy", augment=True) |
|
val_ds = ImageNpyDataset("val_paths.npy", "val_labels.npy") |
|
|
|
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=16, |
|
pin_memory=True, persistent_workers=True, prefetch_factor=2) |
|
val_loader = DataLoader(val_ds, batch_size=32, num_workers=16, |
|
pin_memory=True, persistent_workers=True, prefetch_factor=2) |
|
|
|
model = EnhancedHybridFusionClassifier(train_base=False).to(device) |
|
criterion = nn.BCELoss() |
|
optimizer = optim.Adam(model.parameters(), lr=1e-4) |
|
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3) |
|
|
|
best_model_wts = deepcopy(model.state_dict()) |
|
best_acc = 0 |
|
|
|
print("\nπΈ Training Enhanced Hybrid Fusion Classifier head only...") |
|
for epoch in range(30): |
|
print(f"\nπ Epoch {epoch+1}...") |
|
train_loss = train_epoch(model, train_loader, criterion, optimizer, device) |
|
val_loss, val_acc = eval_model(model, val_loader, criterion, device) |
|
print(f" β
Epoch {epoch+1}: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}, Val Acc={val_acc:.4f}") |
|
|
|
scheduler.step(val_acc) |
|
torch.cuda.empty_cache() |
|
|
|
if val_acc > best_acc: |
|
best_acc = val_acc |
|
best_model_wts = deepcopy(model.state_dict()) |
|
torch.save(model.state_dict(), "hybrid_fusion_best.pth") |
|
print("β Saved new best model") |
|
|
|
|
|
print("\nπ§ Fine-tuning top feature extractor layers...") |
|
model.load_state_dict(best_model_wts) |
|
|
|
for name, param in model.named_parameters(): |
|
if "features" in name or "layer" in name: |
|
param.requires_grad = True |
|
|
|
optimizer = optim.Adam(model.parameters(), lr=1e-5) |
|
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3) |
|
best_finetune_acc = 0.0 |
|
best_finetune_wts = deepcopy(model.state_dict()) |
|
|
|
for epoch in range(25): |
|
print(f"\n[Fine-tune] Epoch {epoch+1}...") |
|
train_loss = train_epoch(model, train_loader, criterion, optimizer, device) |
|
_, val_acc = eval_model(model, val_loader, criterion, device) |
|
print(f" π Epoch {epoch+1}: Train Loss={train_loss:.4f}, Val Acc={val_acc:.4f}") |
|
|
|
scheduler.step(val_acc) |
|
torch.cuda.empty_cache() |
|
|
|
if val_acc > best_finetune_acc: |
|
best_finetune_acc = val_acc |
|
best_finetune_wts = deepcopy(model.state_dict()) |
|
torch.save(best_finetune_wts, "hybrid_fusion_best_finetuned.pth") |
|
print("β Saved fine-tuned best model") |
|
|
|
model.load_state_dict(best_finetune_wts) |
|
torch.save(model.state_dict(), "hybrid_fusion_final.pth") |
|
print("β
Final hybrid fusion model saved as hybrid_fusion_final.pth") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|