|
import sys |
|
import os |
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
from torch.utils.data import DataLoader |
|
from copy import deepcopy |
|
from models.efficientnet_b4 import EfficientNetB4Classifier |
|
from utils.dataset_loader import ImageNpyDataset |
|
|
|
def train_epoch(model, dataloader, criterion, optimizer, device): |
|
model.train() |
|
total_loss = 0 |
|
for i, (x, y) in enumerate(dataloader): |
|
x, y = x.to(device), y.to(device) |
|
optimizer.zero_grad() |
|
out = model(x) |
|
loss = criterion(out, y) |
|
loss.backward() |
|
optimizer.step() |
|
total_loss += loss.item() * x.size(0) |
|
print(f"\rπ Batch {i+1}/{len(dataloader)}", end="") |
|
return total_loss / len(dataloader.dataset) |
|
|
|
def eval_model(model, dataloader, criterion, device): |
|
model.eval() |
|
total_loss = 0 |
|
correct = 0 |
|
with torch.no_grad(): |
|
for x, y in dataloader: |
|
x, y = x.to(device), y.to(device) |
|
out = model(x) |
|
loss = criterion(out, y) |
|
total_loss += loss.item() * x.size(0) |
|
preds = (out > 0.5).float() |
|
correct += (preds == y).sum().item() |
|
acc = correct / len(dataloader.dataset) |
|
return total_loss / len(dataloader.dataset), acc |
|
|
|
def main(): |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print(f"π» Using device: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}") |
|
|
|
train_ds = ImageNpyDataset("train_paths.npy", "train_labels.npy", augment=True) |
|
val_ds = ImageNpyDataset("val_paths.npy", "val_labels.npy") |
|
|
|
train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=4, pin_memory=True) |
|
val_loader = DataLoader(val_ds, batch_size=16, num_workers=4, pin_memory=True) |
|
|
|
model = EfficientNetB4Classifier(train_base=False).to(device) |
|
criterion = nn.BCELoss() |
|
optimizer = optim.Adam(model.parameters(), lr=1e-4) |
|
|
|
best_model_wts = deepcopy(model.state_dict()) |
|
best_acc = 0 |
|
patience = 10 |
|
lr_patience = 3 |
|
cooldown = 0 |
|
|
|
print("πΈ Training EfficientNetB4 head only...") |
|
for epoch in range(30): |
|
print(f"\nπ Starting Epoch {epoch+1}...") |
|
train_loss = train_epoch(model, train_loader, criterion, optimizer, device) |
|
val_loss, val_acc = eval_model(model, val_loader, criterion, device) |
|
print(f" β
Epoch {epoch+1}: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}, Val Acc={val_acc:.4f}") |
|
|
|
if val_acc > best_acc: |
|
best_acc = val_acc |
|
best_model_wts = deepcopy(model.state_dict()) |
|
torch.save(model.state_dict(), "efficientnetb4_best.pth") |
|
cooldown = 0 |
|
print("β Saved new best model") |
|
else: |
|
cooldown += 1 |
|
if cooldown >= lr_patience: |
|
for param_group in optimizer.param_groups: |
|
param_group['lr'] *= 0.5 |
|
print("π Reduced LR") |
|
cooldown = 0 |
|
if cooldown >= patience: |
|
print("βΉοΈ Early stopping") |
|
break |
|
|
|
print("\nπΈ Fine-tuning top EfficientNetB4 layers...") |
|
model.load_state_dict(best_model_wts) |
|
for param in model.base_model.features[-5:].parameters(): |
|
param.requires_grad = True |
|
|
|
optimizer = optim.Adam(model.parameters(), lr=1e-4) |
|
best_finetune_acc = 0.0 |
|
best_finetune_wts = deepcopy(model.state_dict()) |
|
|
|
for epoch in range(20): |
|
print(f"\n[Fine-tune] Epoch {epoch+1}...") |
|
train_loss = train_epoch(model, train_loader, criterion, optimizer, device) |
|
_, val_acc = eval_model(model, val_loader, criterion, device) |
|
print(f" π Epoch {epoch+1}: Train Loss={train_loss:.4f}, Val Acc={val_acc:.4f}") |
|
|
|
if val_acc > best_finetune_acc: |
|
best_finetune_acc = val_acc |
|
best_finetune_wts = deepcopy(model.state_dict()) |
|
torch.save(best_finetune_wts, "efficientnetb4_best_finetuned.pth") |
|
print("β Saved fine-tuned best model") |
|
|
|
model.load_state_dict(best_finetune_wts) |
|
torch.save(model.state_dict(), "efficientnetb4_final.pth") |
|
print("β
Training complete. Final model saved as efficientnetb4_final.pth") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|