eminorhan / mae

Masked Autoencoders for image-based SSL

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Masked Autoencoders (MAE)

This is my personal copy of Facebook's Masked Autoencoders (MAE) repository for image-based SSL customized for my own purposes. The code here can be used to train and evaluate MAEs.

Usage examples

  • Training: To train an MAE model with a ViT-S/16 architecture from scratch on your data, use train_mae.py:
python -u train_mae.py \
	--model 'mae_vit_small_patch16' \
	--batch_size_per_gpu 512 \
	--num_workers 16 \
	--lr 0.0003 \
	--min_lr 0.0003 \
	--output_dir OUTPUT_DIR \
	--data_path DATA_PATH \
	--save_prefix INFORMATIVE_SAVE_PREFIX

This version uses the webdataset interface to feed the data into the model. There's a separete training file that uses the standard torch-torchvision data loading interface instead, if you'd prefer to use that: train_mae_nowds.py.

  • Linear evaluation: To evaluate an MAE model with the linear probing approach, use eval_linear.py:
python -u eval_linear.py \
	--model 'vit_small_patch16' \
	--resume MODEL_PATH \
	--save_prefix INFORMATIVE_SAVE_PREFIX \
	--batch_size 1024 \
	--epochs 50 \
	--num_workers 16 \
	--lr 0.0003 \
	--output_dir OUTPUT_DIR \
	--train_data_path TRAIN_DATA_PATH \
	--val_data_path VAL_DATA_PATH \
	--num_labels 1000
  • Finetuning evaluation: To evaluate an MAE model with the finetuning approach, use eval_finetune.py:
python -u eval_finetune.py \
	--model 'vit_small_patch16' \
	--resume MODEL_PATH \
	--save_prefix INFORMATIVE_SAVE_PREFIX \
	--batch_size 128 \
	--epochs 50 \
	--num_workers 16 \
	--lr 0.0001 \
	--output_dir OUTPUT_DIR \
	--train_data_path TRAIN_DATA_PATH \
	--val_data_path VAL_DATA_PATH \
	--frac_retained 0.010147 \
	--num_labels 1000

Here frac_retained is the fraction of the training set used for finetuning and can be set to do few-shot finetuning evals (e.g. --frac_retained 0.01 corresponds to finetuning with 1% of the training data, i.e. 12-13 examples per class in the case of ImageNet).

About

Masked Autoencoders for image-based SSL

License:Other


Languages

Language:Python 74.8%Language:Shell 25.2%