This repository is the official implementation of Graph Masked Autoencoder Enhanced Predictor for Neural Architecture Search. The supplementary materials is available here.
The overview of GMAE.
The architecture space and its representation..
If you use any part of this code in your research, please cite our paper:
@inproceedings{jing2022gmae,
title={Graph Masked Autoencoder Enhanced Predictor for Neural Architecture Search},
author={Kun Jing, Jungang Xu, and Pengfei Li},
booktitle={Proc. of IJCAI},
year={2022}
}
Python>=3.8, PyTorch==1.9.0, torchvision==0.10.0, torch-cluster==1.5.9, torch-geometric==1.7.2, torch-scatter==2.0.8, torch-sparse==0.6.12
The default path of the NAS benchmark data and CIFAR-10 dataset is in ~/datasets.
Using gen_json_301.py to generate 100,000 architectures in NAS-Bench-301 for pre-training.
To pre-train the GMAE, run this command:
python pretrain.py --space [space_name] --save [save path] [--finetune_fixed_encoder] --param local:[json filename of params] --finetune_train_samples [number of queries] --finetune_batch_size [recommend to set it to one tenth of number of queries] [...]
To fine-tune or train the predictor, run this command:
python finetune.py --space [space_name] --save [save path] [--finetune_fixed_encoder] --param local:[json filename of params] --finetune_train_samples [number of queries] --finetune_batch_size [recommend to set it to one tenth of number of queries] --model_state_dict [model state dict name; ignore it when training from scratch] [...]
Run this command for search:
python search.py --search_space [space_name] --search_algo [algorithm name] --dataset [dataset name] --outdir [output dir] --encoder_model_state_dict [state dict path] [...]
To train the discovered architecture on CIFAR-10 in the paper, run this command:
python train.py --arch [your architecture] [...]
Script | description |
---|---|
nni4nasbenchxxxyyy.sh/.py | search hyper-params for yyy (including pre-training, fine-tuning, and training from scratch) on NAS-Bench-xxx. Note: Using the same hyper-params for DARTS space as the NAS-Bench-301. |
searchxxx.sh | search on NAS-Bench-xxx. |
pretrain_task.sh | Ablation study for choice of pre-training methods. |
masking_ratio.sh | Ablation study for choice of masking ratios. |
pretrain_target.sh | Ablation study for choice of objective functions. |
encoder_decoder.sh | Ablation study for choice of models. |
finetune_target.sh | Ablation study for choice of fine-tuning targets. |
finetune_mode.sh | Ablation study for choice of fine-tuning modes. |
The search process on NAS-Bench-101 (left) and NAS-Bench-301 (right).
The architecture discovered by GMAE-NAS (AE).
The architecture discovered by GMAE-NAS (BO).
Architecture | Test Error (%) | #Params (M) | Search Cost (GPU days) | Search Type |
---|---|---|---|---|
GMAE-NAS (AE) | 2.56 +- 0.04 | 4.0 | 3.3 | AE+neural predictor |
GMAE-NAS (BO) | 2.50 +- 0.03 | 3.8 | 3.3 | BO+neural predictor |