File size: 5,766 Bytes
55459f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
#!/usr/bin/env python
# coding: utf-8

# In[1]:


import torch
import torch.nn as nn
import torch.nn.functional as F
import cv2
import numpy as np
from torchvision import transforms
import matplotlib.pyplot as plt
from PIL import Image
import streamlit as st


# In[2]:


def preprocess_image(image_path):
    """
    Load and preprocess an image for inference.
    """
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    img = Image.open(image_path).convert('RGB')
    tensor = transform(img)
    return tensor.unsqueeze(0), img


# In[3]:


def get_last_conv_layer(model):
    """
    Get the last convolutional layer in the model.
    """
    # For ResNet architecture
    for name, module in reversed(list(model.named_modules())):
        if isinstance(module, nn.Conv2d):
            return name
    raise ValueError("No Conv2d layers found in the model.")


# In[4]:


def apply_gradcam(model, image_tensor, target_class=None):
    """
    Apply Grad-CAM to an image.
    """
    device = next(model.parameters()).device
    image_tensor = image_tensor.to(device)

    # Register hooks to get activations and gradients
    features = []
    gradients = []

    def forward_hook(module, input, output):
        features.append(output.detach())

    def backward_hook(module, grad_input, grad_output):
        gradients.append(grad_output[0].detach())

    last_conv_layer_name = get_last_conv_layer(model)
    last_conv_layer = dict(model.named_modules())[last_conv_layer_name]
    handle_forward = last_conv_layer.register_forward_hook(forward_hook)
    handle_backward = last_conv_layer.register_full_backward_hook(backward_hook)

    # Forward pass
    model.eval()
    output = model(image_tensor)
    if target_class is None:
        target_class = output.argmax(dim=1).item()

    # Zero out all gradients
    model.zero_grad()

    # Backward pass
    one_hot = torch.zeros_like(output)
    one_hot[0][target_class] = 1
    output.backward(gradient=one_hot)

    # Remove hooks
    handle_forward.remove()
    handle_backward.remove()

    # Get feature maps and gradients
    feature_map = features[-1].squeeze().cpu().numpy()
    gradient = gradients[-1].squeeze().cpu().numpy()

    # Global Average Pooling on gradients
    pooled_gradients = np.mean(gradient, axis=(1, 2), keepdims=True)
    cam = feature_map * pooled_gradients
    cam = np.sum(cam, axis=0)

    # Apply ReLU
    cam = np.maximum(cam, 0)

    # Normalize the CAM
    cam = cam - np.min(cam)
    cam = cam / np.max(cam)

    # Resize CAM to match the original image size
    cam = cv2.resize(cam, (224, 224))

    return cam


# In[5]:


def overlay_heatmap(original_image, heatmap, alpha=0.5):
    """
    Overlay the heatmap on the original image.

    Args:
        original_image (np.ndarray): Original image (H, W, 3), uint8
        heatmap (np.ndarray): Grad-CAM heatmap (H', W'), float between 0 and 1
        alpha (float): Weight for the heatmap

    Returns:
        np.ndarray: Overlayed image
    """
    # Ensure heatmap is 2D
    if heatmap.ndim == 3:
        heatmap = np.mean(heatmap, axis=2)

    # Resize heatmap to match original image size
    heatmap_resized = cv2.resize(heatmap, (original_image.shape[1], original_image.shape[0]))

    # Normalize heatmap to [0, 255]
    heatmap_resized = np.uint8(255 * heatmap_resized)

    # Apply colormap
    heatmap_colored = cv2.applyColorMap(heatmap_resized, cv2.COLORMAP_JET)

    # Convert from BGR to RGB
    heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB)

    # Superimpose: blend heatmap and original image
    superimposed_img = heatmap_colored * alpha + original_image * (1 - alpha)
    return np.uint8(superimposed_img)

def visualize_gradcam(model, image_path):
    """
    Visualize Grad-CAM for a given image.
    """
    # Preprocess image
    image_tensor, original_image = preprocess_image(image_path)
    original_image_np = np.array(original_image)  # PIL -> numpy array

    # Resize original image for better display
    max_size = (400, 400)  # Max width and height
    original_image_resized = cv2.resize(original_image_np, max_size)

    # Apply Grad-CAM
    cam = apply_gradcam(model, image_tensor)

    # Resize CAM to match original image size
    heatmap_resized = cv2.resize(cam, (original_image_np.shape[1], original_image_np.shape[0]))

    # Normalize heatmap to [0, 255]
    heatmap_resized = np.uint8(255 * heatmap_resized / np.max(heatmap_resized))

    # Apply color map
    heatmap_colored = cv2.applyColorMap(heatmap_resized, cv2.COLORMAP_JET)
    heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB)

    # Overlay
    superimposed_img = heatmap_colored * 0.4 + original_image_np * 0.6
    superimposed_img = np.clip(superimposed_img, 0, 255).astype(np.uint8)

    # Display results
    fig, axes = plt.subplots(1, 2, figsize=(8, 4))  # Adjust figsize as needed
    axes[0].imshow(original_image_resized)
    axes[0].set_title("Original Image")
    axes[0].axis("off")

    axes[1].imshow(superimposed_img)
    axes[1].set_title("Grad-CAM Heatmap")
    axes[1].axis("off")

    plt.tight_layout()
    st.pyplot(fig)
    plt.close(fig)


# In[6]:


if __name__ == "__main__":

    from models.resnet_model import MalariaResNet50
    # Load your trained model
    model = MalariaResNet50(num_classes=2)
    model.load_state_dict(torch.load("models/malaria_model.pth"))
    model.eval()

    # Path to an image
    image_path = "malaria_ds/split_dataset/test/Parasitized/C33P1thinF_IMG_20150619_114756a_cell_181.png"

    # Visualize Grad-CAM
    visualize_gradcam(model, image_path)


# In[ ]: