Shape mismatch for timm vit model
Hammad-Mir opened this issue · comments
while applying the attention rollout on a finetuned timm vit model (base_patch32_224) I'm am getting the following error with input tensor of shape: torch.Size([1, 3, 224, 224]):
RuntimeError Traceback (most recent call last)
in ()
----> 1 mask_1 = attention_rollout(test_image_1_tensor)
8 frames
in reshape_transform(tensor, height, width)
1 def reshape_transform(tensor, height=7, width=7):
2 result = tensor[:, 1 : , :].reshape(tensor.size(0),
----> 3 height, width, tensor.size(2))
4
5 # Bring the channels to the first dimension,
RuntimeError: shape '[1, 7, 7, 7]' is invalid for input of size 37583
Kindly advice on how to properly apply on the model as I'm facing the same issue for FullGrad in [pytorch-grad-cam] on the same model.