lowrollr / turbozero

fast + parallel AlphaZero in JAX

Home Page:https://github.com/lowrollr/turbozero/wiki

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

turbozero ๐ŸŽ๏ธ ๐ŸŽ๏ธ ๐ŸŽ๏ธ ๐ŸŽ๏ธ

๐Ÿ“ฃ If you're looking for the old PyTorch version of turbozero, it's been moved here: turbozero_torch ๐Ÿ“ฃ

turbozero is a vectorized implementation of AlphaZero written in JAX

It contains:

  • Monte Carlo Tree Search with subtree persistence
  • Batched Replay Memory
  • A complete, customizable training/evaluation loop

turbozero is fast and parallelized:

  • every consequential part of the training loop is JIT-compiled
  • parititions across multiple GPUs by default when available ๐Ÿš€ NEW! ๐Ÿš€
  • self-play and evaluation episodes are batched/vmapped with hardware-acceleration in mind

turbozero is extendable:

turbozero is flexible:

  • easy to integrate with you custom JAX environment or neural network architecture.
  • Use the provided training and evaluation utilities, or pick and choose the components that you need.

To get started, check out the Hello World Notebook

Installation

turbozero uses poetry for dependency management, you can install it with:

pip install poetry==1.7.1

Then, to install dependencies:

poetry install

If you're using a GPU/TPU/etc., after running the previous command you'll need to install the device-specific version of JAX.

For a GPU w/ CUDA 12:

poetry source add jax https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

to point poetry towards JAX cuda releases, then use

poetry add jax[cuda12_pip]==0.4.24

to install the CUDA 12 release for JAX. See https://jax.readthedocs.io/en/latest/installation.html for other devices/cuda versions.

I have tested this project with CUDA 11 and CUDA 12.

To launch an ipython kernel, run:

poetry run python -m ipykernel install --user --name turbozero

Issues

If you use this project and encounter an issue, error, or undesired behavior, please submit a GitHub Issue and I will do my best to resolve it as soon as I can. You may also contact me directly via hello@jacob.land.

Contributing

Contributions, improvements, and fixes are more than welcome! For now I don't have a formal process for this, other than creating a Pull Request. For large changes, consider creating an Issue beforehand.

If you are interested in contributing but don't know what to work on, please reach out. I have plenty of things you could do.

References

Papers/Repos I found helpful.

Repositories:

Papers:

Cite This Work

If you found this work useful, please cite it with:

@software{turbozero,
  author = {Marshall, Jacob},
  title = {{turbozero: fast + parallel AlphaZero}},
  url = {https://github.com/lowrollr/turbozero}
}

About

fast + parallel AlphaZero in JAX

https://github.com/lowrollr/turbozero/wiki

License:Apache License 2.0


Languages

Language:Python 78.8%Language:Jupyter Notebook 21.2%