File size: 4,343 Bytes
c17bef1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from copy import deepcopy
from models.efficientnet_b4 import EfficientNetB4Classifier
from utils.dataset_loader import ImageNpyDataset

def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    for i, (x, y) in enumerate(dataloader):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        out = model(x)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * x.size(0)
        print(f"\rπŸŒ€ Batch {i+1}/{len(dataloader)}", end="")
    return total_loss / len(dataloader.dataset)

def eval_model(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            loss = criterion(out, y)
            total_loss += loss.item() * x.size(0)
            preds = (out > 0.5).float()
            correct += (preds == y).sum().item()
    acc = correct / len(dataloader.dataset)
    return total_loss / len(dataloader.dataset), acc

def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"πŸ’» Using device: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")

    train_ds = ImageNpyDataset("train_paths.npy", "train_labels.npy", augment=True)
    val_ds   = ImageNpyDataset("val_paths.npy", "val_labels.npy")

    train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=4, pin_memory=True)
    val_loader   = DataLoader(val_ds, batch_size=16, num_workers=4, pin_memory=True)

    model = EfficientNetB4Classifier(train_base=False).to(device)
    criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    best_model_wts = deepcopy(model.state_dict())
    best_acc = 0
    patience = 10
    lr_patience = 3
    cooldown = 0

    print("πŸ”Έ Training EfficientNetB4 head only...")
    for epoch in range(30):
        print(f"\nπŸ” Starting Epoch {epoch+1}...")
        train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
        val_loss, val_acc = eval_model(model, val_loader, criterion, device)
        print(f" βœ… Epoch {epoch+1}: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}, Val Acc={val_acc:.4f}")

        if val_acc > best_acc:
            best_acc = val_acc
            best_model_wts = deepcopy(model.state_dict())
            torch.save(model.state_dict(), "efficientnetb4_best.pth")
            cooldown = 0
            print("βœ“ Saved new best model")
        else:
            cooldown += 1
            if cooldown >= lr_patience:
                for param_group in optimizer.param_groups:
                    param_group['lr'] *= 0.5
                print("πŸ” Reduced LR")
                cooldown = 0
            if cooldown >= patience:
                print("⏹️ Early stopping")
                break

    print("\nπŸ”Έ Fine-tuning top EfficientNetB4 layers...")
    model.load_state_dict(best_model_wts)
    for param in model.base_model.features[-5:].parameters():
        param.requires_grad = True

    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    best_finetune_acc = 0.0
    best_finetune_wts = deepcopy(model.state_dict())

    for epoch in range(20):
        print(f"\n[Fine-tune] Epoch {epoch+1}...")
        train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
        _, val_acc = eval_model(model, val_loader, criterion, device)
        print(f" πŸ” Epoch {epoch+1}: Train Loss={train_loss:.4f}, Val Acc={val_acc:.4f}")

        if val_acc > best_finetune_acc:
            best_finetune_acc = val_acc
            best_finetune_wts = deepcopy(model.state_dict())
            torch.save(best_finetune_wts, "efficientnetb4_best_finetuned.pth")
            print("βœ“ Saved fine-tuned best model")

    model.load_state_dict(best_finetune_wts)
    torch.save(model.state_dict(), "efficientnetb4_final.pth")
    print("βœ… Training complete. Final model saved as efficientnetb4_final.pth")

if __name__ == "__main__":
    main()