avik-pal / FluxMPI.jl

Distributed Data Parallel Training of Deep Neural Networks

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

FluxMPI.jl

Caution

This package should be considered deprecated and won't receive any updates. Distributed Training will become a native feature for Lux, so it makes little sense for me to maintain an additional package that does the same thing. Track LuxDL/Lux.jl#494 for furthur updates.

Stable Latest

CI codecov Package Downloads

ColPrac: Contributor's Guide on Collaborative Practices for Community Packages SciML Code Style

Distributed Data Parallel Training of Neural Networks

Installation

Stable release:

] add FluxMPI

Latest development version:

] add FluxMPI#main

Quick Start

using CUDA, FluxMPI, Lux, Optimisers, Random, Zygote

FluxMPI.Init()
CUDA.allowscalar(false)

model = Chain(Dense(1 => 256, tanh), Dense(256 => 512, tanh), Dense(512 => 256, tanh),
              Dense(256 => 1))
rng = Random.default_rng()
Random.seed!(rng, local_rank())
ps, st = Lux.setup(rng, model) .|> gpu

ps = FluxMPI.synchronize!(ps; root_rank = 0)
st = FluxMPI.synchronize!(st; root_rank = 0)

x = rand(rng, 1, 16) |> gpu
y = x .^ 2

opt = DistributedOptimizer(Adam(0.001f0))
st_opt = Optimisers.setup(opt, ps)

loss(p) = sum(abs2, model(x, p, st)[1] .- y)

st_opt = FluxMPI.synchronize!(st_opt; root_rank = 0)

gs_ = gradient(loss, ps)[1]
Optimisers.update(st_opt, ps, gs_)

t1 = time()

for epoch in 1:100
  global ps, st_opt
  l, back = Zygote.pullback(loss, ps)
  FluxMPI.fluxmpi_println("Epoch $epoch: Loss $l")
  gs = back(one(l))[1]
  st_opt, ps = Optimisers.update(st_opt, ps, gs)
end

FluxMPI.fluxmpi_println(time() - t1)

Run the code using mpiexecjl -n 3 julia --project=. <filename>.jl.

Examples

Style Guide

We follow the Lux Style Guide. All contributions must adhere to this style guide.

Changelog

v0.7

  • Dropped support for MPI v0.19.
  • FLUXMPI_DISABLE_CUDAMPI_SUPPORT is no longer used. Instead use FluxMPI.disable_cudampi_support() to setup a LocalPreferences.toml file.
  • clean_(print/println) functions are now fluxmpi_(print/println).

v0.6

  • Dropped support for LearnBase, aka DataLoaders.jl. DistributedDataContainer is now the only compatible with MLUtils.jl.
  • DistributedOptimiser name changed to DistributedOptimizer.

v0.5

v0.5.3

  • Introduces a new API for gradient synchronization
    • Don't wrap in DistributedOptimiser
    • Instead just add a line allreduce_gradients(gs::NamedTuple)

v0.5.1

  • Internal MPIExtensions functions renamed
    • Allreduce! --> allreduce!
    • Bcast! --> bcast!
    • Reduce! --> reduce!
  • CUDA-unaware MPI bug resolved LuxDL/Lux.jl#18
  • Disable CUDA-aware MPI support from FluxMPI using FLUXMPI_DISABLE_CUDAMPI_SUPPORT=true
  • Temporarily re-added dependencies on MLDataUtils and LearnBase to ensure DataLoaders.jl still works -- This will be dropped in a future release

v0.5.0

  • DistributedOptimiser no longer averages the gradients. Instead, the values are summed across the processes. To ensure averaging divide the loss by total_workers()
  • rrules and frules defined for local_rank() and total_workers -- they can now be safely used inside loss functions.

v0.4

  • fluxmpi_print and fluxmpi_println print the current time even if FluxMPI has not been initialized.
  • Calling local_rank or total_workers before FluxMPI.Init doesn't lead to a segfault. Rather we throw an error.
  • MLDataUtils and LearnBase dependencies have been dropped (See #17)
  • Zygote and Flux dependencies have been removed
    • No dispatch for FluxMPI.synchronize! is now available for Zygote.Params. Instead users should be manually broadcasting the function over Zygote.Params

v0.3

  • broadcast_parameters has been renamed to FluxMPI.synchronize! since it synchronizes a lot more than trainable parameters now.
  • DistributedOptimiser is no longer tied with Flux. We can essentially deal with any training as long as it is compatible with Optimisers.jl

About

Distributed Data Parallel Training of Deep Neural Networks

License:MIT License


Languages

Language:Julia 100.0%