This repository contains Jax (Flax) implementations of Reinforcement Learning algorithms:
- Soft Actor Critic with learnable temperature
- Advantage Weighted Actor Critic
- Behavioral Cloning
The goal of this repository is to provide simple and clean implementations to build research on top of. Please do not use this repository for baseline results and use the original implementations instead.
Install and activate an Anaconda environment
conda env create -f environment.yml
conda activate jax-rl
If you want to run this code on GPU, please follow instructions from the official repository.
Please follow the instructions to build mujoco-py with fast headless GPU rendering.
OpenAI Gym MuJoCo tasks
python train.py --env_name=HalfCheetah-v2 --save_dir=./tmp/
DeepMind Control suite (--env-name=dmc-domain-task)
python train.py --env_name=dmc-cheetah-run --save_dir=./tmp/
For offline RL
python train_offline.py --env_name=halfcheetah-expert-v0 --dataset_name=d4rl --save_dir=./tmp/
For RL finetuning
python train_finetuning.py --env_name=HalfCheetah-v2 --dataset_name=awac --save_dir=./tmp/
If you experience out-of-memory errors, especially with enabled video saving, please consider reading docs on Jax GPU memory allocation. Also, you can try running with the following environment variable:
XLA_PYTHON_CLIENT_MEM_FRACTION=0.80 python ...
Launch tensorboard to see training and evaluation logs
tensorboard --logdir=./tmp/
When contributing to this repository, please first discuss the change you wish to make via issue. If you are not familiar with pull requests, please read this documentation.