Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch

Vision Transformer - Pytorch

Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch. Significance is further explained in Yannic Kilcher's video. There's really not much to code here, but may as well lay it out for everyone so we expedite the attention revolution.

For a Pytorch implementation with pretrained models, please see Ross Wightman's repository here.

The official Jax repository is here.


$ pip install vit-pytorch


import torch
from vit_pytorch import ViT

v = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1

img = torch.randn(1, 3, 256, 256)
mask = torch.ones(1, 8, 8).bool() # optional mask, designating which patch to attend to

preds = v(img, mask = mask) # (1, 1000)


  • image_size: int.
    Image size.
  • patch_size: int.
    Number of patches. image_size must be divisible by patch_size.
    The number of patches is: n = (image_size // patch_size) ** 2 and n must be greater than 16.
  • num_classes: int.
    Number of classes to classify.
  • dim: int.
    Last dimension of output tensor after linear transformation nn.Linear(..., dim).
  • depth: int.
    Number of Transformer blocks.
  • heads: int.
    Number of heads in Multi-head Attention layer.
  • mlp_dim: int.
    Dimension of the MLP (FeedForward) layer.
  • channels: int, default 3.
    Number of image's channels.
  • dropout: float between [0, 1], default 0..
    Dropout rate.
  • emb_dropout: float between [0, 1], default 0.
    Embedding dropout rate.
  • pool: string, either cls token pooling or mean pooling


A recent paper has shown that use of a distillation token for distilling knowledge from convolutional nets to vision transformer can yield small and efficient vision transformers. This repository offers the means to do distillation easily.

ex. distilling from Resnet50 (or any teacher) to a vision transformer

import torch
from torchvision.models import resnet50

from vit_pytorch.distill import DistillableViT, DistillWrapper

teacher = resnet50(pretrained = True)

v = DistillableViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 8,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1

distiller = DistillWrapper(
    student = v,
    teacher = teacher,
    temperature = 3,           # temperature of distillation
    alpha = 0.5                # trade between main loss and distillation loss

img = torch.randn(2, 3, 256, 256)
labels = torch.randint(0, 1000, (2,))

loss = distiller(img, labels)

The DistillableViT class is identical to ViT except for how the forward pass is handled, so you should be able to load the parameters back to ViT after you have completed distillation training.

You can also use the handy .to_vit method on the DistillableViT instance to get back a ViT instance.

v = v.to_vit()
type(v) # <class 'vit_pytorch.vit_pytorch.ViT'>

Research Ideas

Self Supervised Training

You can train this with a near SOTA self-supervised learning technique, BYOL, with the following code.


$ pip install byol-pytorch


import torch
from vit_pytorch import ViT
from byol_pytorch import BYOL

model = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 8,
    mlp_dim = 2048

learner = BYOL(
    image_size = 256,
    hidden_layer = 'to_latent'

opt = torch.optim.Adam(learner.parameters(), lr=3e-4)

def sample_unlabelled_images():
    return torch.randn(20, 3, 256, 256)

for _ in range(100):
    images = sample_unlabelled_images()
    loss = learner(images)
    learner.update_moving_average() # update moving average of target encoder

# save your improved network
torch.save(model.state_dict(), './pretrained-net.pt')

A pytorch-lightning script is ready for you to use at the repository link above.

Efficient Attention

There may be some coming from computer vision who think attention still suffers from quadratic costs. Fortunately, we have a lot of new techniques that may help. This repository offers a way for you to plugin your own sparse attention transformer.

An example with Linformer

$ pip install linformer
import torch
from vit_pytorch.efficient import ViT
from linformer import Linformer

efficient_transformer = Linformer(
    dim = 512,
    seq_len = 4096 + 1,  # 64 x 64 patches + 1 cls token
    depth = 12,
    heads = 8,
    k = 256

v = ViT(
    dim = 512,
    image_size = 2048,
    patch_size = 32,
    num_classes = 1000,
    transformer = efficient_transformer

img = torch.randn(1, 3, 2048, 2048) # your high resolution picture
v(img) # (1, 1000)

Other sparse attention frameworks I would highly recommend is Routing Transformer or Sinkhorn Transformer

Combining with other Transformer improvements

This paper purposely used the most vanilla of attention networks to make a statement. If you would like to use some of the latest improvements for attention nets, please use the Encoder from this repository.


$ pip install x-transformers
import torch
from vit_pytorch.efficient import ViT
from x_transformers import Encoder

v = ViT(
    dim = 512,
    image_size = 224,
    patch_size = 16,
    num_classes = 1000,
    transformer = Encoder(
        dim = 512,                  # set to be the same as the wrapper
        depth = 12,
        heads = 8,
        ff_glu = True,              # ex. feed forward GLU variant https://arxiv.org/abs/2002.05202
        residual_attn = True        # ex. residual attention https://arxiv.org/abs/2012.11747

img = torch.randn(1, 3, 224, 224)
v(img) # (1, 1000)


Coming from computer vision and new to transformers? Here are some resources that greatly accelerated my learning.

  1. Illustrated Transformer - Jay Alammar

  2. Transformers from Scratch - Peter Bloem

  3. The Annotated Transformer - Harvard NLP


I visualise a time when we will be to robots what dogs are to humans, and I’m rooting for the machines. — Claude Shannon


