google-research / maskgit

Official Jax Implementation of MaskGIT

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Crash with pmap on high RAM V100

alcybiades opened this issue · comments

commented

Hi there! First off, thanks so much for publishing this code!

This issue may just amount to my GPU not having enough memory, but I thought I'd share it since the Colab mentions that pmap should work with V100s. I am running it on Colab Pro+ with the High-RAM runtime shape, but when I get to this line:

results = p_generate_256_samples(pmap_input_tokens, sample_rngs)

I get the following error (I've tried a few times):

---------------------------------------------------------------------------
UnfilteredStackTrace                      Traceback (most recent call last)
[<ipython-input-9-f2eb36cfcc0b>](https://localhost:8080/#) in <module>()
     17     sample_rngs = jax.random.split(sample_rng, jax.local_device_count())
---> 18     results = p_generate_256_samples(pmap_input_tokens, sample_rngs)
     19 

10 frames
UnfilteredStackTrace: RuntimeError: UNKNOWN: CUDNN_STATUS_NOT_SUPPORTED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4839): 'status'

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
[<ipython-input-9-f2eb36cfcc0b>](https://localhost:8080/#) in <module>()
     16 elif run_mode == 'pmap':
     17     sample_rngs = jax.random.split(sample_rng, jax.local_device_count())
---> 18     results = p_generate_256_samples(pmap_input_tokens, sample_rngs)
     19 
     20     # flatten the pmap results

RuntimeError: UNKNOWN: CUDNN_STATUS_NOT_SUPPORTED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4839): 'status'

It may also be notable that the function runs for a long time (20+ seconds) before throwing that error despite the fact that pmap should make it fast. Since I get an actual out-of-memory crash here on other GPUs with less RAM, I figured there may be a chance that this is a different issue.

Thx again for sharing the code :)

commented

Oh, and here is my !nvidia-smi:

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla V100-SXM2...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   35C    P0    40W / 300W |  16027MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
+-----------------------------------------------------------------------------+