nicklashansen / svea-vit

Code for the paper "Stabilizing Deep Q-Learning with ConvNets and Vision Transformers under Data Augmentation"

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

SVEA: Stabilized Q-Value Estimation under Data Augmentation

This is a PyTorch implementation of SVEA using Vision Transformers from

Stabilizing Deep Q-Learning with ConvNets and Vision Transformers under Data Augmentation
Nicklas Hansen, Hao Su, Xiaolong Wang.

[Paper] [Webpage]

See this repository for SVEA implemented using ConvNets, as well as additional baselines and test environments.

video_hard

Citation

If you find our work useful in your research, please consider citing our work as follows:

@article{hansen2021stabilizing,
  title={Stabilizing Deep Q-Learning with ConvNets and Vision Transformers under Data Augmentation},
  author={Nicklas Hansen and Hao Su and Xiaolong Wang},
  year={2021},
  eprint={2107.00644},
  archivePrefix={arXiv},
  primaryClass={cs.LG}
}

Setup

We assume that you have access to a GPU with CUDA >=9.2 support. All dependencies can then be installed with the following commands:

conda env create -f setup/conda.yaml
conda activate svea
sh setup/install_envs.sh

Training & Evaluation

The scripts directory contains bash scripts for both SVEA and a DrQ baseline, which can be run by sh scripts/svea.sh and sh scripts/drq.sh, respectively.

Alternatively, you can call the python scripts directly, e.g. for training call

python3 src/train.py

to run SVEA on the default task, walker_walk, and using default hyperparameters. This should give you an output of the form:

Working directory: logs/walker_walk/svea/default/1
Evaluating: logs/walker_walk/svea/default/1
| eval | S: 0 | ER: 26.2285 | ERTEST: 25.3730
| train | E: 1 | S: 250 | D: 328.1 s | R: 0.0000 | ALOSS: 0.0000 | CLOSS: 0.0000 | AUXLOSS: 0.0000

where ER and ERTEST corresponds to the average return in the training and test environments, respectively. You can select the test environment used in evaluation with the --eval_mode argument, which accepts one of (original, color_easy, color_hard, video_easy, video_hard). Note that training time is approximately 9 days for 500k frames.

Acknowledgements

We want to thank the numerous researchers involved in work of which this implementation is based on. Our DrQ implementation is based on this repository, the original DMControl is available here, and the gym wrapper for it is available here. Data augmentations are based on this repository and kornia which is available here. Our ViT encoder is based on this implementation by Ross Wightman.

About

Code for the paper "Stabilizing Deep Q-Learning with ConvNets and Vision Transformers under Data Augmentation"

License:MIT License


Languages

Language:Python 99.3%Language:Shell 0.7%