mila-iqia / SGI

Official code for "Pretraining Representations For Data-Efficient Reinforcement Learning" (NeurIPS 2021)

Home Page:https://arxiv.org/abs/2106.04799

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Pretraining Representations For Data-Efficient Reinforcement Learning

Max Schwarzer, Nitarshan Rajkumar, Michael Noukhovitch, Ankesh Anand, Laurent Charlin, Devon Hjelm, Philip Bachman & Aaron Courville

This repo provides code for implementing SGI.

Install

To install the requirements, follow these steps:

# PyTorch
export LANG=C.UTF-8
# Install requirements
pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html
pip install -r requirements.txt

# Finally, install the project
pip install --user -e .

Usage:

The default branch for the latest and stable changes is release.

To run SGI:

  1. Use the helper script to download and parse checkpoints from the DQN Replay Dataset; this requires gsutil to be installed. You may want to modify the script to download fewer checkpoints from fewer games, as otherwise this requires significant storage.
    • Or substitute your own pre-training data! The codebase expects a series of .gz files, one each for observations, actions and terminals.
bash scripts/download_replay_dataset.sh $DATA_DIR
  1. To pretrain with SGI:
python -m scripts.run public=True model_folder=./ offline.runner.save_every=2500 \
    env.game=pong seed=1 offline_model_save={your model name} \
    offline.runner.epochs=10 offline.runner.dataloader.games=[Pong] \
    offline.runner.no_eval=1 \
    +offline.algo.goal_weight=1 \
    +offline.algo.inverse_model_weight=1 \
    +offline.algo.spr_weight=1 \
    +offline.algo.target_update_tau=0.01 \
    +offline.agent.model_kwargs.momentum_tau=0.01 \
    do_online=False \
    algo.batch_size=256 \
    +offline.agent.model_kwargs.noisy_nets_std=0 \
    offline.runner.dataloader.dataset_on_disk=True \
    offline.runner.dataloader.samples=1000000 \
    offline.runner.dataloader.checkpoints='{your checkpoints}' \
    offline.runner.dataloader.num_workers=2 \
    offline.runner.dataloader.data_path={your data dir} \
    offline.runner.dataloader.tmp_data_path=./ 
  1. To fine-tune with SGI:
python -m scripts.run public=True env.game=pong seed=1 num_logs=10  \
    model_load={your_model_name} model_folder=./ \
    algo.encoder_lr=0.000001 algo.q_l1_lr=0.00003 algo.clip_grad_norm=-1 algo.clip_model_grad_norm=-1

When reporting scores, we average across 10 fine-tuning seeds.

./scripts/experiments contains a number of example configurations, including for SGI-M, SGI-M/L and SGI-W, for both pre-training and fine-tuning. Each of these scripts can be launched by providing a game and seed, e.g., ./scripts/experiments/sgim_pretrain.sh pong 1. These scripts are provided primarily to illustrate the hyperparameters used for different experiments; you will likely need to modify the arguments in these scripts to point to your data and model directories.

Data for SGI-R and SGI-E is not included due to its size, but can be re-generated locally. Contact us for details.

What does each file do?

.
β”œβ”€β”€ scripts
β”‚   β”œβ”€β”€ run.py                # The main runner script to launch jobs.
β”‚   β”œβ”€β”€ config.yaml           # The hydra configuration file, listing hyperparameters and options.
|   β”œβ”€β”€ download_replay_dataset.sh  # Helper script to download the DQN replay dataset.
|   └── experiments           # Configurations for various experiments done by SGI.
|   
β”œβ”€β”€ src                     
β”‚   β”œβ”€β”€ agent.py              # Implements the Agent API for action selection 
β”‚   β”œβ”€β”€ algos.py              # Distributional RL loss and optimization
β”‚   β”œβ”€β”€ models.py             # Forward passes, network initialization.
β”‚   β”œβ”€β”€ networks.py           # Network architecture and forward passes.
β”‚   β”œβ”€β”€ offline_dataset.py    # Dataloader for offline data.
β”‚   β”œβ”€β”€ gcrl.py               # Utils for SGI's goal-conditioned RL objective.
β”‚   β”œβ”€β”€ rlpyt_atari_env.py    # Slightly modified Atari env from rlpyt
β”‚   β”œβ”€β”€ rlpyt_utils.py        # Utility methods that we use to extend rlpyt's functionality
β”‚   └── utils.py              # Command line arguments and helper functions 
β”‚
└── requirements.txt          # Dependencies

About

Official code for "Pretraining Representations For Data-Efficient Reinforcement Learning" (NeurIPS 2021)

https://arxiv.org/abs/2106.04799

License:MIT License


Languages

Language:Python 91.9%Language:Shell 8.1%