zihangJiang / TokenLabeling

Pytorch implementation of "All Tokens Matter: Token Labeling for Training Better Vision Transformers"

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

All Tokens Matter: Token Labeling for Training Better Vision Transformers (arxiv)

This is a Pytorch implementation of our paper.

Compare

Comparison between the proposed LV-ViT and other recent works based on transformers. Note that we only show models whose model sizes are under 100M.

Our codes are based on the pytorch-image-models by Ross Wightman.

Update

2021.7: Add script to generate label data.

2021.6: Support pip install tlt to use our Token Labeling Toolbox for image models.

2021.6: Release training code and segmentation model.

2021.4: Release LV-ViT models.

LV-ViT Models

Model layer dim Image resolution Param Top 1 Download
LV-ViT-T 12 240 224 8.53M 79.1 link
LV-ViT-S 16 384 224 26.15M 83.3 link
LV-ViT-S 16 384 384 26.30M 84.4 link
LV-ViT-M 20 512 224 55.83M 84.0 link
LV-ViT-M 20 512 384 56.03M 85.4 link
LV-ViT-M 20 512 448 56.13M 85.5 link
LV-ViT-L 24 768 448 150.47M 86.2 link
LV-ViT-L 24 768 512 150.66M 86.4 link

Requirements

torch>=1.4.0 torchvision>=0.5.0 pyyaml scipy timm==0.4.5

data prepare: ImageNet with the following folder structure, you can extract imagenet by this script.

│imagenet/
├──train/
│  ├── n01440764
│  │   ├── n01440764_10026.JPEG
│  │   ├── n01440764_10027.JPEG
│  │   ├── ......
│  ├── ......
├──val/
│  ├── n01440764
│  │   ├── ILSVRC2012_val_00000293.JPEG
│  │   ├── ILSVRC2012_val_00002138.JPEG
│  │   ├── ......
│  ├── ......

Validation

Replace DATA_DIR with your imagenet validation set path and MODEL_DIR with the checkpoint path

CUDA_VISIBLE_DEVICES=0 bash eval.sh /path/to/imagenet/val /path/to/checkpoint

Label data

We provide NFNet-F6 generated dense label map in Google Drive and BaiDu Yun (password: y6j2). As NFNet-F6 are based on pure ImageNet data, no extra training data is involved.

Training

Train the LV-ViT-S:

If only 4 GPUs are available,

CUDA_VISIBLE_DEVICES=0,1,2,3 ./distributed_train.sh 4 /path/to/imagenet --model lvvit_s -b 256 --apex-amp --img-size 224 --drop-path 0.1 --token-label --token-label-data /path/to/label_data --token-label-size 14 --model-ema

If 8 GPUs are available:

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./distributed_train.sh 8 /path/to/imagenet --model lvvit_s -b 128 --apex-amp --img-size 224 --drop-path 0.1 --token-label --token-label-data /path/to/label_data --token-label-size 14 --model-ema

Train the LV-ViT-M and LV-ViT-L (run on 8 GPUs):

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./distributed_train.sh 8 /path/to/imagenet --model lvvit_m -b 128 --apex-amp --img-size 224 --drop-path 0.2 --token-label --token-label-data /path/to/label_data --token-label-size 14 --model-ema
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./distributed_train.sh 8 /path/to/imagenet --model lvvit_l -b 128 --lr 1.e-3 --aa rand-n3-m9-mstd0.5-inc1 --apex-amp --img-size 224 --drop-path 0.3 --token-label --token-label-data /path/to/label_data --token-label-size 14 --model-ema

If you want to train our LV-ViT on images with 384x384 resolution, please use --img-size 384 --token-label-size 24.

Fine-tuning

To Fine-tune the pre-trained LV-ViT-S on images with 384x384 resolution:

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./distributed_train.sh 8 /path/to/imagenet --model lvvit_s -b 64 --apex-amp --img-size 384 --drop-path 0.1 --token-label --token-label-data /path/to/label_data --token-label-size 24 --lr 5.e-6 --min-lr 5.e-6 --weight-decay 1.e-8 --finetune /path/to/checkpoint

To Fine-tune the pre-trained LV-ViT-S on other datasets without token labeling:

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./distributed_train.sh 8 /path/to/dataset --model lvvit_s -b 64 --apex-amp --img-size 224 --drop-path 0.1 --token-label --token-label-size 14 --dense-weight 0.0 --num-classes $NUM_CLASSES --finetune /path/to/checkpoint

Segmentation

Our Segmentation model are fully based upon the MMSegmentation Toolkit. The model and config files are under seg/ folder which follow the same folder structure. You can simply drop in these file to get start.

git clone https://github.com/open-mmlab/mmsegmentation # and install

cp seg/mmseg/models/backbones/vit.py mmsegmentation/mmseg/models/backbones/
cp -r seg/configs/lvvit mmsegmentation/configs/

# test upernet+lvvit_s (add --aug-test to test on multi scale)
cd mmsegmentation
./tools/dist_test.sh configs/lvvit/upernet_lvvit_s_512x512_160k_ade20k.py /path/to/checkpoint 8 --eval mIoU [--aug-test]
Backbone Method Crop size Lr Schd mIoU mIoU(ms) Pixel Acc. Param Download
LV-ViT-S UperNet 512x512 160k 47.9 48.6 83.1 44M link
LV-ViT-M UperNet 512x512 160k 49.4 50.6 83.5 77M link
LV-ViT-L UperNet 512x512 160k 50.9 51.8 84.1 209M link

Visualization

We apply the visualization method in this repo to visualize the parts of the image that led to a certain classification for DeiT-Base and our LV-ViT-S. The parts of the image that used by the network to make the decision are highlighted in red.

Compare

Label generation

To generate token label data for training:

python3 generate_label.py /path/to/imagenet/train /path/to/save/label_top5_train_nfnet --model dm_nfnet_f6 --pretrained --img-size 576 -b 32 --crop-pct 1.0

Reference

If you use this repo or find it useful, please consider citing:

@inproceedings{NEURIPS2021_9a49a25d,
 author = {Jiang, Zi-Hang and Hou, Qibin and Yuan, Li and Zhou, Daquan and Shi, Yujun and Jin, Xiaojie and Wang, Anran and Feng, Jiashi},
 booktitle = {Advances in Neural Information Processing Systems},
 editor = {M. Ranzato and A. Beygelzimer and Y. Dauphin and P.S. Liang and J. Wortman Vaughan},
 pages = {18590--18602},
 publisher = {Curran Associates, Inc.},
 title = {All Tokens Matter: Token Labeling for Training Better Vision Transformers},
 url = {https://proceedings.neurips.cc/paper/2021/file/9a49a25d845a483fae4be7e341368e36-Paper.pdf},
 volume = {34},
 year = {2021}
}

Related projects

T2T-ViT, Re-labeling ImageNet, MMSegmentation, Transformer Explainability.

About

Pytorch implementation of "All Tokens Matter: Token Labeling for Training Better Vision Transformers"

License:Apache License 2.0


Languages

Language:Jupyter Notebook 52.9%Language:Python 46.8%Language:Shell 0.3%