Code for Google's ViT and complete example
sayakpaul opened this issue · comments
Hi @jacobgil!
Thank you for this amazing piece of work. I was wondering if you plan to open-source the code to try out your experiments on Google's ViT (An Image is Worth ...) as well. If it's already there inside the repo, could you point me to it?
Update: I was able to use timm
and make use of the ViT model it comes with:
timm_vit_model = timm.create_model('vit_large_patch16_384', pretrained=True)
timm_vit_model.eval()
roller = VITAttentionGradRollout(timm_vit_model, discard_ratio=0.9)
mask = roller(x.unsqueeze(0), label_idx)
However, I am still a bit unsure as to how to actually visualize the mask. Could you help?
Hi,
In vit_explain.py there is an example.
Once you have the mask you can do
mask = show_mask_on_image(img, mask).
Did it work out?
It does thanks.