fish-freshness-classifier / train /train_efficientnet_b4.py
roqueselopeta's picture
Initial commit with clean project files
c17bef1
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from copy import deepcopy
from models.efficientnet_b4 import EfficientNetB4Classifier
from utils.dataset_loader import ImageNpyDataset
def train_epoch(model, dataloader, criterion, optimizer, device):
model.train()
total_loss = 0
for i, (x, y) in enumerate(dataloader):
x, y = x.to(device), y.to(device)
optimizer.zero_grad()
out = model(x)
loss = criterion(out, y)
loss.backward()
optimizer.step()
total_loss += loss.item() * x.size(0)
print(f"\rπŸŒ€ Batch {i+1}/{len(dataloader)}", end="")
return total_loss / len(dataloader.dataset)
def eval_model(model, dataloader, criterion, device):
model.eval()
total_loss = 0
correct = 0
with torch.no_grad():
for x, y in dataloader:
x, y = x.to(device), y.to(device)
out = model(x)
loss = criterion(out, y)
total_loss += loss.item() * x.size(0)
preds = (out > 0.5).float()
correct += (preds == y).sum().item()
acc = correct / len(dataloader.dataset)
return total_loss / len(dataloader.dataset), acc
def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"πŸ’» Using device: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
train_ds = ImageNpyDataset("train_paths.npy", "train_labels.npy", augment=True)
val_ds = ImageNpyDataset("val_paths.npy", "val_labels.npy")
train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=16, num_workers=4, pin_memory=True)
model = EfficientNetB4Classifier(train_base=False).to(device)
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
best_model_wts = deepcopy(model.state_dict())
best_acc = 0
patience = 10
lr_patience = 3
cooldown = 0
print("πŸ”Έ Training EfficientNetB4 head only...")
for epoch in range(30):
print(f"\nπŸ” Starting Epoch {epoch+1}...")
train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
val_loss, val_acc = eval_model(model, val_loader, criterion, device)
print(f" βœ… Epoch {epoch+1}: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}, Val Acc={val_acc:.4f}")
if val_acc > best_acc:
best_acc = val_acc
best_model_wts = deepcopy(model.state_dict())
torch.save(model.state_dict(), "efficientnetb4_best.pth")
cooldown = 0
print("βœ“ Saved new best model")
else:
cooldown += 1
if cooldown >= lr_patience:
for param_group in optimizer.param_groups:
param_group['lr'] *= 0.5
print("πŸ” Reduced LR")
cooldown = 0
if cooldown >= patience:
print("⏹️ Early stopping")
break
print("\nπŸ”Έ Fine-tuning top EfficientNetB4 layers...")
model.load_state_dict(best_model_wts)
for param in model.base_model.features[-5:].parameters():
param.requires_grad = True
optimizer = optim.Adam(model.parameters(), lr=1e-4)
best_finetune_acc = 0.0
best_finetune_wts = deepcopy(model.state_dict())
for epoch in range(20):
print(f"\n[Fine-tune] Epoch {epoch+1}...")
train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
_, val_acc = eval_model(model, val_loader, criterion, device)
print(f" πŸ” Epoch {epoch+1}: Train Loss={train_loss:.4f}, Val Acc={val_acc:.4f}")
if val_acc > best_finetune_acc:
best_finetune_acc = val_acc
best_finetune_wts = deepcopy(model.state_dict())
torch.save(best_finetune_wts, "efficientnetb4_best_finetuned.pth")
print("βœ“ Saved fine-tuned best model")
model.load_state_dict(best_finetune_wts)
torch.save(model.state_dict(), "efficientnetb4_final.pth")
print("βœ… Training complete. Final model saved as efficientnetb4_final.pth")
if __name__ == "__main__":
main()