google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Home Page:http://jax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Running multiple runs on one A100

linyuhongg opened this issue · comments

I have reproduced a paper using JAX, the algorithm performs at least 2 times faster than the PyTorch counterpart. Both programs only use a fraction of the GPU (~10%), this makes it possible to run multiple programs in a single GPU.

However, whenever I have multiple runs on the same GPU, JAX becomes extremely slow while the PyTorch counterpart is not affected.

I am using jax 0.3.4 and jaxlib 0.3.2+cuda11.cudnn82. And when running XLA_PYTHON_CLIENT_PREALLOCATE is set to False

I was wondering if this is due to my code or JAX?

Is it possible to share instructions to reproduce the problem, ideally with a small program?

What does nvidia-smi show while the program is running?

Sorry for the inconvenience, I have found the problem and it is not related to JAX.

It works perfectly fine now.