jacobgil / vit-explain

Explainability for Vision Transformers

Repository from Github https://github.comjacobgil/vit-explainRepository from Github https://github.comjacobgil/vit-explain

Cannot use block.attn.fused_attn = False in another ViT model

kristosh opened this issue · comments

I am trying to run the code for another ViT model, and more specifically:

    #  Get pretrained weights for ViT-Base
    retrained_vit_weights = torchvision.models.ViT_B_16_Weights.DEFAULT # requires torchvision >= 0.13, "DEFAULT" means best available
    pretrained_vit = torchvision.models.vit_b_16(weights=retrained_vit_weights).to(device)
    
    #pretrained_vit_1 = torch.hub.load('facebookresearch/deit:main', 'deit_tiny_patch16_224', pretrained=True).to(device)

In this model I have noticed that I cannot use the following code:

for block in model.blocks:
            block.attn.fused_attn = False
            

Since the model does not have the same structure as deit_tiny_patch16_224 one. I am also sure how to do the same fused_attn in this mode. Can you explain a bit what this code is about and why it procures different results when I comment it out?