Running multiple runs on one A100
linyuhongg opened this issue · comments
I have reproduced a paper using JAX, the algorithm performs at least 2 times faster than the PyTorch counterpart. Both programs only use a fraction of the GPU (~10%), this makes it possible to run multiple programs in a single GPU.
However, whenever I have multiple runs on the same GPU, JAX becomes extremely slow while the PyTorch counterpart is not affected.
I am using jax 0.3.4 and jaxlib 0.3.2+cuda11.cudnn82. And when running XLA_PYTHON_CLIENT_PREALLOCATE is set to False
I was wondering if this is due to my code or JAX?
Is it possible to share instructions to reproduce the problem, ideally with a small program?
What does nvidia-smi
show while the program is running?
Sorry for the inconvenience, I have found the problem and it is not related to JAX.
It works perfectly fine now.