balloch / cr-dv2

DreamerV2 implementation of Curious Replay, a method for prioritizing experience replay that is tailored to model-based reinforcement learning agents.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Curious Replay for Model-based Adaptation

Implementation of Curious Replay with the DreamerV2 agent in TensorFlow 2.

If you find this code useful, please reference in your paper:

@article{kauvar2023curious,
  title={Curious Replay for Model-Based Adaptation},
  author={Kauvar, Isaac and Doyle, Chris and Zhou, Linqi and Haber, Nick},
  journal={International Conference on Machine Learning},
  year={2023}
}

Method

Curious Replay prioritizes sampling of past experiences for training the agent's world model, by focusing on the experiences that are most interesting to the agent - whether because they are unfamiliar or surprising. Inspired by the concept of curiosity, which is often used as an intrinsic reward to guide action selection, here curiosity signals are used to guide selection of what experiences the agent should learn from (i.e. train its model with).

Curious Replay is a simple modification to existing agents that use experience replay -- with minimal computational overhead -- by leveraging a count of how many times an experience has been sampled and the model losses that are computed for each training batch.

This prioritization is especially helpful in changing environments, where adaptation is necessary. Curious Replay helps keep the world model up to date as the environment changes, which is a prerequisite for effective action selection in model-based architectures.

Installation Instructions

To install the DreamerV2 agent with Curious Replay, clone the repository, enter the directory, and follow the instructions below.

conda create -n curiousreplay-dv2 python=3.8 -y
conda activate curiousreplay-dv2
pip3 install tensorflow==2.6.0 tensorflow_probability==0.14.0 \
             protobuf==3.20.3 ruamel.yaml \
             'gym[atari]' dm_control crafter \
             keras==2.6 matplotlib pandas numpy==1.19.5 \
             starr==0.2.1 elements==0.3.2 moviepy

Experiment Instructions


Train on Crafter with Curious Replay (using optimized hyperparameter):

python3 dreamerv2/train.py --logdir ~/logdir/crafter/dreamerv2_cr/1 \
  --configs crafter --task crafter_reward  \
  --prioritize_until -1 --priority_params.metric CR \
  --priority_params.a 0.5

Monitor results:

tensorboard --logdir ~/logdir

Summarize crafter run:

python3 dreamerv2/plot_crafter.py

Train on Crafter with Temporal Difference (TD) Prioritized Experience Replay:

python3 dreamerv2/train.py --logdir ~/logdir/crafter/dreamerv2_td/1 \
  --configs crafter --task crafter_reward  \
  --prioritize_until -1 --priority_params.metric TD

To train on Crafter with 8x increased training frequency (takes a long time):

python3 dreamerv2/train.py --logdir ~/logdir/crafter/dreamerv2_cr/1 \
  --configs crafter --task crafter_reward \
  --prioritize_until -1 \
  --priority_params.metric CR \
  --train_steps 8

Train on DM Control:

python3 dreamerv2/train.py --logdir ~/logdir/dmc_walker_walk/dreamerv2_cr/1 \
  --configs dmc_vision --task dmc_walker_walk \
  --prioritize_until -1  --priority_params.metric CR

Note: if running headless and you get this error with DM Control: ImportError: Cannot initialize a headless EGL display, you can run:

sudo killall Xorg
sudo /usr/bin/X :0 &

and potentially

export DISPLAY=:0

and potentially

sudo nvidia-xconfig -a --use-display-device=none

Train on Object Interaction Assay (requires installing adaptgym library, which can sometimes add complication):

# Mujoco may need to be installed at this point. We are working on improving adaptgym to reduce this complexity. 
pip install numpy==1.22
pip install --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple adaptgym==0.1.34 --no-cache-dir
python3 dreamerv2/train.py --logdir ~/logdir/admc_sphero_novel_object/dreamerv2_cr/1 \
  --configs adaptgym --task admc_sphero_novel_object \
  --prioritize_until -1  --priority_params.metric CR
  
python3 dreamerv2/plot_object_interaction.py \
  --logdir ~/logdir/admc_sphero_novel_object/dreamerv2_cr/1 \
  --outdir ~/logdir/admc_sphero_novel_object/dreamerv2_cr/1/plots  

Train on Unchanging Object Interaction Assay

python3 dreamerv2/train.py --logdir ~/logdir/admc_sphero_novel_object_unchanging/dreamerv2_cr/1 \
  --configs adaptgym --task admc_sphero_novel_object_unchanging \
  --prioritize_until -1  --priority_params.metric CR
  
python3 dreamerv2/plot_object_interaction.py \
  --logdir ~/logdir/admc_sphero_novel_object_unchanging/dreamerv2_cr/1 \
  --outdir ~/logdir/admc_sphero_novel_object_unchanging/dreamerv2_cr/1/plots  

Acknowledgments

This repository is largely based on the TensorFlow 2 implementation of DreamerV2. We would like to thank Danijar Hafner for releasing and updating his clean implementation.

Files modified for incorporating Curious Replay

Those wishing to incorporate Curious Replay into additional agents may find it helpful to know the files that were modified from the original DreamerV2:

replay.py, train.py, agent.py, configs.yaml, expl.py

Tips

  • Efficient debugging. You can use the debug config as in --configs crafter debug. This reduces the batch size, increases the evaluation frequency, and disables tf.function graph compilation for easy line-by-line debugging.

  • Infinite gradient norms. This is normal and described under loss scaling in the mixed precision guide. You can disable mixed precision by passing --precision 32 to the training script. Mixed precision is faster but can in principle cause numerical instabilities.

  • Accessing logged metrics. The metrics are stored in both TensorBoard and JSON lines format. You can directly load them using pandas.read_json(). The plotting script also stores the binned and aggregated metrics of multiple runs into a single JSON file for easy manual plotting.

About

DreamerV2 implementation of Curious Replay, a method for prioritizing experience replay that is tailored to model-based reinforcement learning agents.

License:MIT License


Languages

Language:Python 100.0%