|
import os |
|
import torch |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
import cv2 |
|
from PIL import Image |
|
from torchvision import transforms |
|
from transformers import SamModel, SamProcessor |
|
from models.efficientnet_b4 import EfficientNetB4Classifier |
|
from glob import glob |
|
|
|
|
|
|
|
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
RESULTS_DIR = "results/gradcam_sam" |
|
os.makedirs(RESULTS_DIR, exist_ok=True) |
|
|
|
|
|
|
|
|
|
class GradCAM: |
|
def __init__(self, model, target_layer): |
|
self.model = model |
|
self.target_layer = target_layer |
|
self.gradients = None |
|
self.activations = None |
|
|
|
def forward_hook(module, input, output): |
|
self.activations = output |
|
|
|
def backward_hook(module, grad_input, grad_output): |
|
self.gradients = grad_output[0] |
|
|
|
target_layer.register_forward_hook(forward_hook) |
|
target_layer.register_full_backward_hook(backward_hook) |
|
|
|
def __call__(self, x): |
|
self.model.zero_grad() |
|
out = self.model(x) |
|
out.backward(torch.ones_like(out)) |
|
|
|
weights = self.gradients.mean(dim=(2, 3), keepdim=True) |
|
cam = (weights * self.activations).sum(dim=1, keepdim=True) |
|
cam = torch.nn.functional.relu(cam) |
|
cam = torch.nn.functional.interpolate(cam, size=x.shape[2:], mode='bilinear', align_corners=False) |
|
cam = cam.squeeze().detach().cpu().numpy() |
|
cam = (cam - cam.min()) / (cam.max() - cam.min()) |
|
return cam |
|
|
|
|
|
|
|
|
|
model = EfficientNetB4Classifier() |
|
model.load_state_dict(torch.load("efficientnetb4_best8595.pth", map_location=DEVICE)) |
|
model.to(DEVICE).eval() |
|
target_layer = model.base_model.features[-1] |
|
cam_extractor = GradCAM(model, target_layer) |
|
|
|
|
|
|
|
|
|
sam_model = SamModel.from_pretrained("facebook/sam-vit-base").to(DEVICE) |
|
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base") |
|
|
|
|
|
|
|
|
|
def apply_gradcam_sam(image_path, save_path): |
|
raw = Image.open(image_path).convert("RGB") |
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((380, 380)), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
]) |
|
input_tensor = transform(raw).unsqueeze(0).to(DEVICE) |
|
|
|
|
|
cam = cam_extractor(input_tensor) |
|
heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET) |
|
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) |
|
heatmap = cv2.resize(heatmap, raw.size) |
|
overlay = cv2.addWeighted(np.array(raw), 0.5, heatmap, 0.5, 0) |
|
|
|
|
|
sam_inputs = sam_processor(raw, return_tensors="pt").to(DEVICE) |
|
with torch.no_grad(): |
|
sam_outputs = sam_model(**sam_inputs) |
|
sam_mask = sam_outputs.pred_masks[0][0].cpu().numpy() |
|
|
|
|
|
mask_outline = cv2.Canny(np.uint8(sam_mask * 255), 100, 200) |
|
mask_outline = cv2.cvtColor(mask_outline, cv2.COLOR_GRAY2RGB) |
|
final = cv2.addWeighted(overlay, 1, mask_outline, 1, 0) |
|
|
|
|
|
fname = os.path.basename(image_path) |
|
out_path = os.path.join(save_path, f"overlay_{fname}") |
|
Image.fromarray(final).save(out_path) |
|
print(f"β
Saved Grad-CAM+SAM: {out_path}") |
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
image_paths = glob("dataset/**/*.jpg", recursive=True) |
|
selected = np.random.choice(image_paths, size=5, replace=False) |
|
|
|
for input_path in selected: |
|
apply_gradcam_sam(input_path, RESULTS_DIR) |
|
|