Paper: Ha and Schmidhuber, "World Models", 2018. https://doi.org/10.5281/zenodo.1207631
The implementation is based on Python3 and PyTorch, check their website here for installation instructions. The rest of the requirements is included in the requirements file, to install them:
pip3 install -r requirements.txt
The model is composed of three parts:
- A Variational Auto-Encoder (VAE), whose task is to compress the input images into a compact latent representation.
- A Mixture-Density Recurrent Network (MDN-RNN), trained to predict the latent encoding of the next frame given past latent encodings and actions.
- A linear Controller (C), which takes both the latent encoding of the current frame, and the hidden state of the MDN-RNN given past latents and actions as input and outputs an action. It is trained to maximize the cumulated reward using the Covariance-Matrix Adaptation Evolution-Strategy (CMA-ES) from the
cma
python package.
In the given code, all three sections are trained separately, using the scripts trainvae.py
, trainmdrnn.py
and traincontroller.py
.
Training scripts take as argument:
- --logdir : The directory in which the models will be stored. If the logdir specified already exists, it loads the old model and continues the training.
- --noreload : If you want to override a model in logdir instead of reloading it, add this option.
Before launching the VAE and MDN-RNN training scripts, you need to generate a dataset of random rollouts and place it in the datasets/carracing
folder.
Data generation is handled through the data/generation_script.py
script, e.g.
python data/generation_script.py --rollouts 1000 --dir datasets/carracing --threads 8
Rollouts are generated using a brownian random policy, instead of the white noise random action_space.sample()
policy from gym, providing more consistent rollouts.
The VAE is trained using the trainvae.py
file, e.g.
python trainvae.py --logdir exp_dir
The MDN-RNN is trained using the trainmdrnn.py
file, e.g.
python trainmdrnn.py --logdir exp_dir
A VAE must have been trained in the same exp_dir
for this script to work.
Finally, the controller is trained using CMA-ES, e.g.
python traincontroller.py --logdir exp_dir
You can test the obtained policy with test_controller.py
e.g.
python test_controller.py --logdir exp_dir
- Corentin Tallec - ctallec
- LĂ©onard Blier - leonardblier
- Diviyan Kalainathan - diviyan-kalainathan
This project is licensed under the MIT License - see the LICENSE.md file for details