lucidrains / vit-pytorch

Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Accessing last layer hidden states or embeddings for models like CrossViT, RegionViT (Extractor doesn't seem to work)

PrithivirajDamodaran opened this issue · comments

How can I access the last layer hidden states aka embeddings of an image from models like CrossViT and RegionViT? The extractor option works only on vanilla ViT.

Please advice

@PrithivirajDamodaran Hi Prithivida! Let me know if 4e62e5f works now

regionvit can also work, if you pass in a reference to the layer whose output you would like to extract

import torch
from vit_pytorch.regionvit import RegionViT

model = RegionViT(
    dim = (64, 128, 256, 512),      # tuple of size 4, indicating dimension at each stage
    depth = (2, 2, 8, 2),           # depth of the region to local transformer at each stage
    window_size = 7,                # window size, which should be either 7 or 14
    num_classes = 1000,             # number of output classes
    tokenize_local_3_conv = False,  # whether to use a 3 layer convolution to encode the local tokens from the image. the paper uses this for the smaller models, but uses only 1 conv (set to False) for the larger models
    use_peg = False,                # whether to use positional generating module. they used this for object detection for a boost in performance
)

# wrap the CrossViT

from vit_pytorch.extractor import Extractor
v = Extractor(model, layer = model.layers[-1][-1])

# forward pass now returns predictions and the attention maps

img = torch.randn(1, 3, 224, 224)
logits, embeddings = v(img)

# there is one extra token due to the CLS token

embeddings # ((1, 512, 7, 7), (1, 512, 1, 1))

Thank you, will check and close. Big fan of your work.

Works fine! so just to be sure, the below tuple for a single image is

((1, 512, 7, 7) - last_layer emb
(1, 512, 1, 1)) - CLS emb

That's a right understanding?

@PrithivirajDamodaran so RegionViT is a bit different than the conventional neural net in that it keeps two separate information paths and have them cross attend to each other iirc

what you are seeing is the outputs of those two separate paths, one is for the normal network output, the other is the "regional" tokens

@PrithivirajDamodaran if you are doing anything downstream i would concat those two together for a 1024 dimensional embedding

from einops import reduce
embedding = torch.cat((reduce(fine_embed, 'b c h w -> b c', 'mean'), reduce(region_embed, 'b c h w -> b c', 'mean')), dim = -1)

excuse me what if i need to remove the last layer of the layer for the classification to get the features before classifying it ?

is there any help please ?