google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Home Page:http://jax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Distributed training is stuck

syyxsxx opened this issue · comments

commented

Description

I use two 4090 host for data parallel distributed training by jax.distributed, like this:
jax.distributed.initialize(coordinator_address="[ip]:[port]",
num_processes=2,
process_id=[index])
the train is stuck when doing all_reduce ops
2611715271422_ pic
How can I debug this problem?
Are there any examples for parallel distributed training

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.23
jaxlib: 0.4.23
numpy: 1.26.3
python: 3.10.13 (main, Sep 11 2023, 13:44:35) [GCC 11.2.0]
jax.devices (10 total, 10 local): [cuda(id=0) cuda(id=1) ... cuda(id=8) cuda(id=9)]
process_count: 1

$ nvidia-smi
Mon May 13 19:53:47 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA GeForce RTX 4090 Off | 00000000:45:00.0 Off | Off |
| 30% 28C P2 39W / 450W | 406MiB / 24564MiB | 0% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
| 1 NVIDIA GeForce RTX 4090 Off | 00000000:46:00.0 Off | Off |
| 72% 61C P2 411W / 450W | 19697MiB / 24564MiB | 100% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
| 2 NVIDIA GeForce RTX 4090 Off | 00000000:49:00.0 Off | Off |
| 78% 62C P2 418W / 450W | 19689MiB / 24564MiB | 100% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
| 3 NVIDIA GeForce RTX 4090 Off | 00000000:4E:00.0 Off | Off |
| 73% 61C P2 396W / 450W | 19689MiB / 24564MiB | 100% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
| 4 NVIDIA GeForce RTX 4090 Off | 00000000:4F:00.0 Off | Off |
| 71% 61C P2 407W / 450W | 19689MiB / 24564MiB | 100% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
| 5 NVIDIA GeForce RTX 4090 Off | 00000000:C5:00.0 Off | Off |
| 73% 61C P2 411W / 450W | 19689MiB / 24564MiB | 100% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
| 6 NVIDIA GeForce RTX 4090 Off | 00000000:C6:00.0 Off | Off |
| 80% 63C P2 416W / 450W | 19689MiB / 24564MiB | 100% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
| 7 NVIDIA GeForce RTX 4090 Off | 00000000:C9:00.0 Off | Off |
| 78% 62C P2 402W / 450W | 19689MiB / 24564MiB | 100% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
| 8 NVIDIA GeForce RTX 4090 Off | 00000000:CE:00.0 Off | Off |
| 73% 61C P2 382W / 450W | 19689MiB / 24564MiB | 100% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
| 9 NVIDIA GeForce RTX 4090 Off | 00000000:CF:00.0 Off | Off |
| 78% 62C P2 404W / 450W | 19689MiB / 24564MiB | 100% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+

The latest jax and jaxlib versions are 0.4.28, can you try them first? 0.4.23 is pretty old.

On GPU, it might also be worth trying a configuration that has one GPU per process. That may avoid a class of deadlocks in NVIDIA's NCCL library.

commented

On GPU, it might also be worth trying a configuration that has one GPU per process. That may avoid a class of deadlocks in NVIDIA's NCCL library.

@hawkinsp hi,How to configure configuration in Jax can make one GPU per process. ps When one host, I can train multiple GPUs by pmap, but when multi host it is stuck

How did you launch the job? Are you using a cluster scheduler of some kind? If you're using one that JAX already integrates with (e.g., SLURM) we have code handle this already, but perhaps you're not using one.

Basically you need to do two things:
a) run one process per GPU, and arrange that it has visibility to only the GPU it is supposed to have. If you are running multiple processes on a single machine with multiple GPUs, you can limit which GPUs any given JAX process sees by setting JAX_CUDA_VISIBLE_DEVICES to 0, 1, ... for each process within a machine.

b) When you call jax.distributed.initialize to set up a distributed training job, set process_id and num_processes to reflect the fact you have one process per GPU
(https://jax.readthedocs.io/en/latest/_autosummary/jax.distributed.initialize.html).

Does that answer the question?

commented

@hawkinsp
hi,
I launch the job by starting process manually on each machine. i set the gpu by CUDA_VISIBLE_DEVICES, also, i have set the process_id and num_processes ,the code like this,
host0:

jax.distributed.initialize(coordinator_address="66.181.42.141:8889",
                           num_processes=2,
                           process_id=0)

host1:

jax.distributed.initialize(coordinator_address="66.181.42.141:8889",
                           num_processes=2,
                           process_id=1)

also, i have tried launch the distributed training on two a100 machine, single-machine training is ok, but distributed training is stuck
2661716815053_ pic

i am trying SLURM, but not succeed yet,Is SLURM necessary for jax distribution?

Please confirm you're using jax 0.4.28.

@syyxsxx Those warnings might mean compilation is slow, but shouldn't cause a deadlock.

If you're using pmap yourself explicitly, then one thing to make sure is that both processes are performing the same pmaps in the same order.

I think I'll need a reproduction of the problem to help further.

commented

@hawkinsp
hi, i am using jax==0.4.23, I tried to use jax==0.4.28, but there were some errors. I will try jax==0.4.28 again later.
the pmap code is here:

jax.pmap(jax.random.PRNGKey)(jnp.arange(jax.local_device_count())).block_until_ready()