JJack27 / fast-autoaugment

Official Implementation of 'Fast AutoAugment' in PyTorch.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Fast AutoAugment (Accepted at NeurIPS 2019)

Official Fast AutoAugment implementation in PyTorch.

  • Fast AutoAugment learns augmentation policies using a more efficient search strategy based on density matching.
  • Fast AutoAugment speeds up the search time by orders of magnitude while maintaining the comparable performances.

Results

CIFAR-10 / 100

Search : 3.5 GPU Hours (1428x faster than AutoAugment), WResNet-40x2 on Reduced CIFAR-10

Model(CIFAR-10) Baseline Cutout AutoAugment Fast AutoAugment
(transfer/direct)
Wide-ResNet-40-2 5.3 4.1 3.7 3.6 / 3.7 Download
Wide-ResNet-28-10 3.9 3.1 2.6 2.7 / 2.7 Download
Shake-Shake(26 2x32d) 3.6 3.0 2.5 2.7 / 2.5 Download
Shake-Shake(26 2x96d) 2.9 2.6 2.0 2.0 / 2.0 Download
Shake-Shake(26 2x112d) 2.8 2.6 1.9 2.0 / 1.9 Download
PyramidNet+ShakeDrop 2.7 2.3 1.5 1.8 / 1.7
Model(CIFAR-100) Baseline Cutout AutoAugment Fast AutoAugment
(transfer/direct)
Wide-ResNet-40-2 26.0 25.2 20.7 20.7 / 20.6 Download
Wide-ResNet-28-10 18.8 18.4 17.1 17.3 / 17.3 Download
Shake-Shake(26 2x96d) 17.1 16.0 14.3 14.9 / 14.6 Download
PyramidNet+ShakeDrop 14.0 12.2 10.7 11.9 / 11.7

ImageNet

Search : 450 GPU Hours (33x faster than AutoAugment), ResNet-50 on Reduced ImageNet

Model Baseline AutoAugment Fast AutoAugment
(Top1/Top5)
ResNet-50 23.7 / 6.9 22.4 / 6.2 22.4 / 6.3 Download
ResNet-200 21.5 / 5.8 20.0 / 5.0 19.4 / 4.7

SVHN Test

Search : 1.5 GPU Hours

Baseline AutoAug / Our Fast AutoAugment
Wide-Resnet28x10 1.5 1.1 1.1

Run

We conducted experiments under

  • python 3.6.9
  • pytorch 1.2.0, torchvision 0.4.0, cuda10

Search a augmentation policy

Please read ray's document to construct a proper ray cluster : https://github.com/ray-project/ray, and run search.py with the master's redis address.

$ python search.py -c confs/wresnet40x2_cifar10_b512.yaml --dataroot ... --redis ...

Train a model with found policies

You can train network architectures on CIFAR-10 / 100 and ImageNet with our searched policies.

  • fa_reduced_cifar10 : reduced CIFAR-10(4k images), WResNet-40x2
  • fa_reduced_imagenet : reduced ImageNet(50k images, 120 classes), ResNet-50
$ export PYTHONPATH=$PYTHONPATH:$PWD
$ python FastAutoAugment/train.py -c confs/wresnet40x2_cifar10_b512.yaml --aug fa_reduced_cifar10 --dataset cifar10
$ python FastAutoAugment/train.py -c confs/wresnet40x2_cifar10_b512.yaml --aug fa_reduced_cifar10 --dataset cifar100
$ python FastAutoAugment/train.py -c confs/wresnet28x10_cifar10_b512.yaml --aug fa_reduced_cifar10 --dataset cifar10
$ python FastAutoAugment/train.py -c confs/wresnet28x10_cifar10_b512.yaml --aug fa_reduced_cifar10 --dataset cifar100
...
$ python FastAutoAugment/train.py -c confs/resnet50_b512.yaml --aug fa_reduced_imagenet
$ python FastAutoAugment/train.py -c confs/resnet200_b512.yaml --aug fa_reduced_imagenet

By adding --only-eval and --save arguments, you can test trained models without training.

Citation

If you use any part of this code in your research, please cite our paper.

@inproceedings{lim2019fast,
  title={Fast AutoAugment},
  author={Lim, Sungbin and Kim, Ildoo and Kim, Taesup and Kim, Chiheon and Kim, Sungwoong},
  booktitle={Advances in Neural Information Processing Systems (NeurIPS)},
  year={2019}
}

Contact for Issues

References & Opensources

We increase the batch size and adapt the learning rate accordingly to boost the training. Otherwise, we set other hyperparameters equal to AutoAugment if possible. For the unknown hyperparameters, we follow values from the original references or we tune them to match baseline performances.

About

Official Implementation of 'Fast AutoAugment' in PyTorch.

License:MIT License


Languages

Language:Python 100.0%