Distributed training is stuck
syyxsxx opened this issue · comments
Description
I use two 4090 host for data parallel distributed training by jax.distributed, like this:
jax.distributed.initialize(coordinator_address="[ip]:[port]",
num_processes=2,
process_id=[index])
the train is stuck when doing all_reduce ops
How can I debug this problem?
Are there any examples for parallel distributed training
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.23
jaxlib: 0.4.23
numpy: 1.26.3
python: 3.10.13 (main, Sep 11 2023, 13:44:35) [GCC 11.2.0]
jax.devices (10 total, 10 local): [cuda(id=0) cuda(id=1) ... cuda(id=8) cuda(id=9)]
process_count: 1
$ nvidia-smi
Mon May 13 19:53:47 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA GeForce RTX 4090 Off | 00000000:45:00.0 Off | Off |
| 30% 28C P2 39W / 450W | 406MiB / 24564MiB | 0% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
| 1 NVIDIA GeForce RTX 4090 Off | 00000000:46:00.0 Off | Off |
| 72% 61C P2 411W / 450W | 19697MiB / 24564MiB | 100% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
| 2 NVIDIA GeForce RTX 4090 Off | 00000000:49:00.0 Off | Off |
| 78% 62C P2 418W / 450W | 19689MiB / 24564MiB | 100% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
| 3 NVIDIA GeForce RTX 4090 Off | 00000000:4E:00.0 Off | Off |
| 73% 61C P2 396W / 450W | 19689MiB / 24564MiB | 100% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
| 4 NVIDIA GeForce RTX 4090 Off | 00000000:4F:00.0 Off | Off |
| 71% 61C P2 407W / 450W | 19689MiB / 24564MiB | 100% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
| 5 NVIDIA GeForce RTX 4090 Off | 00000000:C5:00.0 Off | Off |
| 73% 61C P2 411W / 450W | 19689MiB / 24564MiB | 100% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
| 6 NVIDIA GeForce RTX 4090 Off | 00000000:C6:00.0 Off | Off |
| 80% 63C P2 416W / 450W | 19689MiB / 24564MiB | 100% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
| 7 NVIDIA GeForce RTX 4090 Off | 00000000:C9:00.0 Off | Off |
| 78% 62C P2 402W / 450W | 19689MiB / 24564MiB | 100% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
| 8 NVIDIA GeForce RTX 4090 Off | 00000000:CE:00.0 Off | Off |
| 73% 61C P2 382W / 450W | 19689MiB / 24564MiB | 100% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
| 9 NVIDIA GeForce RTX 4090 Off | 00000000:CF:00.0 Off | Off |
| 78% 62C P2 404W / 450W | 19689MiB / 24564MiB | 100% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
The latest jax and jaxlib versions are 0.4.28, can you try them first? 0.4.23 is pretty old.
On GPU, it might also be worth trying a configuration that has one GPU per process. That may avoid a class of deadlocks in NVIDIA's NCCL library.
On GPU, it might also be worth trying a configuration that has one GPU per process. That may avoid a class of deadlocks in NVIDIA's NCCL library.
@hawkinsp hi,How to configure configuration in Jax can make one GPU per process. ps When one host, I can train multiple GPUs by pmap, but when multi host it is stuck
How did you launch the job? Are you using a cluster scheduler of some kind? If you're using one that JAX already integrates with (e.g., SLURM) we have code handle this already, but perhaps you're not using one.
Basically you need to do two things:
a) run one process per GPU, and arrange that it has visibility to only the GPU it is supposed to have. If you are running multiple processes on a single machine with multiple GPUs, you can limit which GPUs any given JAX process sees by setting JAX_CUDA_VISIBLE_DEVICES
to 0
, 1
, ... for each process within a machine.
b) When you call jax.distributed.initialize
to set up a distributed training job, set process_id
and num_processes
to reflect the fact you have one process per GPU
(https://jax.readthedocs.io/en/latest/_autosummary/jax.distributed.initialize.html).
Does that answer the question?
@hawkinsp
hi,
I launch the job by starting process manually on each machine. i set the gpu by CUDA_VISIBLE_DEVICES, also, i have set the process_id and num_processes ,the code like this,
host0:
jax.distributed.initialize(coordinator_address="66.181.42.141:8889",
num_processes=2,
process_id=0)
host1:
jax.distributed.initialize(coordinator_address="66.181.42.141:8889",
num_processes=2,
process_id=1)
also, i have tried launch the distributed training on two a100 machine, single-machine training is ok, but distributed training is stuck
i am trying SLURM, but not succeed yet,Is SLURM necessary for jax distribution?
Please confirm you're using jax 0.4.28.
@syyxsxx Those warnings might mean compilation is slow, but shouldn't cause a deadlock.
If you're using pmap
yourself explicitly, then one thing to make sure is that both processes are performing the same pmap
s in the same order.
I think I'll need a reproduction of the problem to help further.