kazuto1011 / grad-cam-pytorch

PyTorch re-implementation of Grad-CAM (+ vanilla/guided backpropagation, deconvnet, and occlusion sensitivity maps)

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

GradCAM with segmentation?

Karol-G opened this issue · comments

Hi again,

I am currently trying to apply GradCAM on an image segmentation unet.
For this I replaced the forward and backward method in _BaseWrapper with this:

    def forward(self, image):
        self.model.zero_grad()
        self.logits = self.model(image)
        return self.logits

    def backward(self):
        self.logits.backward(gradient=self.logits, retain_graph=True)

Because it's segmentation I am already getting back a 2d float tensor that has only ones (object) or zeros (background) and as I understand GradCAM we want to set the probabilities for the "classes" we are interested in to one. So it is already done by default and I should be able to do a backwards pass with it.

Does it make sense to apply GradCAM to segmentation?
Is my thinking correct to this point?

The problem I am currently facing is that for some reason all gradients except for the last layer are zero. And if I generate an image from the last layer it is simply my predicted segmentation. I visualized the computation graph with pytorchviz to see if I destroyed the gradient path somewhere, but this seems not to be the case.

Any idea why this could be happening?

Best,
Karol

Grad-CAM paper says in Fig.2 caption:

Given an image and a class of interest (e.g., ‘tiger cat’ or any other type of differentiable output) as input, we forward propagate the image through the CNN part of the model and then through task-specific computations to obtain a raw score for the category. The gradients are set to zero for all classes except the desired class (tiger cat), which is set to 1.

I guess the segmentation maps you have were produced by argmax or any other non-differentiable operations. That's why you got zero gradients throughout the model. You need to compute gradients with respect to the isolated raw score (logits) in the desired class, before binarizing by argmax or normalizing over classes by softmax. That can be done by setting one-hot values on the gradient of the final logit maps like my manner or calling backward() on the sum over the logit values of the masked regions.