RuntimeError: Mismatch in shape: grad_output[0] has a shape of torch.Size([1, 1, 80, 80]) and output[0] has a shape of torch.Size([1, 1, 96, 96]).
bugsuse opened this issue · comments
Yang Li commented
I encountered a problem below when I applied the lrp method using the zennit to U-net model ResNet-based backend with the input shape (1, 32, 96, 96) (e.g. (batch size, channels, width, height)) and output shape (1, 1, 80, 80). Do the input and output shapes have to be the same? If not, do you have any solution to solve this problem? Any help would be appreciated!
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-30-5d15a2db92a9> in <module>
55 output_relevance[:, -1, :, :] = output[:, -1, :, :]
56 # this will compute the modified gradient of model, with the on
---> 57 output, relevance = attributor(data.cuda(), output_relevance.cuda())
58
59 # sum over the color channel for visualization
~/tools/miniconda3/envs/pytorch/lib/python3.8/site-packages/zennit/attribution.py in __call__(self, input, attr_output)
130
131 if self.composite is None or self.composite.handles:
--> 132 return self.forward(input, attr_output_fn)
133
134 with self:
~/tools/miniconda3/envs/pytorch/lib/python3.8/site-packages/zennit/attribution.py in forward(self, input, attr_output_fn)
175 input = input.detach().requires_grad_(True)
176 output = self.model(input)
--> 177 gradient, = torch.autograd.grad((output,), (input,), grad_outputs=(attr_output_fn(output.detach()),))
178 return output, gradient
179
~/tools/miniconda3/envs/pytorch/lib/python3.8/site-packages/torch/autograd/__init__.py in grad(outputs, inputs, grad_outputs, retain_graph, create_graph, only_inputs, allow_unused)
195
196 grad_outputs_ = _tensor_or_tensors_to_tuple(grad_outputs, len(outputs))
--> 197 grad_outputs_ = _make_grads(outputs, grad_outputs_)
198
199 if retain_graph is None:
~/tools/miniconda3/envs/pytorch/lib/python3.8/site-packages/torch/autograd/__init__.py in _make_grads(outputs, grads)
31 if isinstance(grad, torch.Tensor):
32 if not out.shape == grad.shape:
---> 33 raise RuntimeError("Mismatch in shape: grad_output["
34 + str(grads.index(grad)) + "] has a shape of "
35 + str(grad.shape) + " and output["
RuntimeError: Mismatch in shape: grad_output[0] has a shape of torch.Size([1, 1, 80, 80]) and output[0] has a shape of torch.Size([1, 1, 96, 96]).