AssertionError happens when loading the model in the workflow of “finetuning-then-linprobing”.
uk9921 opened this issue · comments
I used the script main_finetune.py to finetune the pretrained model, and the process went very smoothly. However, when I tried to load the finetuned model and train a linear probe task, I got this AssertionError:
File "main_linprobe.py", line 203, in main assert set(msg.missing_keys) == {'head.weight', 'head.bias', 'fc_norm.weight', 'fc_norm.bias'}
I printed the msg.missing_keys
and got msg.missing_keys = []
So, I wonder if we need to assert the missing keys when we try to load the finetuned model?
Here is my training args
main_linprobe.py \
--batch_size 128 \
--model vit_large_patch16 \
--global_pool \
--finetune ${PRETRAIN_CHKPT} \
--epochs 90 \
--blr 0.05 \
--weight_decay 0.0 \
--output_dir ${OUTPUT_DIR} \
--data_path ${IMAGENET_DIR} \
--dist_eval
The linear probing script is designed to load a pre-trained model (without the linear classification head and norm layer before it). Therefore, we have an assertion there to make sure the loaded model does not have those parameters. If you want to linear probe a model with the classification head, I think you can simply comment out the assertion.
Thank you for your reply, I followed your suggestion and achieve a higher top1 acc than directly loading the pre-trained model.