StomachCold / HCTransformers

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

HCTransformers

PWC

PyTorch implementation for "Attribute Surrogates Learning and Spectral Tokens Pooling in Transformers for Few-shot Learning".
[arxiv]

Code will be continuously updated.

HCT Network Architecture

Updates

07/07/2022

  1. Datasets description and guideline are updated.
  2. Features extracted by the pretrained models on our π’Žπ’Šπ’π’ŠImageNet is also provided here.

07/01/2022

Provided pretrained weights download links and evaluation command line.

Prerequisites

This codebase has been developed with Python version 3.8, PyTorch version 1.9.0, CUDA 11.1 and torchvision 0.10.0. It has been tested on Ubuntu 20.04.

Pretrained weights

Pretrained weights on π’Žπ’Šπ’π’ŠImageNet, π’•π’Šπ’†π’“π’†π’…ImageNet, CIFAR-FS and FC100 are available now. Note that for π’•π’Šπ’†π’“π’†π’…ImageNet and FC100 there are only checkpoints for the first stage (without cascaded training). Accuracy of 5-way 1-shot and 5-way 5-shot shown in the table is evaluated on the test split and for reference only.

dataset 1-shot 5-shot download
π’Žπ’Šπ’π’ŠImageNet 71.16% 84.60% checkpoints_first features_mini
π’•π’Šπ’†π’“π’†π’…ImageNet 79.67% 91.72% -
FC100 48.27% 66.42% -
CIFAR-FS 73.13% 86.36% -

Pretrained weights for the cascaded-trained models on π’Žπ’Šπ’π’ŠImageNet and CIFAR-FS are provided as follows. Note that the path to pretrained weight in the first stage must be specified when evaluating (see Evaluation).

dataset 1-shot 5-shot download
π’Žπ’Šπ’π’ŠImageNet 74.74% 89.19% checkpoints_pooling features_mini
CIFAR-FS 78.89% 90.50% -

Datasets

π’Žπ’Šπ’π’ŠImageNet

The π‘šπ‘–π‘›π‘–ImageNet dataset was proposed by Vinyals et al. for few-shot learning evaluation. Its complexity is high due to the use of ImageNet images but requires fewer resources and infrastructure than running on the full ImageNet dataset. In total, there are 100 classes with 600 samples of color images per class. These 100 classes are divided into 64, 16, and 20 classes respectively for sampling tasks for meta-training, meta-validation, and meta-test. To generate this dataset from ImageNet, you may use the repository π‘šπ‘–π‘›π‘–ImageNet tools.

Note that in our implemenation images are resized to 480 Γ— 480 because the data augmentation we used require the image resolution to be greater than 224 to avoid distortions. Therefore, when generating π’Žπ’Šπ’π’ŠImageNet, you should set --image_resize 0 to keep the original size or --image_resize 480 as what we did.

π’•π’Šπ’†π’“π’†π’…ImageNet

The π‘‘π‘–π‘’π‘Ÿπ‘’π‘‘ImageNet dataset is a larger subset of ILSVRC-12 with 608 classes (779,165 images) grouped into 34 higher-level nodes in the ImageNet human-curated hierarchy. To generate this dataset from ImageNet, you may use the repository π‘‘π‘–π‘’π‘Ÿπ‘’π‘‘ImageNet dataset: π‘‘π‘–π‘’π‘Ÿπ‘’π‘‘ImageNet tools.

Similar to π’Žπ’Šπ’π’ŠImageNet, you should set --image_resize 0 to keep the original size or --image_resize 480 as what we did when generating π’•π’Šπ’†π’“π’†π’…ImageNet.

Training

We provide the training code for π’Žπ’Šπ’π’ŠImageNet, π’•π’Šπ’†π’“π’†π’…ImageNet and CIFAR-FS, extending the DINO repo (link).

1 Pre-train the First Transformer

To pre-train the first Transformer with attribute surrogates learning on π’Žπ’Šπ’π’ŠImageNet from scratch with multiple GPU, run:

python -m torch.distributed.launch --nproc_per_node=8 main_hct_first.py --arch vit_small --data_path /path/to/mini_imagenet/train --output_dir /path/to/saving_dir

2 Train the Hierarchically Cascaded Transformers

To train the Hierarchically Cascaded Transformers with sprectral token pooling on π’Žπ’Šπ’π’ŠImageNet, run:

python -m torch.distributed.launch --nproc_per_node=8 main_hct_pooling.py --arch vit_small --data_path /path/to/mini_imagenet/train --output_dir /path/to/saving_dir --pretrained_weights /path/to/pretrained_weights

Evaluation

To evaluate the performance of the first Transformer on π’Žπ’Šπ’π’ŠImageNet 5-way 1-shot task, run:

python eval_hct_first.py --arch vit_small --server mini --partition test --checkpoint_key student --ckp_path /path/to/checkpoint_mini/ --num_shots 1

To evaluate the performance of the Hierarchically Cascaded Transformers on π’Žπ’Šπ’π’ŠImageNet 5-way 5-shot task, run:

python eval_hct_pooling.py --arch vit_small --server mini_pooling --partition val --checkpoint_key student --ckp_path /path/to/checkpoint_mini_pooling/  --pretrained_weights /path/to/pretrained_weights_of_first_satge --num_shots 5

License

This repository is released under the Apache 2.0 license as found in the LICENSE file.

Citation

If you find our code or paper useful to your research work, please consider citing our work using the following bibtex:

@inproceedings{he2022attribute,
  title={Attribute surrogates learning and spectral tokens pooling in transformers for few-shot learning},
  author={He, Yangji and Liang, Weihan and Zhao, Dongyang and Zhou, Hong-Yu and Ge, Weifeng and Yu, Yizhou and Zhang, Wenqiang},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
  pages={9119--9129},
  year={2022}
}

About

License:Apache License 2.0


Languages

Language:Python 100.0%