File size: 3,882 Bytes
c17bef1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import cv2
from PIL import Image
from torchvision import transforms
from transformers import SamModel, SamProcessor
from models.efficientnet_b4 import EfficientNetB4Classifier
from glob import glob

# -------------------------------
# Setup & Paths
# -------------------------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
RESULTS_DIR = "results/gradcam_sam"
os.makedirs(RESULTS_DIR, exist_ok=True)

# -------------------------------
# Grad-CAM
# -------------------------------
class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None

        def forward_hook(module, input, output):
            self.activations = output

        def backward_hook(module, grad_input, grad_output):
            self.gradients = grad_output[0]

        target_layer.register_forward_hook(forward_hook)
        target_layer.register_full_backward_hook(backward_hook)

    def __call__(self, x):
        self.model.zero_grad()
        out = self.model(x)
        out.backward(torch.ones_like(out))

        weights = self.gradients.mean(dim=(2, 3), keepdim=True)
        cam = (weights * self.activations).sum(dim=1, keepdim=True)
        cam = torch.nn.functional.relu(cam)
        cam = torch.nn.functional.interpolate(cam, size=x.shape[2:], mode='bilinear', align_corners=False)
        cam = cam.squeeze().detach().cpu().numpy()
        cam = (cam - cam.min()) / (cam.max() - cam.min())
        return cam

# -------------------------------
# Load EfficientNetB4
# -------------------------------
model = EfficientNetB4Classifier()
model.load_state_dict(torch.load("efficientnetb4_best8595.pth", map_location=DEVICE))
model.to(DEVICE).eval()
target_layer = model.base_model.features[-1]
cam_extractor = GradCAM(model, target_layer)

# -------------------------------
# Load SAM
# -------------------------------
sam_model = SamModel.from_pretrained("facebook/sam-vit-base").to(DEVICE)
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

# -------------------------------
# Apply Grad-CAM + SAM
# -------------------------------
def apply_gradcam_sam(image_path, save_path):
    raw = Image.open(image_path).convert("RGB")

    transform = transforms.Compose([
        transforms.Resize((380, 380)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    input_tensor = transform(raw).unsqueeze(0).to(DEVICE)

    # Grad-CAM
    cam = cam_extractor(input_tensor)
    heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
    heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
    heatmap = cv2.resize(heatmap, raw.size)
    overlay = cv2.addWeighted(np.array(raw), 0.5, heatmap, 0.5, 0)

    # SAM
    sam_inputs = sam_processor(raw, return_tensors="pt").to(DEVICE)
    with torch.no_grad():
        sam_outputs = sam_model(**sam_inputs)
    sam_mask = sam_outputs.pred_masks[0][0].cpu().numpy()

    # Mask outline and final image
    mask_outline = cv2.Canny(np.uint8(sam_mask * 255), 100, 200)
    mask_outline = cv2.cvtColor(mask_outline, cv2.COLOR_GRAY2RGB)
    final = cv2.addWeighted(overlay, 1, mask_outline, 1, 0)

    # Save output
    fname = os.path.basename(image_path)
    out_path = os.path.join(save_path, f"overlay_{fname}")
    Image.fromarray(final).save(out_path)
    print(f"✅ Saved Grad-CAM+SAM: {out_path}")

# -------------------------------
# Run on 5 Random Images
# -------------------------------
if __name__ == "__main__":
    image_paths = glob("dataset/**/*.jpg", recursive=True)
    selected = np.random.choice(image_paths, size=5, replace=False)

    for input_path in selected:
        apply_gradcam_sam(input_path, RESULTS_DIR)