google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

First accelerator takes much more memory than over accelerators when parallel training

qsh-zh opened this issue · comments

Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.

Problem you have encountered:

The context is similar to the one in the tutorial. I want to take advantage of more than 1 accelerator(gpu or gpu) to train networks.
However, I found the initialization of networks and weights are default on the first accelerator instead of distributed on different accelerators equally.

What you expected to happen:

I hope when I do initialization,

p_train_state = flax.jax_utils.replicate(train_state)

instead of place whole p_train_state in first accelerator, they can be distributed on different accelerators equally. A similar idea is expected for preparing data datasets.

It can help reduce the memory consumption of first accelerator and avoid OOM when other accelerators have enough space.

Logs, error messages, etc:

It is not a bug. I used PyTorch before and DDP in PyTorch has the feature I describe. Maybe there is a good idea and implementation to bypass the problem. But I can not find it in examples or tutorials.

Steps to reproduce:

Whenever possible, please provide a minimal example. Please consider submitting it as a Colab link.

Do you want to shard your model weights (i.e., model parallelism), or replicate them (i.e., data parallelism) over your devices? The way you are describing it sounds like you want to do model parallelism, but DDP in PyTorch is for data parallelism. From the DDP docs:

"Distributed Data-Parallel Training (DDP) is a widely adopted single-program multiple-data training paradigm. With DDP, the model is replicated on every process, and every model replica will be fed with a different set of input data samples. DDP takes care of gradient communication to keep model replicas synchronized and overlaps it with the gradient computations to speed up training."