Crash with pmap on high RAM V100
alcybiades opened this issue · comments
Hi there! First off, thanks so much for publishing this code!
This issue may just amount to my GPU not having enough memory, but I thought I'd share it since the Colab mentions that pmap should work with V100s. I am running it on Colab Pro+ with the High-RAM runtime shape, but when I get to this line:
results = p_generate_256_samples(pmap_input_tokens, sample_rngs)
I get the following error (I've tried a few times):
---------------------------------------------------------------------------
UnfilteredStackTrace Traceback (most recent call last)
[<ipython-input-9-f2eb36cfcc0b>](https://localhost:8080/#) in <module>()
17 sample_rngs = jax.random.split(sample_rng, jax.local_device_count())
---> 18 results = p_generate_256_samples(pmap_input_tokens, sample_rngs)
19
10 frames
UnfilteredStackTrace: RuntimeError: UNKNOWN: CUDNN_STATUS_NOT_SUPPORTED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4839): 'status'
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
RuntimeError Traceback (most recent call last)
[<ipython-input-9-f2eb36cfcc0b>](https://localhost:8080/#) in <module>()
16 elif run_mode == 'pmap':
17 sample_rngs = jax.random.split(sample_rng, jax.local_device_count())
---> 18 results = p_generate_256_samples(pmap_input_tokens, sample_rngs)
19
20 # flatten the pmap results
RuntimeError: UNKNOWN: CUDNN_STATUS_NOT_SUPPORTED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4839): 'status'
It may also be notable that the function runs for a long time (20+ seconds) before throwing that error despite the fact that pmap should make it fast. Since I get an actual out-of-memory crash here on other GPUs with less RAM, I figured there may be a chance that this is a different issue.
Thx again for sharing the code :)
Oh, and here is my !nvidia-smi
:
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03 Driver Version: 460.32.03 CUDA Version: 11.2 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 Tesla V100-SXM2... Off | 00000000:00:04.0 Off | 0 |
| N/A 35C P0 40W / 300W | 16027MiB / 16160MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
+-----------------------------------------------------------------------------+