import sys, os sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) import torch import numpy as np import matplotlib.pyplot as plt import seaborn as sns import pandas as pd from torch.utils.data import DataLoader from torchvision import transforms from PIL import Image from sklearn.metrics import ( accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, roc_curve, auc ) # --- Dataset class --- class ImageNpyDataset(torch.utils.data.Dataset): def __init__(self, paths_file, labels_file, img_size=(224, 224)): self.image_paths = np.load(paths_file, allow_pickle=True).astype(str) self.labels = np.load(labels_file, allow_pickle=True).astype(np.float32) self.transform = transforms.Compose([ transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) def __len__(self): return len(self.image_paths) def __getitem__(self, idx): image = Image.open(self.image_paths[idx]).convert("RGB") image = self.transform(image) label = torch.tensor([self.labels[idx]], dtype=torch.float32) return image, label # --- Evaluation --- def evaluate_model(model, dataloader, device): model.eval() y_true, y_pred, y_prob = [], [], [] with torch.no_grad(): for x, y in dataloader: x, y = x.to(device), y.to(device) out = model(x).squeeze() pred = (out > 0.5).float() y_true.extend(y.squeeze().cpu().numpy().tolist()) y_pred.extend(pred.cpu().numpy().tolist()) y_prob.extend(out.cpu().numpy().tolist()) return { "accuracy": accuracy_score(y_true, y_pred), "precision": precision_score(y_true, y_pred), "recall": recall_score(y_true, y_pred), "f1": f1_score(y_true, y_pred), "y_true": y_true, "y_pred": y_pred, "y_prob": y_prob } if __name__ == "__main__": import models.efficientnet_b0 as b0 import models.efficientnet_b4 as b4 import models.mobilenetv2 as mv2 import models.resnet50 as rsn from models.hybrid_fusion import EnhancedHybridFusionClassifier model_classes = { "EfficientNetB0": (b0.EfficientNetB0Classifier, "results_efficientnet_b0/efficientnet_best9912.pth"), "EfficientNetB4": (b4.EfficientNetB4Classifier, "results_efficientnet_b4/efficientnetb4_best9799.pth"), "MobileNetV2": (mv2.MobileNetV2Classifier, "results_mobilenetv2/mobilenetv2_best9598.pth"), "ResNet50": (rsn.ResNet50Classifier, "results_resnet50/resnet50_best9849.pth"), "HybridFusion": (EnhancedHybridFusionClassifier, "results_hybrid_fusion/hybrid_fusion_best9799.pth"), } device = torch.device("cuda" if torch.cuda.is_available() else "cpu") test_ds = ImageNpyDataset("test_paths.npy", "test_labels.npy") test_loader = DataLoader(test_ds, batch_size=32) results = {} for name, (cls, path) in model_classes.items(): print(f"🔍 Evaluating {name}...") model = cls() model.load_state_dict(torch.load(path, map_location=device)) model.to(device) results[name] = evaluate_model(model, test_loader, device) os.makedirs("results", exist_ok=True) # Confusion Matrices fig, axes = plt.subplots(3, 2, figsize=(13, 15)) for ax, (name, m) in zip(axes.flat, results.items()): cm = confusion_matrix(m["y_true"], m["y_pred"]) sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=["Fresh", "Not Fresh"], yticklabels=["Fresh", "Not Fresh"], ax=ax) ax.set_title(f"{name}") ax.set_xlabel("Predicted") ax.set_ylabel("Actual") fig.tight_layout() fig.savefig("results/confusion_matrices_comparison.png", dpi=300) # ROC Curves plt.figure(figsize=(8, 6)) for name, m in results.items(): fpr, tpr, _ = roc_curve(m["y_true"], m["y_prob"]) roc_auc = auc(fpr, tpr) plt.plot(fpr, tpr, label=f"{name} (AUC={roc_auc:.3f})") plt.plot([0, 1], [0, 1], "k--") plt.title("ROC Curve Comparison") plt.xlabel("False Positive Rate") plt.ylabel("True Positive Rate") plt.legend(loc="lower right") plt.grid(True) plt.tight_layout() plt.savefig("results/roc_curves_comparison.png", dpi=300) # Bar Chart of Metrics metrics = ["accuracy", "precision", "recall", "f1"] x = np.arange(len(metrics)) bar_width = 0.15 plt.figure(figsize=(12, 6)) for i, (name, m) in enumerate(results.items()): scores = [m[k] for k in metrics] plt.bar(x + i * bar_width, scores, width=bar_width, label=name) plt.xticks(x + bar_width * (len(results) / 2), [m.title() for m in metrics]) plt.ylim(0, 1) plt.ylabel("Score") plt.title("Model Metric Comparison") plt.legend() plt.grid(axis="y") plt.tight_layout() plt.savefig("results/metrics_bar_chart.png", dpi=300) # Save to CSV df = pd.DataFrame({ name: [results[name][k] for k in metrics] for name in results }, index=[m.title() for m in metrics]).T df.to_csv("results/model_metrics_summary.csv") print("📄 Saved metrics to results/model_metrics_summary.csv") plt.show()