fish-freshness-classifier / train /train_hybrid_fusion.py
roqueselopeta's picture
Initial commit with clean project files
c17bef1
import os
import sys
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from copy import deepcopy
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
# Project imports
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from models.hybrid_fusion import EnhancedHybridFusionClassifier
from utils.dataset_loader import ImageNpyDataset
def train_epoch(model, dataloader, criterion, optimizer, device):
model.train()
total_loss = 0
loop = tqdm(dataloader, desc="πŸŒ€ Training", leave=False)
for x, y in loop:
x, y = x.to(device), y.to(device)
optimizer.zero_grad()
out = model(x)
loss = criterion(out, y)
loss.backward()
optimizer.step()
total_loss += loss.item() * x.size(0)
return total_loss / len(dataloader.dataset)
def eval_model(model, dataloader, criterion, device):
model.eval()
total_loss = 0
correct = 0
with torch.no_grad():
for x, y in dataloader:
x, y = x.to(device), y.to(device)
out = model(x)
loss = criterion(out, y)
total_loss += loss.item() * x.size(0)
preds = (out > 0.5).float()
correct += (preds == y).sum().item()
acc = correct / len(dataloader.dataset)
return total_loss / len(dataloader.dataset), acc
def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"πŸ’» Using device: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
# Datasets with augmentation
train_ds = ImageNpyDataset("train_paths.npy", "train_labels.npy", augment=True)
val_ds = ImageNpyDataset("val_paths.npy", "val_labels.npy")
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=16,
pin_memory=True, persistent_workers=True, prefetch_factor=2)
val_loader = DataLoader(val_ds, batch_size=32, num_workers=16,
pin_memory=True, persistent_workers=True, prefetch_factor=2)
model = EnhancedHybridFusionClassifier(train_base=False).to(device)
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3)
best_model_wts = deepcopy(model.state_dict())
best_acc = 0
print("\nπŸ”Έ Training Enhanced Hybrid Fusion Classifier head only...")
for epoch in range(30):
print(f"\nπŸ” Epoch {epoch+1}...")
train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
val_loss, val_acc = eval_model(model, val_loader, criterion, device)
print(f" βœ… Epoch {epoch+1}: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}, Val Acc={val_acc:.4f}")
scheduler.step(val_acc)
torch.cuda.empty_cache()
if val_acc > best_acc:
best_acc = val_acc
best_model_wts = deepcopy(model.state_dict())
torch.save(model.state_dict(), "hybrid_fusion_best.pth")
print("βœ“ Saved new best model")
# Fine-tuning
print("\nπŸ”§ Fine-tuning top feature extractor layers...")
model.load_state_dict(best_model_wts)
for name, param in model.named_parameters():
if "features" in name or "layer" in name:
param.requires_grad = True
optimizer = optim.Adam(model.parameters(), lr=1e-5)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3)
best_finetune_acc = 0.0
best_finetune_wts = deepcopy(model.state_dict())
for epoch in range(25):
print(f"\n[Fine-tune] Epoch {epoch+1}...")
train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
_, val_acc = eval_model(model, val_loader, criterion, device)
print(f" πŸ” Epoch {epoch+1}: Train Loss={train_loss:.4f}, Val Acc={val_acc:.4f}")
scheduler.step(val_acc)
torch.cuda.empty_cache()
if val_acc > best_finetune_acc:
best_finetune_acc = val_acc
best_finetune_wts = deepcopy(model.state_dict())
torch.save(best_finetune_wts, "hybrid_fusion_best_finetuned.pth")
print("βœ“ Saved fine-tuned best model")
model.load_state_dict(best_finetune_wts)
torch.save(model.state_dict(), "hybrid_fusion_final.pth")
print("βœ… Final hybrid fusion model saved as hybrid_fusion_final.pth")
if __name__ == "__main__":
main()