How to Use recorder for non-square images?
OlegJakushkin opened this issue · comments
Oleg Jakushkin commented
So on A sample from ReadMe modified for not rectangular images I get an error:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-6-938730bcece3> in <module>()
22
23 img = torch.randn(1, 3, 256, 256*2)
---> 24 preds, attns = v(img)
25
26 # there is one extra patch due to the CLS token
3 frames
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1049 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1050 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051 return forward_call(*input, **kwargs)
1052 # Do not call functions when jit is used
1053 full_backward_hooks, non_full_backward_hooks = [], []
/usr/local/lib/python3.7/dist-packages/vit_pytorch/recorder.py in forward(self, img)
50 self._register_hook()
51
---> 52 pred = self.vit(img)
53
54 # move all recordings to one device before stacking
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1049 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1050 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051 return forward_call(*input, **kwargs)
1052 # Do not call functions when jit is used
1053 full_backward_hooks, non_full_backward_hooks = [], []
/usr/local/lib/python3.7/dist-packages/vit_pytorch/vit.py in forward(self, img)
116 cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
117 x = torch.cat((cls_tokens, x), dim=1)
--> 118 x += self.pos_embedding[:, :(n + 1)]
119 x = self.dropout(x)
120
RuntimeError: The size of tensor a (129) must match the size of tensor b (65) at non-singleton dimension 1
Code sample:
import torch
from vit_pytorch.vit import ViT
v = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
# import Recorder and wrap the ViT
from vit_pytorch.recorder import Recorder
v = Recorder(v)
# forward pass now returns predictions and the attention maps
img = torch.randn(1, 3, 256, 256*2)
preds, attns = v(img)
# there is one extra patch due to the CLS token
attns.shape # (1, 6, 16, 65, 65) - (batch x layers x heads x patch x patch)
Oleg Jakushkin commented
Missed the point that image_size must be the biggest value between W and H of an image. This works
import torch
from vit_pytorch import ViT
v = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
# import Recorder and wrap the ViT
from vit_pytorch.recorder import Recorder
v = Recorder(v)
# forward pass now returns predictions and the attention maps
img = torch.randn(1, 3, 256, 128)
preds, attns = v(img)
# there is one extra patch due to the CLS token
attns.shape # (1, 6, 16, 65, 65) - (batch x layers x heads x patch x patch)
Outputs:
torch.Size([1, 6, 16, 33, 33])