lucidrains / vit-pytorch

Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Loading MAE pre-trained model for image classification fine tuning.

jdegange opened this issue · comments

Hi. I much appreciate the repo. Has accelerated some of my work on experimenting with ViTs.

I have pre-trained a model using MAE and the dogs-vs-cats dataset provided in the example notebook. I use the following model definition:

v = ViT(
    image_size = img_size,
    patch_size = 32,
    num_classes = 2,
    dim = 1024,
    depth = 6,
    heads = 8,
    mlp_dim = 2048
)

mae = MAE(
    encoder = v,
    masking_ratio = 0.75,   # the paper recommended 75% masked patches
    decoder_dim = 512,      # paper showed good results with just 512
    decoder_depth = 6       # anywhere from 1 to 8
)

mae.to(device)

I then pre=train the model for a few epochs, and save the weights:

preTrainSaveOutputPath = './preTrained-vitMAE-dogsVsCats.pt'
torch.save(v.state_dict(), preTrainSaveOutputPath)

After, I load the weights from disk and instantiate a binary classifier for fine-tuning using the original dog/cat labels:

preTrainedModel = ViT(
    image_size = img_size,
    patch_size = 32,
    num_classes = 2,
    dim = 1024,
    depth = 6,
    heads = 8,
    mlp_dim = 2048
)

preTrainedModel.load_state_dict(torch.load(preTrainSaveOutputPath))

preTrainedModel.to(device)

Finally, I fine-tune the model for several epochs. When I do this, the loss does not go down, even after several hundred epochs, and the performance is actually worse than training the model on the dog/cat labels from scratch.

Could you let me know if I am doing something incorrectly here? Have not been able to figure out after several days on my own, and appreciate the help. Likely something trivial I am missing.