Spaces:
Sleeping
Sleeping
File size: 5,766 Bytes
55459f2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 |
#!/usr/bin/env python
# coding: utf-8
# In[1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import cv2
import numpy as np
from torchvision import transforms
import matplotlib.pyplot as plt
from PIL import Image
import streamlit as st
# In[2]:
def preprocess_image(image_path):
"""
Load and preprocess an image for inference.
"""
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
img = Image.open(image_path).convert('RGB')
tensor = transform(img)
return tensor.unsqueeze(0), img
# In[3]:
def get_last_conv_layer(model):
"""
Get the last convolutional layer in the model.
"""
# For ResNet architecture
for name, module in reversed(list(model.named_modules())):
if isinstance(module, nn.Conv2d):
return name
raise ValueError("No Conv2d layers found in the model.")
# In[4]:
def apply_gradcam(model, image_tensor, target_class=None):
"""
Apply Grad-CAM to an image.
"""
device = next(model.parameters()).device
image_tensor = image_tensor.to(device)
# Register hooks to get activations and gradients
features = []
gradients = []
def forward_hook(module, input, output):
features.append(output.detach())
def backward_hook(module, grad_input, grad_output):
gradients.append(grad_output[0].detach())
last_conv_layer_name = get_last_conv_layer(model)
last_conv_layer = dict(model.named_modules())[last_conv_layer_name]
handle_forward = last_conv_layer.register_forward_hook(forward_hook)
handle_backward = last_conv_layer.register_full_backward_hook(backward_hook)
# Forward pass
model.eval()
output = model(image_tensor)
if target_class is None:
target_class = output.argmax(dim=1).item()
# Zero out all gradients
model.zero_grad()
# Backward pass
one_hot = torch.zeros_like(output)
one_hot[0][target_class] = 1
output.backward(gradient=one_hot)
# Remove hooks
handle_forward.remove()
handle_backward.remove()
# Get feature maps and gradients
feature_map = features[-1].squeeze().cpu().numpy()
gradient = gradients[-1].squeeze().cpu().numpy()
# Global Average Pooling on gradients
pooled_gradients = np.mean(gradient, axis=(1, 2), keepdims=True)
cam = feature_map * pooled_gradients
cam = np.sum(cam, axis=0)
# Apply ReLU
cam = np.maximum(cam, 0)
# Normalize the CAM
cam = cam - np.min(cam)
cam = cam / np.max(cam)
# Resize CAM to match the original image size
cam = cv2.resize(cam, (224, 224))
return cam
# In[5]:
def overlay_heatmap(original_image, heatmap, alpha=0.5):
"""
Overlay the heatmap on the original image.
Args:
original_image (np.ndarray): Original image (H, W, 3), uint8
heatmap (np.ndarray): Grad-CAM heatmap (H', W'), float between 0 and 1
alpha (float): Weight for the heatmap
Returns:
np.ndarray: Overlayed image
"""
# Ensure heatmap is 2D
if heatmap.ndim == 3:
heatmap = np.mean(heatmap, axis=2)
# Resize heatmap to match original image size
heatmap_resized = cv2.resize(heatmap, (original_image.shape[1], original_image.shape[0]))
# Normalize heatmap to [0, 255]
heatmap_resized = np.uint8(255 * heatmap_resized)
# Apply colormap
heatmap_colored = cv2.applyColorMap(heatmap_resized, cv2.COLORMAP_JET)
# Convert from BGR to RGB
heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB)
# Superimpose: blend heatmap and original image
superimposed_img = heatmap_colored * alpha + original_image * (1 - alpha)
return np.uint8(superimposed_img)
def visualize_gradcam(model, image_path):
"""
Visualize Grad-CAM for a given image.
"""
# Preprocess image
image_tensor, original_image = preprocess_image(image_path)
original_image_np = np.array(original_image) # PIL -> numpy array
# Resize original image for better display
max_size = (400, 400) # Max width and height
original_image_resized = cv2.resize(original_image_np, max_size)
# Apply Grad-CAM
cam = apply_gradcam(model, image_tensor)
# Resize CAM to match original image size
heatmap_resized = cv2.resize(cam, (original_image_np.shape[1], original_image_np.shape[0]))
# Normalize heatmap to [0, 255]
heatmap_resized = np.uint8(255 * heatmap_resized / np.max(heatmap_resized))
# Apply color map
heatmap_colored = cv2.applyColorMap(heatmap_resized, cv2.COLORMAP_JET)
heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB)
# Overlay
superimposed_img = heatmap_colored * 0.4 + original_image_np * 0.6
superimposed_img = np.clip(superimposed_img, 0, 255).astype(np.uint8)
# Display results
fig, axes = plt.subplots(1, 2, figsize=(8, 4)) # Adjust figsize as needed
axes[0].imshow(original_image_resized)
axes[0].set_title("Original Image")
axes[0].axis("off")
axes[1].imshow(superimposed_img)
axes[1].set_title("Grad-CAM Heatmap")
axes[1].axis("off")
plt.tight_layout()
st.pyplot(fig)
plt.close(fig)
# In[6]:
if __name__ == "__main__":
from models.resnet_model import MalariaResNet50
# Load your trained model
model = MalariaResNet50(num_classes=2)
model.load_state_dict(torch.load("models/malaria_model.pth"))
model.eval()
# Path to an image
image_path = "malaria_ds/split_dataset/test/Parasitized/C33P1thinF_IMG_20150619_114756a_cell_181.png"
# Visualize Grad-CAM
visualize_gradcam(model, image_path)
# In[ ]:
|