File size: 3,882 Bytes
c17bef1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import cv2
from PIL import Image
from torchvision import transforms
from transformers import SamModel, SamProcessor
from models.efficientnet_b4 import EfficientNetB4Classifier
from glob import glob
# -------------------------------
# Setup & Paths
# -------------------------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
RESULTS_DIR = "results/gradcam_sam"
os.makedirs(RESULTS_DIR, exist_ok=True)
# -------------------------------
# Grad-CAM
# -------------------------------
class GradCAM:
def __init__(self, model, target_layer):
self.model = model
self.target_layer = target_layer
self.gradients = None
self.activations = None
def forward_hook(module, input, output):
self.activations = output
def backward_hook(module, grad_input, grad_output):
self.gradients = grad_output[0]
target_layer.register_forward_hook(forward_hook)
target_layer.register_full_backward_hook(backward_hook)
def __call__(self, x):
self.model.zero_grad()
out = self.model(x)
out.backward(torch.ones_like(out))
weights = self.gradients.mean(dim=(2, 3), keepdim=True)
cam = (weights * self.activations).sum(dim=1, keepdim=True)
cam = torch.nn.functional.relu(cam)
cam = torch.nn.functional.interpolate(cam, size=x.shape[2:], mode='bilinear', align_corners=False)
cam = cam.squeeze().detach().cpu().numpy()
cam = (cam - cam.min()) / (cam.max() - cam.min())
return cam
# -------------------------------
# Load EfficientNetB4
# -------------------------------
model = EfficientNetB4Classifier()
model.load_state_dict(torch.load("efficientnetb4_best8595.pth", map_location=DEVICE))
model.to(DEVICE).eval()
target_layer = model.base_model.features[-1]
cam_extractor = GradCAM(model, target_layer)
# -------------------------------
# Load SAM
# -------------------------------
sam_model = SamModel.from_pretrained("facebook/sam-vit-base").to(DEVICE)
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
# -------------------------------
# Apply Grad-CAM + SAM
# -------------------------------
def apply_gradcam_sam(image_path, save_path):
raw = Image.open(image_path).convert("RGB")
transform = transforms.Compose([
transforms.Resize((380, 380)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
input_tensor = transform(raw).unsqueeze(0).to(DEVICE)
# Grad-CAM
cam = cam_extractor(input_tensor)
heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
heatmap = cv2.resize(heatmap, raw.size)
overlay = cv2.addWeighted(np.array(raw), 0.5, heatmap, 0.5, 0)
# SAM
sam_inputs = sam_processor(raw, return_tensors="pt").to(DEVICE)
with torch.no_grad():
sam_outputs = sam_model(**sam_inputs)
sam_mask = sam_outputs.pred_masks[0][0].cpu().numpy()
# Mask outline and final image
mask_outline = cv2.Canny(np.uint8(sam_mask * 255), 100, 200)
mask_outline = cv2.cvtColor(mask_outline, cv2.COLOR_GRAY2RGB)
final = cv2.addWeighted(overlay, 1, mask_outline, 1, 0)
# Save output
fname = os.path.basename(image_path)
out_path = os.path.join(save_path, f"overlay_{fname}")
Image.fromarray(final).save(out_path)
print(f"✅ Saved Grad-CAM+SAM: {out_path}")
# -------------------------------
# Run on 5 Random Images
# -------------------------------
if __name__ == "__main__":
image_paths = glob("dataset/**/*.jpg", recursive=True)
selected = np.random.choice(image_paths, size=5, replace=False)
for input_path in selected:
apply_gradcam_sam(input_path, RESULTS_DIR)
|