mukullokhande99 / alpa

Auto parallelization for large-scale neural networks

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Alpa

Documentation | Slack

Build Jaxlib and Jax CI

Alpa is a system for training large-scale neural networks. Scaling neural networks to hundreds of billions of parameters has enabled dramatic breakthroughs such as GPT-3, but training these large-scale neural networks requires complicated distributed training techniques. Alpa aims to automate large-scale distributed training with just a few lines of code.

The key features of Alpa include:

💻 Automatic Parallelization. Alpa automatically parallelizes users' single-device code on distributed clusters with data, operator, and pipeline parallelism.

🚀 Excellent Performance. Alpa achieves linear scaling on training models with billions of parameters on distributed clusters.

Tight Integration with Machine Learning Ecosystem. Alpa is backed by open-source, high-performance, and production-ready libraries such as Jax, XLA, and Ray

Quick Start

Use Alpa's decorator @parallelize to scale your single-device training code to distributed clusters.

import alpa

# Parallelize the training step in Jax by simply using a decorator
@alpa.parallelize
def train_step(model_state, batch):
    def loss_func(params):
        out = model_state.forward(params, batch["x"])
        return jnp.mean((out - batch["y"]) ** 2)

    grads = grad(loss_func)(model_state.params)
    new_model_state = model_state.apply_gradient(grads)
    return new_model_state

# The training loop now automatically runs on your designated cluster
model_state = create_train_state()
for batch in data_loader:
    model_state = train_step(model_state, batch)

Check out the Alpa Documentation site for installation instructions, tutorials, examples, and more.

More Information

Getting Involved

  • Please read the contributor guide if you are interested in contributing to Alpa.
  • Please connect to Alpa contributors via the Alpa slack.

License

Alpa is licensed under the Apache-2.0 license.

About

Auto parallelization for large-scale neural networks

License:Apache License 2.0


Languages

Language:Python 94.2%Language:Jupyter Notebook 4.4%Language:Shell 0.9%Language:Dockerfile 0.3%Language:Starlark 0.2%Language:C++ 0.0%Language:CMake 0.0%Language:Cuda 0.0%