kazuto1011 / grad-cam-pytorch

PyTorch re-implementation of Grad-CAM (+ vanilla/guided backpropagation, deconvnet, and occlusion sensitivity maps)

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

gcam return zero after F.relu(gcam)

baogiadoan opened this issue · comments

Is it possible for a feature map that doesn't have any positive related information to the targeted class? As for now, my situation is that the result at that feature map returns all negative values for gcam, and thus lead to Zero output for GradCAM after F.relu(gcam)

First of all, for models with ReLU, negative values of the intermediate feature map are meaningless/ambiguous information for the final scores. Grad-CAM generally chooses the feature maps "rectified" after ReLU. Then the negative values of the final weighted map purely depend on the top-down gradients. The regions which are "negative" to the task are removed by the final ReLU.

If you want to visualize the negativeness for instead, why don't you use the modified weights described in the paper's Section 7 "Counterfactual Explanations"?

# Just flip the gradient's signs
weights = F.adaptive_avg_pool2d(-grads, 1)

The layer looks correct. However, gradients with respect to features.23 layer could have negative values, which are used to weigh the positive feature maps. Moreover, when the global average pooling, the positive gradients occurred locally may disappear in spatially occupied negative gradients. Please check the gradients, not the feature maps.

I found out that weights = self._compute_grad_weights(grads) results contained both positive and negative values, fmaps contained all positive values, so obviously multiply together will give both positive and negative values, but when adding together the result returns all negative causing the final gcam result shows nothing after F.relu(gcam)
gcam = torch.mul(fmaps, weights).sum(dim=1, keepdim=True)

If I add one more line weights = F.relu(weights) before the code gcam = torch.mul(fmaps, weights).sum(dim=1, keepdim=True) meaning that I remove out all negative values, it now can show the heatmap for the object I want, not sure is it the right thing to do...

Grad-CAM heatmap represents average contribution over the channels. Removing the negative weights can push up the positive regions; however, they might be overestimated even if the regions actually weaken the target score. I would say the current image less contribute to the specified class or the score is derived from other regions actually.

Can someone explain what does it means for GradCAM to be all negative?

It's weird to think that all channels in the last convolution lead to "negative" contribution to the true output class.