google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Home Page:http://jax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Extremely long compilation time for pmap on a5000 GPUs

C-J-Cundy opened this issue · comments

Hi,
When running a very simple parallel MLP example on an a5000 GPU, the compile time is thousands of times slower compared to other GPUs.

A minimal example can be found in this colab: https://colab.research.google.com/drive/1HTvcKQ4ozmdftXDkB715BKeEOpI3hkAX#scrollTo=LGB0i9LkPgER

When running on a single 2080ti, the script compiles in 0.4s. When two 2080tis are available, the compilation takes 1.1s. For a single a5000, it also takes around 0.4s. When run with two a5000s available, the compilation takes over an hour (I killed it after one hour so didn't see how long it took).

A different script (training a language model) that was assigned to a5000s was still compiling after 48 hours when it took a minute to compile on an a4000.
This is with jax 0.3.7, jaxlib 0.3.7+cuda11.cudnn805, cuda 11.2

That's odd! The timing shouldn't be significantly different. But this might be hard to debug if we can't find a GPU on which it reproduces. I'll see if it reproduces on A100, which is the closest match I have.

One quick thing you might look at: what does top show during the long compilation? Is it the Python process taking the time or ptxas? How do you know it is compilation taking the time (as opposed to, say, a deadlock early in execution)?

Hi Chris, can you run gdb with your example and collect callstack backtraces to see where the program is hanging? It will help us to confirm what is actually going on. Thanks!

(And for reference, you can get a backtrace by running your python process under gdb (or attaching to a running Python process using the gdb command attach PID), and then using bt to get a backtrace.)

We tried on both A100 and A6000 and were unable to reproduce.

Ok, I will definitely do so when I get the chance. That might be a while as the A5000s are under high usage by my group in the run-up to neurips.

I've been unable to reproduce this when I got a chance. The same script that previously caused the GPUs to hang now compiles and runs fine. I haven't changed the version of jax or anything, so this is a bit odd, but I'm not complaining.

If it happens again, please grab a stack trace! Thanks for the report.