File size: 9,894 Bytes
d3dbf03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F


class GradCAM:
    """GradCAM class helps create visualization results.



    Visualization results are blended by heatmaps and input images.

    This class is modified from

    https://github.com/facebookresearch/SlowFast/blob/master/slowfast/visualization/gradcam_utils.py # noqa

    For more information about GradCAM, please visit:

    https://arxiv.org/pdf/1610.02391.pdf



    Args:

        model (nn.Module): the recognizer model to be used.

        target_layer_name (str): name of convolutional layer to

            be used to get gradients and feature maps from for creating

            localization maps.

        colormap (str): matplotlib colormap used to create

            heatmap. Defaults to 'viridis'. For more information, please visit

            https://matplotlib.org/3.3.0/tutorials/colors/colormaps.html

    """

    def __init__(self,

                 model: nn.Module,

                 target_layer_name: str,

                 colormap: str = 'viridis') -> None:
        from ..models.recognizers import Recognizer2D, Recognizer3D
        if isinstance(model, Recognizer2D):
            self.is_recognizer2d = True
        elif isinstance(model, Recognizer3D):
            self.is_recognizer2d = False
        else:
            raise ValueError(
                'GradCAM utils only support Recognizer2D & Recognizer3D.')

        self.model = model
        self.model.eval()
        self.target_gradients = None
        self.target_activations = None

        import matplotlib.pyplot as plt
        self.colormap = plt.get_cmap(colormap)
        self._register_hooks(target_layer_name)

    def _register_hooks(self, layer_name: str) -> None:
        """Register forward and backward hook to a layer, given layer_name, to

        obtain gradients and activations.



        Args:

            layer_name (str): name of the layer.

        """

        def get_gradients(module, grad_input, grad_output):
            self.target_gradients = grad_output[0].detach()

        def get_activations(module, input, output):
            self.target_activations = output.clone().detach()

        layer_ls = layer_name.split('/')
        prev_module = self.model
        for layer in layer_ls:
            prev_module = prev_module._modules[layer]

        target_layer = prev_module
        target_layer.register_forward_hook(get_activations)
        target_layer.register_backward_hook(get_gradients)

    def _calculate_localization_map(self,

                                    data: dict,

                                    use_labels: bool,

                                    delta=1e-20) -> tuple:
        """Calculate localization map for all inputs with Grad-CAM.



        Args:

            data (dict): model inputs, generated by test pipeline,

            use_labels (bool): Whether to use given labels to generate

                localization map.

            delta (float): used in localization map normalization,

                must be small enough. Please make sure

                `localization_map_max - localization_map_min >> delta`



        Returns:

            localization_map (torch.Tensor): the localization map for

            input imgs.

            preds (torch.Tensor): Model predictions with shape

            (batch_size, num_classes).

        """
        inputs = data['inputs']

        # use score before softmax
        self.model.cls_head.average_clips = 'score'
        # model forward & backward
        results = self.model.test_step(data)
        preds = [result.pred_score for result in results]
        preds = torch.stack(preds)

        if use_labels:
            labels = [result.gt_label for result in results]
            labels = torch.stack(labels)
            score = torch.gather(preds, dim=1, index=labels)
        else:
            score = torch.max(preds, dim=-1)[0]
        self.model.zero_grad()
        score = torch.sum(score)
        score.backward()

        imgs = torch.stack(inputs)
        if self.is_recognizer2d:
            # [batch_size, num_segments, 3, H, W]
            b, t, _, h, w = imgs.size()
        else:
            # [batch_size, num_crops*num_clips, 3, clip_len, H, W]
            b1, b2, _, t, h, w = imgs.size()
            b = b1 * b2

        gradients = self.target_gradients
        activations = self.target_activations
        if self.is_recognizer2d:
            # [B*Tg, C', H', W']
            b_tg, c, _, _ = gradients.size()
            tg = b_tg // b
        else:
            # source shape: [B, C', Tg, H', W']
            _, c, tg, _, _ = gradients.size()
            # target shape: [B, Tg, C', H', W']
            gradients = gradients.permute(0, 2, 1, 3, 4)
            activations = activations.permute(0, 2, 1, 3, 4)

        # calculate & resize to [B, 1, T, H, W]
        weights = torch.mean(gradients.view(b, tg, c, -1), dim=3)
        weights = weights.view(b, tg, c, 1, 1)
        activations = activations.view([b, tg, c] +
                                       list(activations.size()[-2:]))
        localization_map = torch.sum(
            weights * activations, dim=2, keepdim=True)
        localization_map = F.relu(localization_map)
        localization_map = localization_map.permute(0, 2, 1, 3, 4)
        localization_map = F.interpolate(
            localization_map,
            size=(t, h, w),
            mode='trilinear',
            align_corners=False)

        # Normalize the localization map.
        localization_map_min, localization_map_max = (
            torch.min(localization_map.view(b, -1), dim=-1, keepdim=True)[0],
            torch.max(localization_map.view(b, -1), dim=-1, keepdim=True)[0])
        localization_map_min = torch.reshape(
            localization_map_min, shape=(b, 1, 1, 1, 1))
        localization_map_max = torch.reshape(
            localization_map_max, shape=(b, 1, 1, 1, 1))
        localization_map = (localization_map - localization_map_min) / (
            localization_map_max - localization_map_min + delta)
        localization_map = localization_map.data

        return localization_map.squeeze(dim=1), preds

    def _alpha_blending(self, localization_map: torch.Tensor,

                        input_imgs: torch.Tensor,

                        alpha: float) -> torch.Tensor:
        """Blend heatmaps and model input images and get visulization results.



        Args:

            localization_map (torch.Tensor): localization map for all inputs,

                generated with Grad-CAM.

            input_imgs (torch.Tensor): model inputs, raw images.

            alpha (float): transparency level of the heatmap,

                in the range [0, 1].



        Returns:

            torch.Tensor: blending results for localization map and input

            images, with shape [B, T, H, W, 3] and pixel values in

            RGB order within range [0, 1].

        """
        # localization_map shape [B, T, H, W]
        localization_map = localization_map.cpu()

        # heatmap shape [B, T, H, W, 3] in RGB order
        heatmap = self.colormap(localization_map.detach().numpy())
        heatmap = heatmap[..., :3]
        heatmap = torch.from_numpy(heatmap)
        input_imgs = torch.stack(input_imgs)
        # Permute input imgs to [B, T, H, W, 3], like heatmap
        if self.is_recognizer2d:
            # Recognizer2D input (B, T, C, H, W)
            curr_inp = input_imgs.permute(0, 1, 3, 4, 2)
        else:
            # Recognizer3D input (B', num_clips*num_crops, C, T, H, W)
            # B = B' * num_clips * num_crops
            curr_inp = input_imgs.view([-1] + list(input_imgs.size()[2:]))
            curr_inp = curr_inp.permute(0, 2, 3, 4, 1)

        # renormalize input imgs to [0, 1]
        curr_inp = curr_inp.cpu().float()
        curr_inp /= 255.

        # alpha blending
        blended_imgs = alpha * heatmap + (1 - alpha) * curr_inp

        return blended_imgs

    def __call__(self,

                 data: dict,

                 use_labels: bool = False,

                 alpha: float = 0.5) -> tuple:
        """Visualize the localization maps on their corresponding inputs as

        heatmap, using Grad-CAM.



        Generate visualization results for **ALL CROPS**.

        For example, for I3D model, if `clip_len=32, num_clips=10` and

        use `ThreeCrop` in test pipeline, then for every model inputs,

        there are 960(32*10*3) images generated.



        Args:

            data (dict): model inputs, generated by test pipeline.

            use_labels (bool): Whether to use given labels to generate

                localization map.

            alpha (float): transparency level of the heatmap,

                in the range [0, 1].



        Returns:

            blended_imgs (torch.Tensor): Visualization results, blended by

            localization maps and model inputs.

            preds (torch.Tensor): Model predictions for inputs.

        """

        # localization_map shape [B, T, H, W]
        # preds shape [batch_size, num_classes]
        localization_map, preds = self._calculate_localization_map(
            data, use_labels=use_labels)

        # blended_imgs shape [B, T, H, W, 3]
        blended_imgs = self._alpha_blending(localization_map, data['inputs'],
                                            alpha)

        # blended_imgs shape [B, T, H, W, 3]
        # preds shape [batch_size, num_classes]
        # Recognizer2D: B = batch_size, T = num_segments
        # Recognizer3D: B = batch_size * num_crops * num_clips, T = clip_len
        return blended_imgs, preds