gary109 / delta-iris

Efficient World Models with Context-Aware Tokenization. ICML 2024

Home Page:https://arxiv.org/abs/2406.19320

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Efficient World Models with Context-Aware Tokenization (Δ-IRIS)

Efficient World Models with Context-Aware Tokenization
Vincent Micheli*, Eloi Alonso*, François Fleuret

TL;DR Δ-IRIS is a reinforcement learning agent trained in the imagination of its world model.

Δ-IRIS agent alternatively playing in the environment and its world model (download here)
delta-iris.mp4

Setup

  • pip install pip==23.0
  • Install dependencies: pip install -r requirements.txt
  • Warning: Atari ROMs will be downloaded with the Atari dependencies, which means that you acknowledge that you have the license to use them.

Launch a training run

Crafter:

python src/main.py

The run will be located in outputs/YYYY-MM-DD/hh-mm-ss/.

By default, logs are synced to weights & biases, set wandb.mode=disabled to turn logging off.

Atari:

python src/main.py env=atari params=atari env.train.id=BreakoutNoFrameskip-v4

Note that this Atari configuration achieves slightly higher aggregate metrics than those reported in the paper. Here is the updated table of results.

Configuration

  • All configuration files are located in config/, the main configuration file is config/trainer.yaml.
  • The simplest way to customize the configuration is to edit these files directly.
  • Please refer to Hydra for more details regarding configuration management.

Run folder

Each new run is located in outputs/YYYY-MM-DD/hh-mm-ss/. This folder is structured as:

outputs/YYYY-MM-DD/hh-mm-ss/
│
└─── checkpoints
│   │   last.pt
│   │   optimizer.pt
│   │   ...
│   │
│   └── dataset
│      │
│      └─ train
│        │   info.pt
│        │   ...
│      │
│      └─ test
│        │   info.pt
│        │   ...
│
└─── config
│   │   trainer.yaml
│   │   ...
│
└─── media
│   │
│   └── episodes
│      │   ...
│   │
│   └── reconstructions
│      │   ...
│
└─── scripts
│   │   resume.sh
│   │   play.sh
│
└─── src
│   │   main.py
│   │   ...
│
└─── wandb
    │   ...
  • checkpoints: contains the last checkpoint of the model, its optimizer and the dataset.
  • media:
    • episodes: contains train / test episodes for visualization purposes.
    • reconstructions: contains original frames alongside their reconstructions with the autoencoder.
  • scripts: from the run folder, you can use the following scripts.
    • resume.sh: Launch ./scripts/resume.sh to resume a training run that crashed.
    • play.sh: Tool to visualize the agent and interact with the world model.
      • Launch ./scripts/play.sh to watch the agent play live in the environment.
      • Launch ./scripts/play.sh -w to play live in the world model. Note that for faster interaction, the memory of the world model is flushed after a few seconds.
      • Launch ./scripts/play.sh -a to watch the agent play live in the world model. Note that for faster interaction, the memory of the world model is flushed after a few seconds.
      • Launch ./scripts/play.sh -e to visualize the episodes contained in media/episodes.
      • Add the flag -h to display a header with additional information.

Pretrained agent

An agent checkpoint (Crafter 5M frames) is available here.

To visualize the agent or play in its world model:

  • Create a checkpoints directory
  • Copy the checkpoint to checkpoints/last.pt
  • Run ./scripts/play.sh with the flags of your choice as described above.

Cite

If you find this code or paper useful, please use the following reference:

@inproceedings{
micheli2024efficient,
title={Efficient World Models with Context-Aware Tokenization},
author={Vincent Micheli and Eloi Alonso and François Fleuret},
booktitle={Forty-first International Conference on Machine Learning},
year={2024},
url={https://openreview.net/forum?id=BiWIERWBFX}
}

Credits

About

Efficient World Models with Context-Aware Tokenization. ICML 2024

https://arxiv.org/abs/2406.19320

License:GNU General Public License v3.0


Languages

Language:Python 99.3%Language:Shell 0.7%