Brax is a differentiable physics engine that simulates environments made up of rigid bodies, joints, and actuators. Brax is written in JAX and is designed for use on acceleration hardware. It is both efficient for single-device simulation, and scalable to massively parallel simulation on multiple devices, without the need for pesky datacenters.
Some policies trained via Brax. Brax simulates these environments at millions of physics steps per second on TPU.
Brax also includes a suite of learning algorithms that train agents in seconds to minutes:
- Baseline learning algorithms such as PPO, SAC, and evolutionary strategies.
- Experimental algorithms such as Variational GCRL, Adversarial Inverse RL, and State Marginal Matching
- Learning algorithms that leverage the differentiability of the simulator, such as analytic policy gradients.
Explore Brax easily and quickly through a series of colab notebooks:
- Brax Basics introduces the Brax API, and shows how to simulate basic physics primitives.
- Brax Training introduces Brax environments and training algorithms, and lets you train your own policies directly within the colab.
- Brax + PyTorch demonstrates how to operate Brax environments performantly from PyTorch.
To install Brax from source, clone this repo, cd
to it, and then:
python3 -m venv env
source env/bin/activate
pip install --upgrade pip
pip install -e .
To train a model:
learn
Training on NVidia GPU is supported, but you must first install CUDA, CuDNN, and JAX with GPU support.
For a deep dive into Brax's design and performance characteristics, please see our paper, Brax -- A Differentiable Physics Engine for Large Scale Rigid Body Simulation , to appear in the Datasets and Benchmarks Track at NeurIPS 2021.
If you would like to reference Brax in a publication, please use:
@software{brax2021github,
author = {C. Daniel Freeman and Erik Frey and Anton Raichuk and Sertan Girgin and Igor Mordatch and Olivier Bachem},
title = {Brax - A Differentiable Physics Engine for Large Scale Rigid Body Simulation},
url = {http://github.com/google/brax},
version = {0.1.0},
year = {2021},
}