|
import sys, os |
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) |
|
|
|
import torch |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
import pandas as pd |
|
from torch.utils.data import DataLoader |
|
from torchvision import transforms |
|
from PIL import Image |
|
from sklearn.metrics import ( |
|
accuracy_score, precision_score, recall_score, f1_score, |
|
confusion_matrix, roc_curve, auc |
|
) |
|
|
|
|
|
class ImageNpyDataset(torch.utils.data.Dataset): |
|
def __init__(self, paths_file, labels_file, img_size=(224, 224)): |
|
self.image_paths = np.load(paths_file, allow_pickle=True).astype(str) |
|
self.labels = np.load(labels_file, allow_pickle=True).astype(np.float32) |
|
self.transform = transforms.Compose([ |
|
transforms.Resize(img_size), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.485, 0.456, 0.406], |
|
[0.229, 0.224, 0.225]) |
|
]) |
|
|
|
def __len__(self): |
|
return len(self.image_paths) |
|
|
|
def __getitem__(self, idx): |
|
image = Image.open(self.image_paths[idx]).convert("RGB") |
|
image = self.transform(image) |
|
label = torch.tensor([self.labels[idx]], dtype=torch.float32) |
|
return image, label |
|
|
|
|
|
def evaluate_model(model, dataloader, device): |
|
model.eval() |
|
y_true, y_pred, y_prob = [], [], [] |
|
with torch.no_grad(): |
|
for x, y in dataloader: |
|
x, y = x.to(device), y.to(device) |
|
out = model(x).squeeze() |
|
pred = (out > 0.5).float() |
|
y_true.extend(y.squeeze().cpu().numpy().tolist()) |
|
y_pred.extend(pred.cpu().numpy().tolist()) |
|
y_prob.extend(out.cpu().numpy().tolist()) |
|
return { |
|
"accuracy": accuracy_score(y_true, y_pred), |
|
"precision": precision_score(y_true, y_pred), |
|
"recall": recall_score(y_true, y_pred), |
|
"f1": f1_score(y_true, y_pred), |
|
"y_true": y_true, |
|
"y_pred": y_pred, |
|
"y_prob": y_prob |
|
} |
|
|
|
if __name__ == "__main__": |
|
import models.efficientnet_b0 as b0 |
|
import models.efficientnet_b4 as b4 |
|
import models.mobilenetv2 as mv2 |
|
import models.resnet50 as rsn |
|
from models.hybrid_fusion import EnhancedHybridFusionClassifier |
|
|
|
model_classes = { |
|
"EfficientNetB0": (b0.EfficientNetB0Classifier, "results_efficientnet_b0/efficientnet_best9912.pth"), |
|
"EfficientNetB4": (b4.EfficientNetB4Classifier, "results_efficientnet_b4/efficientnetb4_best9799.pth"), |
|
"MobileNetV2": (mv2.MobileNetV2Classifier, "results_mobilenetv2/mobilenetv2_best9598.pth"), |
|
"ResNet50": (rsn.ResNet50Classifier, "results_resnet50/resnet50_best9849.pth"), |
|
"HybridFusion": (EnhancedHybridFusionClassifier, "results_hybrid_fusion/hybrid_fusion_best9799.pth"), |
|
} |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
test_ds = ImageNpyDataset("test_paths.npy", "test_labels.npy") |
|
test_loader = DataLoader(test_ds, batch_size=32) |
|
|
|
results = {} |
|
for name, (cls, path) in model_classes.items(): |
|
print(f"π Evaluating {name}...") |
|
model = cls() |
|
model.load_state_dict(torch.load(path, map_location=device)) |
|
model.to(device) |
|
results[name] = evaluate_model(model, test_loader, device) |
|
|
|
os.makedirs("results", exist_ok=True) |
|
|
|
|
|
fig, axes = plt.subplots(3, 2, figsize=(13, 15)) |
|
for ax, (name, m) in zip(axes.flat, results.items()): |
|
cm = confusion_matrix(m["y_true"], m["y_pred"]) |
|
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", |
|
xticklabels=["Fresh", "Not Fresh"], |
|
yticklabels=["Fresh", "Not Fresh"], |
|
ax=ax) |
|
ax.set_title(f"{name}") |
|
ax.set_xlabel("Predicted") |
|
ax.set_ylabel("Actual") |
|
fig.tight_layout() |
|
fig.savefig("results/confusion_matrices_comparison.png", dpi=300) |
|
|
|
|
|
plt.figure(figsize=(8, 6)) |
|
for name, m in results.items(): |
|
fpr, tpr, _ = roc_curve(m["y_true"], m["y_prob"]) |
|
roc_auc = auc(fpr, tpr) |
|
plt.plot(fpr, tpr, label=f"{name} (AUC={roc_auc:.3f})") |
|
plt.plot([0, 1], [0, 1], "k--") |
|
plt.title("ROC Curve Comparison") |
|
plt.xlabel("False Positive Rate") |
|
plt.ylabel("True Positive Rate") |
|
plt.legend(loc="lower right") |
|
plt.grid(True) |
|
plt.tight_layout() |
|
plt.savefig("results/roc_curves_comparison.png", dpi=300) |
|
|
|
|
|
metrics = ["accuracy", "precision", "recall", "f1"] |
|
x = np.arange(len(metrics)) |
|
bar_width = 0.15 |
|
|
|
plt.figure(figsize=(12, 6)) |
|
for i, (name, m) in enumerate(results.items()): |
|
scores = [m[k] for k in metrics] |
|
plt.bar(x + i * bar_width, scores, width=bar_width, label=name) |
|
|
|
plt.xticks(x + bar_width * (len(results) / 2), [m.title() for m in metrics]) |
|
plt.ylim(0, 1) |
|
plt.ylabel("Score") |
|
plt.title("Model Metric Comparison") |
|
plt.legend() |
|
plt.grid(axis="y") |
|
plt.tight_layout() |
|
plt.savefig("results/metrics_bar_chart.png", dpi=300) |
|
|
|
|
|
df = pd.DataFrame({ |
|
name: [results[name][k] for k in metrics] |
|
for name in results |
|
}, index=[m.title() for m in metrics]).T |
|
df.to_csv("results/model_metrics_summary.csv") |
|
print("π Saved metrics to results/model_metrics_summary.csv") |
|
|
|
plt.show() |
|
|