File size: 3,626 Bytes
c17bef1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import sys, os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from torch.utils.data import DataLoader
from torchvision import transforms
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, roc_auc_score
)

# βœ… Correct import from your model file
from models.hybrid_fusion import EnhancedHybridFusionClassifier

# --- Dataset ---
class ImageNpyDataset(torch.utils.data.Dataset):
    def __init__(self, paths_file, labels_file, img_size=(224, 224)):
        self.image_paths = np.load(paths_file, allow_pickle=True).astype(str)
        self.labels = np.load(labels_file, allow_pickle=True).astype(np.float32)

        self.transform = transforms.Compose([
            transforms.Resize(img_size),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406],
                                 [0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert("RGB")
        image = self.transform(image)
        label = torch.tensor([self.labels[idx]], dtype=torch.float32)
        return image, label

# --- Evaluation ---
def evaluate(model_path="results_hybrid_fusion/hybrid_fusion_best9799.pth"):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    os.makedirs("results", exist_ok=True)

    test_ds = ImageNpyDataset("test_paths.npy", "test_labels.npy")
    test_loader = DataLoader(test_ds, batch_size=32)

    model = EnhancedHybridFusionClassifier()
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device).eval()

    y_true, y_pred, y_prob = [], [], []

    with torch.no_grad():
        for x, y in test_loader:
            x = x.to(device)
            out = model(x).squeeze()
            prob = out.cpu().numpy().tolist()
            pred = (out > 0.5).float().cpu().numpy().tolist()
            y_true.extend(y.squeeze().numpy().tolist())
            y_pred.extend(pred)
            y_prob.extend(prob)

    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    y_prob = np.array(y_prob)

    acc = accuracy_score(y_true, y_pred)
    prec = precision_score(y_true, y_pred)
    rec = recall_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred)
    auc = roc_auc_score(y_true, y_prob)

    print("\nπŸ“Š Enhanced Hybrid Fusion Evaluation:")
    print(f"Accuracy : {acc:.4f}")
    print(f"Precision: {prec:.4f}")
    print(f"Recall   : {rec:.4f}")
    print(f"F1 Score : {f1:.4f}")
    print(f"AUC      : {auc:.4f}")

    # --- Bar chart ---
    metrics = {"Accuracy": acc, "Precision": prec, "Recall": rec, "F1 Score": f1, "AUC": auc}
    plt.figure()
    sns.barplot(x=list(metrics.keys()), y=list(metrics.values()))
    plt.ylim(0, 1)
    plt.title("Enhanced Hybrid Fusion Evaluation Metrics")
    plt.savefig("results/eval_metrics_enhanced_hybrid_fusion.png", dpi=300, bbox_inches="tight")

    # --- Confusion matrix ---
    cm = confusion_matrix(y_true, y_pred)
    plt.figure()
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
                xticklabels=["Fresh", "Not Fresh"],
                yticklabels=["Fresh", "Not Fresh"])
    plt.xlabel("Predicted")
    plt.ylabel("Actual")
    plt.title("Enhanced Hybrid Fusion Confusion Matrix")
    plt.savefig("results/confusion_matrix_enhanced_hybrid_fusion.png", dpi=300, bbox_inches="tight")

if __name__ == "__main__":
    evaluate()