jacobgil / pytorch-grad-cam

Advanced AI Explainability for computer vision. Support for CNNs, Vision Transformers, Classification, Object detection, Segmentation, Image similarity and more.

Home Page:https://jacobgil.github.io/pytorch-gradcam-book

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

GradCAM throws error for models that give ClassifierOutput class as model output instead of tensors

AdityaDeodeshmukh opened this issue · comments

Using the SwinForImageClassification model using the code given below:

class ModelOutputTarget:
    def __init__(self):
    def __call__(self, model_output):
        return torch.sigmoid(model_output)
image_processor = AutoImageProcessor.from_pretrained("microsoft/swin-base-patch4-window7-224-in22k")
target_layers = [model.swin.encoder.layers[-1].blocks[1].layernorm_before]
input_tensor = image_processor(img,return_tensors="pt")
cam = GradCAM(model=model, target_layers=target_layers)
targets = [ModelOutputTarget()]
grayscale_cam = cam(input_tensor['pixel_values'], targets=targets)
​grayscale_cam = grayscale_cam[0, :]
visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
​model_outputs = cam.outputs

Throws the following error:

TypeError                                 Traceback (most recent call last)
Cell In[13], line 16
     13 targets = [ModelOutputTarget()]
     15 # You can also pass aug_smooth=True and eigen_smooth=True, to apply smoothing.
---> 16 grayscale_cam = cam(input_tensor['pixel_values'], targets=targets)
     18 # In this example grayscale_cam has only one image in the batch:
     19 grayscale_cam = grayscale_cam[0, :]

File /opt/conda/lib/python3.10/site-packages/pytorch_grad_cam/base_cam.py:192, in BaseCAM.__call__(self, input_tensor, targets, aug_smooth, eigen_smooth)
    188 if aug_smooth is True:
    189     return self.forward_augmentation_smoothing(
    190         input_tensor, targets, eigen_smooth)
--> 192 return self.forward(input_tensor,
    193                     targets, eigen_smooth)

File /opt/conda/lib/python3.10/site-packages/pytorch_grad_cam/base_cam.py:92, in BaseCAM.forward(self, input_tensor, targets, eigen_smooth)
     90 if self.uses_gradients:
     91     self.model.zero_grad()
---> 92     loss = sum([target(output)
     93                for target, output in zip(targets, outputs)])
     94     loss.backward(retain_graph=True)
     96 # In most of the saliency attribution papers, the saliency is
     97 # computed with a single target layer.
     98 # Commonly it is the last convolutional layer.
    103 # use all conv layers for example, all Batchnorm layers,
    104 # or something else.

File /opt/conda/lib/python3.10/site-packages/pytorch_grad_cam/base_cam.py:92, in <listcomp>(.0)
     90 if self.uses_gradients:
     91     self.model.zero_grad()
---> 92     loss = sum([target(output)
     93                for target, output in zip(targets, outputs)])
     94     loss.backward(retain_graph=True)
     96 # In most of the saliency attribution papers, the saliency is
     97 # computed with a single target layer.
     98 # Commonly it is the last convolutional layer.
    103 # use all conv layers for example, all Batchnorm layers,
    104 # or something else.

Cell In[12], line 6, in ModelOutputTarget.__call__(self, model_output)
      5 def __call__(self, model_output):
----> 6     return torch.sigmoid(model_output)

TypeError: sigmoid(): argument 'input' (position 1) must be Tensor, not str

This is being caused due to the following code in the base_cam.py:

self.outputs = outputs = self.activations_and_grads(input_tensor)

if targets is None:
    target_categories = np.argmax(outputs.cpu().data.numpy(), axis=-1)
    targets = [ClassifierOutputTarget(
        category) for category in target_categories]

if self.uses_gradients:
    loss = sum([target(output)
               for target, output in zip(targets, outputs)])

Since the output of the SwinForImageClassification model is a SwinClassifierOutput object, when using list comprehension to get the loss, all it captures is the key Logits and hence cannot get the proper sigmoid. Is there any workaround for this issue. Though not tested on other models, this issue will probable occur with any model that will pass a custom object at the end of the forward function.