google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Home Page:http://jax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

NCCL error when trying xmap on 4 T4 GPUs

ranzenTom opened this issue · comments

Hi,
First, thank you for making this library available, I love it!

I tried to reproduce the example from the xmap documentation (https://jax.readthedocs.io/en/latest/notebooks/xmap_tutorial.html) on a VM with 4 T4 GPUs.

I am using conda and the requirements used are:

name: jax-env
channels:
  - defaults
  - conda-forge
dependencies:
  - python=3.8
  - numpy==1.19.5
  - pandas==1.2.5
  - six==1.15.0
  - pip=21.0.1
  - pip:
      - --find-links https://storage.googleapis.com/jax-releases/jax_releases.html
      - jax==0.3.2
      - optax==0.0.9
      - dm-haiku==0.0.6
      - regex==2022.1.18
      - tokenizers==0.11.5
      - tensorflow==2.6.0
      - neptune-client==0.14.3
      - biopython==1.79
      - jaxlib==0.3.0+cuda11.cudnn82
      - tqdm==4.56.0`

I launched this snippet of code (taken from the documentation):

import jax
import jax.numpy as jnp
import numpy as np
from jax import lax
from jax.experimental.maps import Mesh, xmap
from jax.nn import one_hot, relu
from jax.scipy.special import logsumexp


def named_predict(w1: jnp.ndarray, w2: jnp.ndarray, image: jnp.ndarray) -> jnp.ndarray:
    hidden = relu(lax.pdot(image, w1, "inputs"))
    logits = lax.pdot(hidden, w2, "hidden")
    return logits - logsumexp(logits, "classes")


def named_loss(
    w1: jnp.ndarray, w2: jnp.ndarray, images: jnp.ndarray, labels: jnp.ndarray
) -> jnp.ndarray:
    predictions = named_predict(w1, w2, images)
    num_classes = lax.psum(1, "classes")
    targets = one_hot(labels, num_classes, axis="classes")
    losses = lax.psum(targets * predictions, "classes")
    return -lax.pmean(losses, "batch")


if __name__ == "__main__":
    # Start script
    devices = jax.local_devices()
    print(f"Start test, detected devices: {devices}")

    # Generate dummy data
    w1 = jnp.zeros((784, 512))
    w2 = jnp.zeros((512, 10))
    images = jnp.zeros((128, 784))
    labels = jnp.zeros(128, dtype=jnp.int32)

    # Prepare xmapped function
    in_axes = [
        ["inputs", "hidden", ...],
        ["hidden", "classes", ...],
        ["batch", "inputs", ...],
        ["batch", ...],
    ]
    loss = xmap(
        named_loss,
        in_axes=in_axes,
        out_axes=[...],
        axis_resources={"batch": "x", "hidden": "y"},
    )

    # Prepare devices mesh
    num_devices = len(devices)
    assert num_devices >= 2 and num_devices % 2 == 0
    devices = np.array(jax.local_devices()).reshape((2, num_devices // 2))
    print(f"Device mesh: {devices}")

    # Run meshed computations
    with Mesh(devices, ("x", "y")):
        print("Loss computed on mesh: ", loss(w1, w2, images, labels))

    print("End of test.")

I got the following output:

Start test, detected devices: [GpuDevice(id=0, process_index=0), GpuDevice(id=1, process_index=0), GpuDevice(id=2, process_index=0), GpuDevice(id=3, process_index=0)]
/opt/conda/envs/jax-env/lib/python3.8/site-packages/jax/experimental/maps.py:492: UserWarning: xmap is an experimental feature and probably has bugs!
  warn("xmap is an experimental feature and probably has bugs!")
Device mesh: [[GpuDevice(id=0, process_index=0) GpuDevice(id=1, process_index=0)]
 [GpuDevice(id=2, process_index=0) GpuDevice(id=3, process_index=0)]]
2022-04-27 12:30:21.693000: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2124] Execution of replica 2 failed: INTERNAL: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_utils.cc:326: NCCL operation ncclCommInitRank(comm.get(), nranks, id, rank) failed: unhandled system error
2022-04-27 12:30:21.693490: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2124] Execution of replica 3 failed: INTERNAL: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_utils.cc:326: NCCL operation ncclCommInitRank(comm.get(), nranks, id, rank) failed: unhandled system error
2022-04-27 12:30:31.317787: E external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_utils.cc:144] This thread has been waiting for 10s and may be stuck:
2022-04-27 12:30:31.317860: E external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_utils.cc:144] This thread has been waiting for 10s and may be stuck:
2022-04-27 12:30:31.693344: F external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2264] Replicated computation launch failed, but not all replicas terminated. Aborting process to work around deadlock. Failure message (there may have been multiple failures, see the error log for all failures): 

external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_utils.cc:326: NCCL operation ncclCommInitRank(comm.get(), nranks, id, rank) failed: unhandled system error
Aborted (core dumped)

If I change the version of Jax from 0.3.2 to 0.3.0 (without changing anything else), then I obtain a different error:

Start test, detected devices: [GpuDevice(id=0, process_index=0), GpuDevice(id=1, process_index=0), GpuDevice(id=2, process_index=0), GpuDevice(id=3, process_index=0)]
/opt/conda/envs/jax-env/lib/python3.8/site-packages/jax/experimental/maps.py:544: UserWarning: xmap is an experimental feature and probably has bugs!
  warn("xmap is an experimental feature and probably has bugs!")
Device mesh: [[GpuDevice(id=0, process_index=0) GpuDevice(id=1, process_index=0)]
 [GpuDevice(id=2, process_index=0) GpuDevice(id=3, process_index=0)]]
Traceback (most recent call last):
  File "scripts/xmap_test.py", line 62, in <module>
    with Mesh(devices, ("x", "y")):
AttributeError: __enter__

Please, do not hesitate if you need more information.
Thanks a lot!

Hmm. This might be tough for us to figure out, given it's happening inside one of NVidia's libraries (NCCL).

One hypothesis: perhaps we are low on memory? It's worth ruling out, at least. https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
Try setting XLA_PYTHON_CLIENT_ALLOCATOR=platform, which is slow, but let's see if it fixes the problem.

@ranzenTom I ran into a similar problem as yours and I later found out that there was a mismatch between the CUDA version the driver was using (11.4) and the CUDA libraries I had loaded (11.1).

I was experimenting with pjit and got the following error:

2022-06-29 12:36:12.905032: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2129] Execution of replica 0 failed: INTERNAL: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_utils.cc:245:
 NCCL operation ncclCommInitRank(comm.get(), nranks, id, rank) failed: unhandled cuda error                                                                                                                                                    
2022-06-29 12:36:12.907134: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2129] Execution of replica 0 failed: INTERNAL: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_utils.cc:245:
 NCCL operation ncclCommInitRank(comm.get(), nranks, id, rank) failed: unhandled cuda error                                                                                                                                                    
2022-06-29 12:36:12.907630: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2129] Execution of replica 0 failed: INTERNAL: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_utils.cc:245:
 NCCL operation ncclCommInitRank(comm.get(), nranks, id, rank) failed: unhandled cuda error                                                                                                                                                    
2022-06-29 12:36:12.907832: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2129] Execution of replica 0 failed: INTERNAL: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_utils.cc:245:
 NCCL operation ncclCommInitRank(comm.get(), nranks, id, rank) failed: unhandled cuda error                                                                                                                                                    
Traceback (most recent call last):                                                                                                                                                                                                             
  File "/petrobr/parceirosbr/brcluster/gustavo.leite/research/examples/jax/partitioning.py", line 25, in <module>                                                                                                                              
    output = f(M, M)                                                                                                                                                                                                                           
  File "/petrobr/parceirosbr/brcluster/gustavo.leite/research/examples/venv/lib/python3.9/site-packages/jax/experimental/pjit.py", line 355, in wrapped                                                                                        
    out = pjit_p.bind(*args_flat, **params)                                                                                                                                                                                                    
  File "/petrobr/parceirosbr/brcluster/gustavo.leite/research/examples/venv/lib/python3.9/site-packages/jax/core.py", line 327, in bind                                                                                                        
    return self.bind_with_trace(find_top_trace(args), args, params)                                                                                                                                                                            
  File "/petrobr/parceirosbr/brcluster/gustavo.leite/research/examples/venv/lib/python3.9/site-packages/jax/core.py", line 330, in bind_with_trace                                                                                             
    out = trace.process_primitive(self, map(trace.full_raise, args), params)                                                                                                                                                                   
  File "/petrobr/parceirosbr/brcluster/gustavo.leite/research/examples/venv/lib/python3.9/site-packages/jax/core.py", line 680, in process_primitive                                                                                           
    return primitive.impl(*tracers, **params)                                                                                                                                                                                                  
  File "/petrobr/parceirosbr/brcluster/gustavo.leite/research/examples/venv/lib/python3.9/site-packages/jax/experimental/pjit.py", line 722, in _pjit_call_impl                                                                                
    return compiled.unsafe_call(*args)                                                                                                                                                                                                         
  File "/petrobr/parceirosbr/brcluster/gustavo.leite/research/examples/venv/lib/python3.9/site-packages/jax/_src/profiler.py", line 312, in wrapper                                                                                            
    return func(*args, **kwargs)                                                                                                                                                                                                               
  File "/petrobr/parceirosbr/brcluster/gustavo.leite/research/examples/venv/lib/python3.9/site-packages/jax/interpreters/pxla.py", line 1687, in __call__                                                                                      
    out_bufs = self.xla_executable.execute_sharded_on_local_devices(input_bufs)                                                                                                                                                                
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_utils.cc:245: NCCL operation ncclCommInitRank(comm.get(), nranks, id, rank) failed: unhandled cuda error: while running replic
a 0 and partition 0 of a replicated computation (other replicas may have failed as well).

In my case, I work on a HPC cluster and I can load different versions of CUDA via environment modules. The fix was as simple as loading CUDA 11.4 via: module load cuda/11.4. I am not sure if this is your case, but I hope it helps.

Click here to see the example code that produced the error.
#!/usr/bin/env python3

import jax
import jax.numpy as jnp
import numpy as np

from jax.experimental.pjit import pjit
from jax.experimental.maps import Mesh
from jax.experimental import PartitionSpec

# Create device mesh
mesh = Mesh(np.array(jax.devices()).reshape((2, 2)), ("x", "y"))

# Create matrix and vector
M = jnp.eye(8)
v = jnp.arange(8).reshape((8, 1))

spec = PartitionSpec("x", "y")

f = pjit(jnp.dot,
         in_axis_resources=(spec, None),
         out_axis_resources=spec)

with mesh:
    output = f(M, M)

print(output)

Given there was no follow-up and we can't reproduce this, I'm going to declare the issue stale.

(Please retry with an up to date jax/jaxlib, if it's still a problem.)