sbhavani / JAX-Toolbox

JAX-Toolbox

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

JAX Toolbox

Image Build Tests

container-badge-base

build-badge-base n/a
Frameworks
container-badge-jax
build-badge-jax test-badge-jax-V100
test-badge-jax-A100
container-badge-t5x build-badge-t5x test-badge-t5x
container-badge-pax build-badge-pax test-badge-pax
container-badge-te build-badge-te unit-test-badge-te
integration-test-badge-te
Rosetta
container-badge-rosetta-t5x build-badge-rosetta-t5x test-badge-rosetta-t5x
container-badge-rosetta-pax build-badge-rosetta-pax test-badge-rosetta-pax

Note

This repo currently hosts a public CI for JAX on NVIDIA GPUs and covers some JAX libraries like: T5x, PAXML, Transformer Engine, and others to come soon.

Supported Models

We currently enable training and evaluation for the following models:

Model Name Pretraining Fine-tuning Evaluation
GPT-3(paxml) ✔️ ✔️ ✔️
t5(t5x) ✔️ ✔️ ✔️
ViT ✔️ ✔️ ✔️

We will update this table as new models become available, so stay tuned.

Environment Variables

The JAX image is embedded with the following flags and environment variables for performance tuning:

XLA Flags Value Explanation
--xla_gpu_enable_latency_hiding_scheduler true allows XLA to move communication collectives to increase overlap with compute kernels
--xla_gpu_enable_async_all_gather true allows XLA to run NCCL AllGather kernels on a separate CUDA stream to allow overlap with compute kernels
--xla_gpu_enable_async_reduce_scatter true allows XLA to run NCCL ReduceScatter kernels on a separate CUDA stream to allow overlap with compute kernels
--xla_gpu_enable_triton_gemm false use cuBLAS instead of Trition GeMM kernels
Environment Variable Value Explanation
CUDA_DEVICE_MAX_CONNECTIONS 1 use a single queue for GPU work to lower latency of stream operations; OK since XLA already orders launches
NCCL_IB_SL 1 defines the InfiniBand Service Level (1)

FAQ (Frequently Asked Questions)

Question: A "bus error"

Q: When I execute my JAX code, I come across a bus error. How can I address this issue?

A: The bus error might occur due to the size limitation of /dev/shm. You can address this by increasing the shared memory size using the --shm-size option when launching your container. Here is a demonstration of how this can be achieved using Docker:

docker run -it --shm-size=1g ...

JAX on Public Clouds

Resources

About

JAX-Toolbox

License:Apache License 2.0


Languages

Language:Python 84.6%Language:Shell 15.4%