import os import torch import numpy as np import matplotlib.pyplot as plt import cv2 from PIL import Image from torchvision import transforms from transformers import SamModel, SamProcessor from models.efficientnet_b4 import EfficientNetB4Classifier from glob import glob # ------------------------------- # Setup & Paths # ------------------------------- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") RESULTS_DIR = "results/gradcam_sam" os.makedirs(RESULTS_DIR, exist_ok=True) # ------------------------------- # Grad-CAM # ------------------------------- class GradCAM: def __init__(self, model, target_layer): self.model = model self.target_layer = target_layer self.gradients = None self.activations = None def forward_hook(module, input, output): self.activations = output def backward_hook(module, grad_input, grad_output): self.gradients = grad_output[0] target_layer.register_forward_hook(forward_hook) target_layer.register_full_backward_hook(backward_hook) def __call__(self, x): self.model.zero_grad() out = self.model(x) out.backward(torch.ones_like(out)) weights = self.gradients.mean(dim=(2, 3), keepdim=True) cam = (weights * self.activations).sum(dim=1, keepdim=True) cam = torch.nn.functional.relu(cam) cam = torch.nn.functional.interpolate(cam, size=x.shape[2:], mode='bilinear', align_corners=False) cam = cam.squeeze().detach().cpu().numpy() cam = (cam - cam.min()) / (cam.max() - cam.min()) return cam # ------------------------------- # Load EfficientNetB4 # ------------------------------- model = EfficientNetB4Classifier() model.load_state_dict(torch.load("efficientnetb4_best8595.pth", map_location=DEVICE)) model.to(DEVICE).eval() target_layer = model.base_model.features[-1] cam_extractor = GradCAM(model, target_layer) # ------------------------------- # Load SAM # ------------------------------- sam_model = SamModel.from_pretrained("facebook/sam-vit-base").to(DEVICE) sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base") # ------------------------------- # Apply Grad-CAM + SAM # ------------------------------- def apply_gradcam_sam(image_path, save_path): raw = Image.open(image_path).convert("RGB") transform = transforms.Compose([ transforms.Resize((380, 380)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) input_tensor = transform(raw).unsqueeze(0).to(DEVICE) # Grad-CAM cam = cam_extractor(input_tensor) heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET) heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) heatmap = cv2.resize(heatmap, raw.size) overlay = cv2.addWeighted(np.array(raw), 0.5, heatmap, 0.5, 0) # SAM sam_inputs = sam_processor(raw, return_tensors="pt").to(DEVICE) with torch.no_grad(): sam_outputs = sam_model(**sam_inputs) sam_mask = sam_outputs.pred_masks[0][0].cpu().numpy() # Mask outline and final image mask_outline = cv2.Canny(np.uint8(sam_mask * 255), 100, 200) mask_outline = cv2.cvtColor(mask_outline, cv2.COLOR_GRAY2RGB) final = cv2.addWeighted(overlay, 1, mask_outline, 1, 0) # Save output fname = os.path.basename(image_path) out_path = os.path.join(save_path, f"overlay_{fname}") Image.fromarray(final).save(out_path) print(f"✅ Saved Grad-CAM+SAM: {out_path}") # ------------------------------- # Run on 5 Random Images # ------------------------------- if __name__ == "__main__": image_paths = glob("dataset/**/*.jpg", recursive=True) selected = np.random.choice(image_paths, size=5, replace=False) for input_path in selected: apply_gradcam_sam(input_path, RESULTS_DIR)