NCCL error when trying xmap on 4 T4 GPUs
ranzenTom opened this issue · comments
Hi,
First, thank you for making this library available, I love it!
I tried to reproduce the example from the xmap documentation (https://jax.readthedocs.io/en/latest/notebooks/xmap_tutorial.html) on a VM with 4 T4 GPUs.
I am using conda and the requirements used are:
name: jax-env
channels:
- defaults
- conda-forge
dependencies:
- python=3.8
- numpy==1.19.5
- pandas==1.2.5
- six==1.15.0
- pip=21.0.1
- pip:
- --find-links https://storage.googleapis.com/jax-releases/jax_releases.html
- jax==0.3.2
- optax==0.0.9
- dm-haiku==0.0.6
- regex==2022.1.18
- tokenizers==0.11.5
- tensorflow==2.6.0
- neptune-client==0.14.3
- biopython==1.79
- jaxlib==0.3.0+cuda11.cudnn82
- tqdm==4.56.0`
I launched this snippet of code (taken from the documentation):
import jax
import jax.numpy as jnp
import numpy as np
from jax import lax
from jax.experimental.maps import Mesh, xmap
from jax.nn import one_hot, relu
from jax.scipy.special import logsumexp
def named_predict(w1: jnp.ndarray, w2: jnp.ndarray, image: jnp.ndarray) -> jnp.ndarray:
hidden = relu(lax.pdot(image, w1, "inputs"))
logits = lax.pdot(hidden, w2, "hidden")
return logits - logsumexp(logits, "classes")
def named_loss(
w1: jnp.ndarray, w2: jnp.ndarray, images: jnp.ndarray, labels: jnp.ndarray
) -> jnp.ndarray:
predictions = named_predict(w1, w2, images)
num_classes = lax.psum(1, "classes")
targets = one_hot(labels, num_classes, axis="classes")
losses = lax.psum(targets * predictions, "classes")
return -lax.pmean(losses, "batch")
if __name__ == "__main__":
# Start script
devices = jax.local_devices()
print(f"Start test, detected devices: {devices}")
# Generate dummy data
w1 = jnp.zeros((784, 512))
w2 = jnp.zeros((512, 10))
images = jnp.zeros((128, 784))
labels = jnp.zeros(128, dtype=jnp.int32)
# Prepare xmapped function
in_axes = [
["inputs", "hidden", ...],
["hidden", "classes", ...],
["batch", "inputs", ...],
["batch", ...],
]
loss = xmap(
named_loss,
in_axes=in_axes,
out_axes=[...],
axis_resources={"batch": "x", "hidden": "y"},
)
# Prepare devices mesh
num_devices = len(devices)
assert num_devices >= 2 and num_devices % 2 == 0
devices = np.array(jax.local_devices()).reshape((2, num_devices // 2))
print(f"Device mesh: {devices}")
# Run meshed computations
with Mesh(devices, ("x", "y")):
print("Loss computed on mesh: ", loss(w1, w2, images, labels))
print("End of test.")
I got the following output:
Start test, detected devices: [GpuDevice(id=0, process_index=0), GpuDevice(id=1, process_index=0), GpuDevice(id=2, process_index=0), GpuDevice(id=3, process_index=0)]
/opt/conda/envs/jax-env/lib/python3.8/site-packages/jax/experimental/maps.py:492: UserWarning: xmap is an experimental feature and probably has bugs!
warn("xmap is an experimental feature and probably has bugs!")
Device mesh: [[GpuDevice(id=0, process_index=0) GpuDevice(id=1, process_index=0)]
[GpuDevice(id=2, process_index=0) GpuDevice(id=3, process_index=0)]]
2022-04-27 12:30:21.693000: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2124] Execution of replica 2 failed: INTERNAL: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_utils.cc:326: NCCL operation ncclCommInitRank(comm.get(), nranks, id, rank) failed: unhandled system error
2022-04-27 12:30:21.693490: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2124] Execution of replica 3 failed: INTERNAL: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_utils.cc:326: NCCL operation ncclCommInitRank(comm.get(), nranks, id, rank) failed: unhandled system error
2022-04-27 12:30:31.317787: E external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_utils.cc:144] This thread has been waiting for 10s and may be stuck:
2022-04-27 12:30:31.317860: E external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_utils.cc:144] This thread has been waiting for 10s and may be stuck:
2022-04-27 12:30:31.693344: F external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2264] Replicated computation launch failed, but not all replicas terminated. Aborting process to work around deadlock. Failure message (there may have been multiple failures, see the error log for all failures):
external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_utils.cc:326: NCCL operation ncclCommInitRank(comm.get(), nranks, id, rank) failed: unhandled system error
Aborted (core dumped)
If I change the version of Jax from 0.3.2 to 0.3.0 (without changing anything else), then I obtain a different error:
Start test, detected devices: [GpuDevice(id=0, process_index=0), GpuDevice(id=1, process_index=0), GpuDevice(id=2, process_index=0), GpuDevice(id=3, process_index=0)]
/opt/conda/envs/jax-env/lib/python3.8/site-packages/jax/experimental/maps.py:544: UserWarning: xmap is an experimental feature and probably has bugs!
warn("xmap is an experimental feature and probably has bugs!")
Device mesh: [[GpuDevice(id=0, process_index=0) GpuDevice(id=1, process_index=0)]
[GpuDevice(id=2, process_index=0) GpuDevice(id=3, process_index=0)]]
Traceback (most recent call last):
File "scripts/xmap_test.py", line 62, in <module>
with Mesh(devices, ("x", "y")):
AttributeError: __enter__
Please, do not hesitate if you need more information.
Thanks a lot!
Hmm. This might be tough for us to figure out, given it's happening inside one of NVidia's libraries (NCCL).
One hypothesis: perhaps we are low on memory? It's worth ruling out, at least. https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
Try setting XLA_PYTHON_CLIENT_ALLOCATOR=platform
, which is slow, but let's see if it fixes the problem.
@ranzenTom I ran into a similar problem as yours and I later found out that there was a mismatch between the CUDA version the driver was using (11.4) and the CUDA libraries I had loaded (11.1).
I was experimenting with pjit
and got the following error:
2022-06-29 12:36:12.905032: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2129] Execution of replica 0 failed: INTERNAL: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_utils.cc:245:
NCCL operation ncclCommInitRank(comm.get(), nranks, id, rank) failed: unhandled cuda error
2022-06-29 12:36:12.907134: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2129] Execution of replica 0 failed: INTERNAL: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_utils.cc:245:
NCCL operation ncclCommInitRank(comm.get(), nranks, id, rank) failed: unhandled cuda error
2022-06-29 12:36:12.907630: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2129] Execution of replica 0 failed: INTERNAL: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_utils.cc:245:
NCCL operation ncclCommInitRank(comm.get(), nranks, id, rank) failed: unhandled cuda error
2022-06-29 12:36:12.907832: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2129] Execution of replica 0 failed: INTERNAL: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_utils.cc:245:
NCCL operation ncclCommInitRank(comm.get(), nranks, id, rank) failed: unhandled cuda error
Traceback (most recent call last):
File "/petrobr/parceirosbr/brcluster/gustavo.leite/research/examples/jax/partitioning.py", line 25, in <module>
output = f(M, M)
File "/petrobr/parceirosbr/brcluster/gustavo.leite/research/examples/venv/lib/python3.9/site-packages/jax/experimental/pjit.py", line 355, in wrapped
out = pjit_p.bind(*args_flat, **params)
File "/petrobr/parceirosbr/brcluster/gustavo.leite/research/examples/venv/lib/python3.9/site-packages/jax/core.py", line 327, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/petrobr/parceirosbr/brcluster/gustavo.leite/research/examples/venv/lib/python3.9/site-packages/jax/core.py", line 330, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/petrobr/parceirosbr/brcluster/gustavo.leite/research/examples/venv/lib/python3.9/site-packages/jax/core.py", line 680, in process_primitive
return primitive.impl(*tracers, **params)
File "/petrobr/parceirosbr/brcluster/gustavo.leite/research/examples/venv/lib/python3.9/site-packages/jax/experimental/pjit.py", line 722, in _pjit_call_impl
return compiled.unsafe_call(*args)
File "/petrobr/parceirosbr/brcluster/gustavo.leite/research/examples/venv/lib/python3.9/site-packages/jax/_src/profiler.py", line 312, in wrapper
return func(*args, **kwargs)
File "/petrobr/parceirosbr/brcluster/gustavo.leite/research/examples/venv/lib/python3.9/site-packages/jax/interpreters/pxla.py", line 1687, in __call__
out_bufs = self.xla_executable.execute_sharded_on_local_devices(input_bufs)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_utils.cc:245: NCCL operation ncclCommInitRank(comm.get(), nranks, id, rank) failed: unhandled cuda error: while running replic
a 0 and partition 0 of a replicated computation (other replicas may have failed as well).
In my case, I work on a HPC cluster and I can load different versions of CUDA via environment modules. The fix was as simple as loading CUDA 11.4 via: module load cuda/11.4
. I am not sure if this is your case, but I hope it helps.
Click here to see the example code that produced the error.
#!/usr/bin/env python3
import jax
import jax.numpy as jnp
import numpy as np
from jax.experimental.pjit import pjit
from jax.experimental.maps import Mesh
from jax.experimental import PartitionSpec
# Create device mesh
mesh = Mesh(np.array(jax.devices()).reshape((2, 2)), ("x", "y"))
# Create matrix and vector
M = jnp.eye(8)
v = jnp.arange(8).reshape((8, 1))
spec = PartitionSpec("x", "y")
f = pjit(jnp.dot,
in_axis_resources=(spec, None),
out_axis_resources=spec)
with mesh:
output = f(M, M)
print(output)
Given there was no follow-up and we can't reproduce this, I'm going to declare the issue stale.
(Please retry with an up to date jax/jaxlib, if it's still a problem.)