LTH14 / mage

A PyTorch implementation of MAGE: MAsked Generative Encoder to Unify Representation Learning and Image Synthesis

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

AssertionError happens when loading the model in the workflow of “finetuning-then-linprobing”.

uk9921 opened this issue · comments

commented

I used the script main_finetune.py to finetune the pretrained model, and the process went very smoothly. However, when I tried to load the finetuned model and train a linear probe task, I got this AssertionError:
File "main_linprobe.py", line 203, in main assert set(msg.missing_keys) == {'head.weight', 'head.bias', 'fc_norm.weight', 'fc_norm.bias'}
I printed the msg.missing_keys and got msg.missing_keys = []

So, I wonder if we need to assert the missing keys when we try to load the finetuned model?
image

commented

Here is my training args

    main_linprobe.py \
    --batch_size 128 \
    --model vit_large_patch16 \
    --global_pool \
    --finetune ${PRETRAIN_CHKPT} \
    --epochs 90 \
    --blr 0.05 \
    --weight_decay 0.0 \
    --output_dir ${OUTPUT_DIR} \
    --data_path ${IMAGENET_DIR} \
    --dist_eval

The linear probing script is designed to load a pre-trained model (without the linear classification head and norm layer before it). Therefore, we have an assertion there to make sure the loaded model does not have those parameters. If you want to linear probe a model with the classification head, I think you can simply comment out the assertion.

commented

Thank you for your reply, I followed your suggestion and achieve a higher top1 acc than directly loading the pre-trained model.