lucidrains / vit-pytorch

Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

How to Use recorder for non-square images?

OlegJakushkin opened this issue · comments

So on A sample from ReadMe modified for not rectangular images I get an error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-6-938730bcece3> in <module>()
     22 
     23 img = torch.randn(1, 3, 256, 256*2)
---> 24 preds, attns = v(img)
     25 
     26 # there is one extra patch due to the CLS token

3 frames
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.7/dist-packages/vit_pytorch/recorder.py in forward(self, img)
     50             self._register_hook()
     51 
---> 52         pred = self.vit(img)
     53 
     54         # move all recordings to one device before stacking

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.7/dist-packages/vit_pytorch/vit.py in forward(self, img)
    116         cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
    117         x = torch.cat((cls_tokens, x), dim=1)
--> 118         x += self.pos_embedding[:, :(n + 1)]
    119         x = self.dropout(x)
    120 

RuntimeError: The size of tensor a (129) must match the size of tensor b (65) at non-singleton dimension 1

Code sample:

import torch
from vit_pytorch.vit import ViT

v = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

# import Recorder and wrap the ViT

from vit_pytorch.recorder import Recorder
v = Recorder(v)

# forward pass now returns predictions and the attention maps

img = torch.randn(1, 3, 256, 256*2)
preds, attns = v(img)

# there is one extra patch due to the CLS token

attns.shape # (1, 6, 16, 65, 65) - (batch x layers x heads x patch x patch)

Missed the point that image_size must be the biggest value between W and H of an image. This works

import torch
from vit_pytorch import ViT

v = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

# import Recorder and wrap the ViT

from vit_pytorch.recorder import Recorder
v = Recorder(v)

# forward pass now returns predictions and the attention maps

img = torch.randn(1, 3, 256, 128)
preds, attns = v(img)

# there is one extra patch due to the CLS token

attns.shape # (1, 6, 16, 65, 65) - (batch x layers x heads x patch x patch)

Outputs:

torch.Size([1, 6, 16, 33, 33])