File size: 7,003 Bytes
c17bef1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
import os, sys
import torch
import numpy as np
import cv2
from torchvision import transforms
from PIL import Image
# Add model import path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from models.efficientnet_b0 import EfficientNetB0Classifier
class GuidedBackprop:
def __init__(self, model):
self.model = model.eval()
self.gradients = None
self.forward_relu_outputs = []
# Replace ReLU with GuidedReLU
self.relu_targets = []
for module in self.model.modules():
if isinstance(module, torch.nn.ReLU):
module.register_forward_hook(self.forward_hook)
module.register_full_backward_hook(self.backward_hook)
self.relu_targets.append(module)
def forward_hook(self, module, input, output):
self.forward_relu_outputs.append(output)
def backward_hook(self, module, grad_in, grad_out):
corresponding_forward_output = self.forward_relu_outputs[-len(self.relu_targets):]
corresponding_forward_output = corresponding_forward_output[self.relu_targets.index(module)]
positive_mask = torch.gt(corresponding_forward_output, 0.).float()
positive_grad = torch.gt(grad_in[0], 0.).float()
grad = positive_mask * positive_grad * grad_in[0]
return (grad,)
def generate_gradients(self, input_image, target_class=None):
self.forward_relu_outputs = []
# Ensure input requires gradients
input_image.requires_grad_(True)
# Forward pass
model_output = self.model(input_image)
output_score = model_output.squeeze()
# Backward pass
self.model.zero_grad()
output_score.backward(retain_graph=True)
gradients_as_arr = input_image.grad.data.cpu().numpy()[0]
return gradients_as_arr
class GuidedGradCAM:
def __init__(self, model, target_layer):
self.model = model.eval()
self.target_layer = target_layer
self.activations = None
self.gradients = None
self.guided_backprop = GuidedBackprop(model)
self.hooks = [
target_layer.register_forward_hook(self.save_activation),
target_layer.register_full_backward_hook(self.save_gradient)
]
def save_activation(self, module, input, output):
self.activations = output.detach()
def save_gradient(self, module, grad_input, grad_output):
self.gradients = grad_output[0]
def generate_cam(self, input_tensor, target_class=None):
# Get gradients and activations
model_output = self.model(input_tensor)
self.model.zero_grad()
output_score = model_output.squeeze() # For binary classification
output_score.backward(retain_graph=True)
if self.gradients is None:
raise RuntimeError("No gradients found. The model architecture might not be compatible.")
# Generate GradCAM
gradients = self.gradients.squeeze()
activations = self.activations.squeeze()
weights = torch.mean(gradients, dim=(1, 2))
cam = torch.zeros(activations.shape[1:], dtype=torch.float32, device=activations.device)
for i, w in enumerate(weights):
cam += w * activations[i]
cam = torch.relu(cam)
cam = cam - torch.min(cam)
cam = cam / (torch.max(cam) + 1e-7)
# Generate Guided Backprop
guided_grads = self.guided_backprop.generate_gradients(input_tensor)
# Resize CAM to match input size
cam_np = cam.cpu().numpy()
cam_np = cv2.resize(cam_np, (380, 380))
# Combine them
cam_gb = np.multiply(cam_np[..., np.newaxis], guided_grads.transpose(1, 2, 0))
return cam_np, guided_grads, cam_gb
def __call__(self, input_tensor):
return self.generate_cam(input_tensor)
def remove_hooks(self):
for hook in self.hooks:
hook.remove()
def preprocess_and_show_results(img_path, cam, guided_grads, cam_gb, save_path):
# Read and preprocess original image
raw = cv2.imread(img_path)
raw = cv2.resize(raw, (380, 380))
# Process GradCAM
heatmap = np.uint8(255 * cam)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
gradcam_overlay = cv2.addWeighted(raw, 0.6, heatmap, 0.4, 0)
# Process Guided Backprop
guided_grads = guided_grads.transpose(1, 2, 0)
guided_grads = np.maximum(guided_grads, 0)
guided_grads = (guided_grads - guided_grads.min()) / (guided_grads.max() - guided_grads.min() + 1e-8)
guided_grads = np.uint8(255 * guided_grads)
# Process Guided GradCAM
cam_gb = np.maximum(cam_gb, 0)
cam_gb = (cam_gb - cam_gb.min()) / (cam_gb.max() - cam_gb.min() + 1e-8)
cam_gb = np.uint8(255 * cam_gb)
# Save all visualizations
cv2.imwrite(save_path.replace('.png', '_gradcam.png'), gradcam_overlay)
cv2.imwrite(save_path.replace('.png', '_guided_bp.png'), guided_grads)
cv2.imwrite(save_path.replace('.png', '_guided_gradcam.png'), cam_gb)
cv2.imwrite(save_path.replace('.png', '_original.png'), raw)
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load model
model = EfficientNetB0Classifier()
model.load_state_dict(torch.load("results_efficientnet_b0/efficientnet_best9912.pth", map_location=device))
model = model.eval()
# Enable gradients for feature visualization
for param in model.parameters():
param.requires_grad = True
model.to(device)
# Image transform
transform = transforms.Compose([
transforms.Resize((380, 380)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Use the final convolutional layer
target_layer = model.base_model.features[-1]
guided_gradcam = GuidedGradCAM(model, target_layer)
# Process test images
image_paths = np.load("test_paths.npy", allow_pickle=True).astype(str)
np.random.seed(42)
selected_indices = np.random.choice(len(image_paths), 5, replace=False)
os.makedirs("results/guided_gradcam", exist_ok=True)
for i in selected_indices:
input_path = image_paths[i]
output_path = f"results/guided_gradcam/visualization_{i}.png"
# Process image
img = Image.open(input_path).convert("RGB")
input_tensor = transform(img).unsqueeze(0).to(device)
# Generate visualizations
cam, guided_grads, cam_gb = guided_gradcam(input_tensor)
# Save results
preprocess_and_show_results(input_path, cam, guided_grads, cam_gb, output_path)
print(f"✅ Saved visualizations for image {i}")
guided_gradcam.remove_hooks() |