File size: 4,588 Bytes
c17bef1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os
import sys
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from copy import deepcopy
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
from tqdm import tqdm

# Project imports
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from models.hybrid_fusion import EnhancedHybridFusionClassifier
from utils.dataset_loader import ImageNpyDataset

def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    loop = tqdm(dataloader, desc="πŸŒ€ Training", leave=False)
    for x, y in loop:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        out = model(x)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * x.size(0)
    return total_loss / len(dataloader.dataset)

def eval_model(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            loss = criterion(out, y)
            total_loss += loss.item() * x.size(0)
            preds = (out > 0.5).float()
            correct += (preds == y).sum().item()
    acc = correct / len(dataloader.dataset)
    return total_loss / len(dataloader.dataset), acc

def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"πŸ’» Using device: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")

    # Datasets with augmentation
    train_ds = ImageNpyDataset("train_paths.npy", "train_labels.npy", augment=True)
    val_ds   = ImageNpyDataset("val_paths.npy", "val_labels.npy")

    train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=16,
                              pin_memory=True, persistent_workers=True, prefetch_factor=2)
    val_loader   = DataLoader(val_ds, batch_size=32, num_workers=16,
                              pin_memory=True, persistent_workers=True, prefetch_factor=2)

    model = EnhancedHybridFusionClassifier(train_base=False).to(device)
    criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3)

    best_model_wts = deepcopy(model.state_dict())
    best_acc = 0

    print("\nπŸ”Έ Training Enhanced Hybrid Fusion Classifier head only...")
    for epoch in range(30):
        print(f"\nπŸ” Epoch {epoch+1}...")
        train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
        val_loss, val_acc = eval_model(model, val_loader, criterion, device)
        print(f" βœ… Epoch {epoch+1}: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}, Val Acc={val_acc:.4f}")

        scheduler.step(val_acc)
        torch.cuda.empty_cache()

        if val_acc > best_acc:
            best_acc = val_acc
            best_model_wts = deepcopy(model.state_dict())
            torch.save(model.state_dict(), "hybrid_fusion_best.pth")
            print("βœ“ Saved new best model")

    # Fine-tuning
    print("\nπŸ”§ Fine-tuning top feature extractor layers...")
    model.load_state_dict(best_model_wts)

    for name, param in model.named_parameters():
        if "features" in name or "layer" in name:
            param.requires_grad = True

    optimizer = optim.Adam(model.parameters(), lr=1e-5)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3)
    best_finetune_acc = 0.0
    best_finetune_wts = deepcopy(model.state_dict())

    for epoch in range(25):
        print(f"\n[Fine-tune] Epoch {epoch+1}...")
        train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
        _, val_acc = eval_model(model, val_loader, criterion, device)
        print(f" πŸ” Epoch {epoch+1}: Train Loss={train_loss:.4f}, Val Acc={val_acc:.4f}")

        scheduler.step(val_acc)
        torch.cuda.empty_cache()

        if val_acc > best_finetune_acc:
            best_finetune_acc = val_acc
            best_finetune_wts = deepcopy(model.state_dict())
            torch.save(best_finetune_wts, "hybrid_fusion_best_finetuned.pth")
            print("βœ“ Saved fine-tuned best model")

    model.load_state_dict(best_finetune_wts)
    torch.save(model.state_dict(), "hybrid_fusion_final.pth")
    print("βœ… Final hybrid fusion model saved as hybrid_fusion_final.pth")

if __name__ == "__main__":
    main()