ekzhang / archax

Experiments in multi-architecture parallelism for deep learning with JAX

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

archax

Experiments in multi-architecture parallelism for deep learning with JAX.

Example JAX computation graph

What if we could create a new kind of multi-architecture parallelism library for deep learning compilers, supporting expressive frontends like JAX? This would optimize a mix of pipeline and operator parallelism on accelerated devices. Use both CPU, GPU, and/or TPU in the same program, and automatically interleave between them.

Experiments are given in this repository, dated and annotated with brief descriptions.

License

All code and notebooks in this repository are distributed under the terms of the MIT license.

About

Experiments in multi-architecture parallelism for deep learning with JAX

License:MIT License


Languages

Language:Python 79.3%Language:Jupyter Notebook 20.7%