Transformers are Sample Efficient World Models
Vincent Micheli*, Eloi Alonso*, François Fleuret
* Denotes equal contribution
tl;dr
- IRIS is a data-efficient agent trained over millions of imagined trajectories in a world model.
- The world model is composed of a discrete autoencoder and an autoregressive Transformer.
- Our approach casts dynamics learning as a sequence modeling problem, where the autoencoder builds a language of image tokens and the Transformer composes that language over time.
If you find this code or paper useful, please use the following reference:
@article{iris2022,
title={Transformers are Sample Efficient World Models},
author={Micheli, Vincent and Alonso, Eloi and Fleuret, François},
journal={arXiv preprint arXiv:2209.00588},
year={2022}
}
- Install PyTorch (torch and torchvision). Code developed with torch==1.11.0 and torchvision==0.12.0.
- Install other dependencies:
pip install -r requirements.txt
- Warning: Atari ROMs will be downloaded with the dependencies, which means that you acknowledge that you have the license to use them.
python src/main.py env.train.id=BreakoutNoFrameskip-v4 common.device=cuda:0 wandb.mode=online
By default, the logs are synced to weights & biases, set wandb.mode=disabled
to turn it off.
- All configuration files are located in
config/
, the main configuration file isconfig/trainer.yaml
. - The simplest way to customize the configuration is to edit these files directly.
- Please refer to Hydra for more details regarding configuration management.
Each new run is located at outputs/YYYY-MM-DD/hh-mm-ss/
. This folder is structured as:
outputs/YYYY-MM-DD/hh-mm-ss/
│
└─── checkpoints
│ │ last.pt
| | optimizer.pt
| | ...
│ │
│ └─── dataset
│ │ 0.pt
│ │ 1.pt
│ │ ...
│
└─── config
│ | trainer.yaml
|
└─── media
│ │
│ └─── episodes
│ | │ ...
│ │
│ └─── reconstructions
│ | │ ...
│
└─── scripts
| | eval.py
│ │ play.sh
│ │ resume.sh
| | ...
|
└─── src
| | ...
|
└─── wandb
| ...
checkpoints
: contains the last checkpoint of the model, its optimizer and the dataset.media
:episodes
: contains train / test / imagination episodes for visualization purposes.reconstructions
: contains original frames alongside their reconstructions with the autoencoder.
scripts
: from the run folder, you can use the following three scripts.eval.py
: Launchpython ./scripts/eval.py
to evaluate the run.resume.sh
: Launch./scripts/resume.sh
to resume a training that crashed.play.sh
: Tool to visualize some interesting aspects of the run.- Launch
./scripts/play.sh
to watch the agent play live in the environment. If you add the flag-r
, the left panel displays the original frame, the center panel displays the same frame downscaled to the input resolution of the discrete autoencoder, and the right panel shows the output of the autoencoder (what the agent actually sees). - Launch
./scripts/play.sh -w
to unroll live trajectories with your keyboard inputs (i.e. to play in the world model). Note that for faster interaction, the memory of the Transformer is flushed every 20 frames. - Launch
./scripts/play.sh -a
to watch the agent play live in the world model. Note that for faster interaction, the memory of the Transformer is flushed every 20 frames. - Launch
./scripts/play.sh -e
to visualize the episodes contained inmedia/episodes
. - Add the flag
-h
to display a header with additional information. - Press '
,
' to start and stop recording. The corresponding segment is saved inmedia/recordings
in mp4 and numpy formats. - Add the flag
-s
to enter 'save mode', where the user is prompted to save trajectories upon completion.
- Launch
The folder results/data/
contains raw scores (for each game, and for each training run) for IRIS and the baselines.
Use the notebook results/results_iris.ipynb
to reproduce the figures from the paper.
Pretrained models are available here.
-
To start a training run from one of these checkpoints, in the section
initialization
ofconfig/trainer.yaml
, setpath_to_checkpoint
to the corresponding path, andload_tokenizer
,load_world_model
, andload_actor_critic
toTrue
. -
To visualize one of these checkpoints, set
train.id
to the corresponding game inconfig/env/default.yaml
, create acheckpoints
directory and copy the checkpoint tocheckpoints/last.pt
. You can then visualize the agent with./scripts/play.sh
as described above.