tchaton / fairscale

PyTorch extensions for high performance and large scale training.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

fairscale

PyPI Documentation Status CircleCI PyPI - License PRs Welcome

Description

fairscale is a PyTorch extension library for high performance and large scale training for optimizing training on one or across multiple machines/nodes. This library extend basic pytorch capabilities while adding new experimental ones.

fairscale supports:

  • Parallelism:
    • pipeline parallelism (fairscale.nn.Pipe)
    • tensor parallelism (fairscale.nn.model_parallel)
  • Optimization:
    • optimizer state sharding (fairscale.optim.oss)

Requirements

  • PyTorch >= 1.5.1

Installation

Normal installation:

pip install fairscale

Development mode:

cd fairscale
pip install -r requirements.txt
pip install -e .

Getting Started

The full documentation (https://fairscale.readthedocs.io/) contains instructions for getting started and extending fairscale.

Examples

Pipe

Run a 4-layer model on 2 GPUs. The first two layers run on cuda:0 and the next two layers run on cuda:1.

import torch

import fairscale

model = torch.nn.Sequential(a, b, c, d)
model = fairscale.nn.Pipe(model, balance=[2, 2], devices=[0, 1], chunks=8)

Optimizer state sharding (ZeRO)

See a more complete example here, but a minimal example could look like the following :

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from fairscale.optim.oss import OSS
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP

def train(
    rank: int,
    world_size: int,
    epochs: int):

    # DDP init example
    dist.init_process_group(backend='nccl', init_method="tcp://localhost:29501", rank=rank, world_size=world_size)

    # Problem statement
    model = myAwesomeModel().to(rank)
    model = ShardedDDP(model, device_ids=[rank])  # this will handle the gradient reduce automatically
    dataloader = mySuperFastDataloader()
    loss_fn = myVeryRelevantLoss()
    base_optimizer = torch.optim.SGD # pick any pytorch compliant optimizer here
    base_optimizer_arguments = {} # pass any optimizer specific arguments here, or directly below when instantiating OSS

    optimizer = OSS(params=model.parameters(), optim=base_optimizer, **base_optimizer_arguments)

    # Any relevant training loop, nothing specific to OSS. For example:
    model.train()
    for e in range(epochs):
        for batch in dataloader:
            # Train
            model.zero_grad()
            outputs = model(batch["inputs"])
            loss = loss_fn(outputs, batch["label"])
            loss.backward()
            optimizer.step()

    dist.destroy_process_group()

if __name__ == "__main__":
    # Supposing that WORLD_SIZE and EPOCHS are somehow defined somewhere
    mp.spawn(
        train,
        args=(
            WORLD_SIZE,
            EPOCHS,
        ),
        nprocs=WORLD_SIZE,
        join=True,
    )

Testing

We use circleci to test on PyTorch versions 1.5.1, 1.6.0 and 1.7.0 and CUDA version 10.1. Please create an issue if you are having trouble with installation.

Contributors

See the CONTRIBUTING file for how to help out.

License

fairscale is licensed under the BSD-3-Clause License.

fairscale.nn.pipe is forked from torchgpipe, Copyright 2019, Kakao Brain, licensed under Apache License.

fairscale.nn.model_parallel is forked from Megatron-LM, Copyright 2020, NVIDIA CORPORATION, licensed under Apache License.

fairscale.optim.adascale is forked from AdaptDL, Copyright 2020, Petuum, Inc., licensed under Apache License.

References

Here is a list of all authors on relevant research papers this work is based on:

  • torchgpipe: Chiheon Kim, Heungsub Lee, Myungryong Jeong, Woonhyuk Baek, Boogeon Yoon, Ildoo Kim, Sungbin Lim, Sungwoong Kim. [Paper] [Code]
  • ZeRO: Samyam Rajbhandari, Jeff Rasley, Olatunji Ruwase, Yuxiong He. [Paper] [Code]
  • Megatron-LM: Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper, Bryan Catanzaro. [Paper][Code]
  • AdaScale SGD: Tyler B. Johnson, Pulkit Agrawal, Haijie Gu, Carlos Guestrin. [Paper]

About

PyTorch extensions for high performance and large scale training.

License:Other


Languages

Language:Python 98.2%Language:Cuda 1.6%Language:C++ 0.1%Language:Shell 0.0%Language:C 0.0%