tgisaturday / aws_autoaugment

Improving Auto-Augment via Augmentation-Wise Weight Sharing

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Improving Auto-Augment via Augmentation-Wise Weight Sharing

Unofficial AWS AutoAugment implementation in PyTorch.

  • AWS AutoAugment learns augmentation policies using augmentation-wise shared model weights

To-do

Essentials

  • Baseline structure
  • Augmentation list
  • Shared policy
  • Augmentation-wise shared model weights
  • PPO + baseline trick
  • Search code
  • Training code
  • Enlarge Batch (EB)
  • CIFAR100 WRN
  • CIFAR100 Shake-Shake
  • CIFAR100 PyramidNet+ShakeDrop

Possible Modification

  • Faster action recorder with pickle (currently using txt)
  • Distributed training (currently using nn.DataParallel)
  • Search with EB
  • Random Search
  • Policy Gradient
  • Stocastic Depth
  • CIFAR10
  • ImageNet

Future Works

  • Incremental Searching with Operation Embedding Sharing (CIFAR10 -> CIFAR100 -> ImageNet)
  • FastAugment + AWS
  • ProxylessNAS + AWS
  • Gradient-basedNAS + AWS

Results

CIFAR 100

Search : 120 GPU Hours, WResNet-28-10 on CIFAR100

  • Searched with Cutout after AWSAugment
Model(CIFAR-100) Baseline Cutout AWSAugment AWS + EB
Wide-ResNet-28-10 20.04 19.81 20.18 20.16
Shake-Shake(26 2x32d) 20.81 19.05 20.39 20.31
PyramidNet+ShakeDrop - - - -

Run

We conducted experiments under

  • python 3.7.0
  • pytorch 1.6.0, torchvision 0.5.0, cuda10

Search a augmentation policy

$ python AWSAutoAugment/search.py --path ... --dataroot ...

Train a model with found policies

$ python AWSAutoAugment/train.py --path ... --dataroot ... --policy_checkpoint ... 

$ python AWSAutoAugment/train.py --path ... --dataroot ... --policy_checkpoint ... --model wresnet28_10 --no_aug --cutout 0

$ python AWSAutoAugment/train.py --path ... --dataroot ... --policy_checkpoint ... --model wresnet28_10 --no_aug --cutout 16

$ python AWSAutoAugment/train.py --path ... --dataroot ... --policy_checkpoint ... --model wresnet28_10

$ python AWSAutoAugment/train.py --path ... --dataroot ... --policy_checkpoint ... --model wresnet28_10 --enlarge_batch

$ python AWSAutoAugment/train.py --path ... --dataroot ... --policy_checkpoint ... --model shakeshake26_2x32d --batch_size 128 --n_epochs 1800 --init_lr 0.01 --weight_decay 0.001 --no_aug --cutout 0

$ python AWSAutoAugment/train.py --path ... --dataroot ... --policy_checkpoint ... --model shakeshake26_2x32d --batch_size 128 --n_epochs 1800 --init_lr 0.01 --weight_decay 0.001 --no_aug --cutout 16

$ python AWSAutoAugment/train.py --path ... --dataroot ... --policy_checkpoint ... --model shakeshake26_2x32d --batch_size 128 --n_epochs 1800 --init_lr 0.01 --weight_decay 0.001 

$ python AWSAutoAugment/train.py --path ... --dataroot ... --policy_checkpoint ... --model shakeshake26_2x32d --batch_size 128 --n_epochs 1800 --init_lr 0.01 --weight_decay 0.001 --enlarge_batch

$ python AWSAutoAugment/train.py --path ... --dataroot ... --policy_checkpoint ... --model pyramid --batch_size 64 --n_epochs 1800 --init_lr 0.05 --weight_decay 0.00005 --no_aug --cutout 0

$ python AWSAutoAugment/train.py --path ... --dataroot ... --policy_checkpoint ... --model pyramid --batch_size 64 --n_epochs 1800 --init_lr 0.05 --weight_decay 0.00005 --no_aug --cutout 16

$ python AWSAutoAugment/train.py --path ... --dataroot ... --policy_checkpoint ... --model pyramid --batch_size 64 --n_epochs 1800 --init_lr 0.05 --weight_decay 0.00005

$ python AWSAutoAugment/train.py --path ... --dataroot ... --policy_checkpoint ... --model pyramid --batch_size 64 --n_epochs 1800 --init_lr 0.05 --weight_decay 0.00005 --enlarge_batch

References & Opensources

We increase the batch size and adapt the learning rate accordingly to boost the training. Otherwise, we set other hyperparameters equal to AutoAugment if possible. For the unknown hyperparameters, we follow values from the original references or we tune them to match baseline performances.

About

Improving Auto-Augment via Augmentation-Wise Weight Sharing


Languages

Language:Python 100.0%