Luffy03 / SatMAE

Official code repository for NeurIPS 2022 paper "SatMAE: Pretraining Transformers for Temporal and Multi-Spectral Satellite Imagery"

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

SatMAE (NeurIPS 2022)

Project | Paper | Video

This is the official repository for the NeurIPS 2022 paper "SatMAE: Pre-training Transformers for Temporal and Multi-Spectral Satellite Imagery".

Authors: Yezhen Cong 1, Samar Khanna 1, Chenlin Meng, Patrick Liu, Erik Rozi, Yutong He, Marshall Burke, David B. Lobell, Stefano Ermon.

1 Equal contribution, order determined via coin flip.

Temporal SatMAE

Pre-training and finetuning on fMoW-Temporal are MEMORY-HEAVY. Please make sure you have enough memory. For context, we ran our experiments on 8 NVIDIA V100 GPUs.


You can download the fMoW dataset here. Then follow this piece of code to preprocess it. We will also upload the pre-processed dataset soon. The metadata files are here.

After you download the dataset and metadata files, your directory should look like:

--- train_62classes.csv
--- val_62classes.csv
--- fmow
------- train
---------- airport
---------- ...
------- val
---------- airport
---------- ...


For pretraining, this is the default command:

python -m torch.distributed.launch --nproc_per_node=8 \
    --nnodes=1 --master_port=1234 \
    --batch_size 8 --accum_iter 16 \
    --norm_pix_loss --epochs 100 \
    --blr 1.5e-4 --mask_ratio 0.75 \
    --input_size 224 --patch_size 16 \
    --model mae_vit_large_patch16 \
    --model_type temporal \
    --dataset_type temporal \
    --train_path <PATH_TO_DATASET_ROOT_FOLDER>/train_62classes.csv
    --output_dir <PATH_TO_YOUR_OUTPUT_FOLDER> \
    --log_dir <PATH_TO_YOUR_OUTPUT_FOLDER> \
    --num_workers 8

Note that if you want to use wandb, login to wandb after activating conda and before running the code by doing wandb login in the shell, and add --wandb <YOUR_WANDB_PROJECT_NAME> to the command above. This applies to all following commands. You will also have to edit the entity name in and


To finetune, the basic command is:

python -m torch.distributed.launch --nproc_per_node=8 \
    --nnodes=1 --master_port=1234 \
    --output_dir <PATH_TO_YOUR_OUTPUT_FOLDER> \
    --log_dir <PATH_TO_YOUR_OUTPUT_FOLDER> \
    --batch_size 4 --accum_iter 4 \
    --model vit_large_patch16 --epochs 50 --blr 1e-3 --layer_decay 0.75 \
    --weight_decay 0.05 --drop_path 0.2 --reprob 0.25 \
    --mixup 0.8 --cutmix 1.0 --model_type temporal \
    --dist_eval --num_workers 8 --dataset temporal \
    --train_path <PATH_TO_DATASET_ROOT_FOLDER>/train_62classes.csv \
    --test_path <PATH_TO_DATASET_ROOT_FOLDER>/val_62classes.csv

Note: If you are using our provided checkpoint, please add --nb_classes 1000. This is a legacy issue which won't affect the model performance since the actual number of classes is less than 1000. To resume a finetuning job, you can replace the --finetune <PATH_TO_YOUR_PRETRAIN_CHECKPOINT> to --resume <PATH_TO_YOUR_PRETRAIN_CHECKPOINT> in the command above. To finetune from scratch, simply omit the --finetune argument.


To evaluate, the basic command is:

python -m torch.distributed.launch --nproc_per_node=8 \
    --nnodes=1 --master_port=1234 \
    --output_dir <PATH_TO_YOUR_OUTPUT_FOLDER> \
    --log_dir <PATH_TO_YOUR_OUTPUT_FOLDER> \
    --batch_size 16 \
    --model vit_large_patch16 \
    --model_type temporal \
    --resume <PATH_TO_YOUR_FINEtune_CHECKPOINT>  \
    --dist_eval --eval --num_workers 8 --dataset fmow_temporal \
    --train_path <PATH_TO_DATASET_ROOT_FOLDER>/train_62classes.csv \
    --test_path <PATH_TO_DATASET_ROOT_FOLDER>/val_62classes.csv

Similarly, you are using our provided checkpoint, please add --nb_classes 1000.

Model Weights

You can download model weights pre-trained on fMoW-temporal and weights finetuned on fMoW-temporal here.

Multi-Spectral SatMAE

Training multi-spectral SatMAE is similar to training temporal SatMAE.

fMoW-Sentinel Dataset

You can access and download the fMoW-Sentinel dataset we collected here. Try this link if the previous one doesn't display correctly.

Note that when loading the train.csv or val.csv files, you may have to preprocess a column called image_path. The image_path for any row can be constructed like this:



For pretraining, this is the default command:

python -m torch.distributed.launch --nproc_per_node=8 \
--wandb satmae_pretrain \
--batch_size 16 --accum_iter 32 --blr 0.0001 \
--epochs 200 --warmup_epochs 20 --num_workers 16 \
--input_size 96 --patch_size 8 \
--model_type group_c \
--dataset_type sentinel --dropped_bands 0 9 10 \
--train_path /home/fmow-sentinel-filtered-csv/train.csv \
--output_dir /home/experiments/pretrain \
--log_dir /home/experiments/pretrain

For an example of additional arguments, you can specify them like so:

python -m torch.distributed.launch --nproc_per_node=8 \
--wandb satmae_pretrain \
--batch_size 16 --accum_iter 32 --blr 0.0001 \
--epochs 200 --warmup_epochs 20 --num_workers 16 \
--input_size 96 --patch_size 8 \
--mask_ratio 0.9 --spatial_mask \
--norm_pix_loss \
--model_type group_c \
--dataset_type sentinel --dropped_bands 0 9 10 \
--grouped_bands 0 1 2 6 --grouped_bands 3 4 5 7 --grouped_bands 8 9 \
--train_path /home/fmow-sentinel-filtered-csv/train.csv \
--output_dir /home/experiments/pretrain \
--log_dir /home/experiments/pretrain \
--resume /home/experiments/pretrain/checkpoint-175.pth \


To finetune, the basic command is:

python -m torch.distributed.launch --nproc_per_node=8 \
--wandb satmae_finetune \
--batch_size 8 --accum_iter 16 --blr 0.0002 \
--epochs 30 --num_workers 16 \
--input_size 96 --patch_size 8  \
--weight_decay 0.05 --drop_path 0.2 --reprob 0.25 --mixup 0.8 --cutmix 1.0 \
--model_type group_c  \
--dataset_type sentinel --dropped_bands 0 9 10 \
--train_path /home/fmow-sentinel-filtered-csv/train.csv \
--test_path /home/fmow-sentinel-filtered-csv/val.csv \
--output_dir /home/experiments/finetune \
--log_dir /home/experiments/finetune \
--finetune /home/experiments/pretain/checkpoint-199.pth

To finetune from scratch, simply omit the --finetune argument. To resume a finetuning job, you can replace --finetune path/to/pretrained_checkpoint.pth with --resume path/to/finetune_checkpoint.pth in the command above.

Model Weights



Code from this repository is inspired from the Masked Autoencoders (MAE) repository (link).


If you found our project helpful, please cite our paper:

    title={Sat{MAE}: Pre-training Transformers for Temporal and Multi-Spectral Satellite Imagery},
    author={Yezhen Cong and Samar Khanna and Chenlin Meng and Patrick Liu and Erik Rozi and Yutong He and Marshall Burke and David B. Lobell and Stefano Ermon},
    booktitle={Thirty-Sixth Conference on Neural Information Processing Systems},


Official code repository for NeurIPS 2022 paper "SatMAE: Pretraining Transformers for Temporal and Multi-Spectral Satellite Imagery"


Language:Python 100.0%