google / trax

Trax — Deep Learning with Clear Code and Speed

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Trax ML: GPU memory allocated but completed on CPU

opened this issue · comments

Description

Hi, I have run the example ende translation script and have installed jax+cuda so that I do not get the typical "No GPU/TPU found, falling back to CPU" error. But using 'nvidia-smi' and 'top' it appears most of my GPU memory is being allocated by jax but the GPU itself is not being used and instead my computer CPU is working at 100%. I have checked which device jax is using and it is saying gpu:0. Eventually it does translate, but it is very slow and I strongly suspect the computation is happening on my CPU as shown below. Why might this be?

Environment information

OS: Ubuntu 20.04
CUDA 11.2
Jax 0.2.17
Trax 1.3.9

$ pip freeze | grep trax
trax==1.3.9

$ pip freeze | grep tensor
mesh-tensorflow==0.1.19
tensorboard==2.5.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.0
tensorflow==2.5.0
tensorflow-datasets==4.3.0
tensorflow-estimator==2.5.0
tensorflow-hub==0.12.0
tensorflow-metadata==1.1.0
tensorflow-text==2.5.0

$ pip freeze | grep jax
jax==0.2.17
jaxlib==0.1.68+cuda101

$ python3 -V
Python 3.8.10

>>> from jax import numpy as jnp
>>> print(jnp.ones(3).device_buffer.device())
gpu:0

Nvidia-smi

NVIDIA GeForce GTX 1650 Ti
Memory-Usage 3876/3914 MiB
Voltatile GPU-Util: 9%

GPU Memory Allocation

+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| 0 N/A N/A 1046 G /usr/lib/xorg/Xorg 45MiB |
| 0 N/A N/A 1621 G /usr/lib/xorg/Xorg 124MiB |
| 0 N/A N/A 1795 G /usr/bin/gnome-shell 100MiB |
| 0 N/A N/A 2981 C python3 3541MiB |

CPU Usage

Command: python3
%CPU: 100.0
%MEM: 19.6

I'm having the exact same problem. When I run the Tensorboard profiler/trace viewer, it shows a tiny bit of startup activity in the GPU, then nothing - it's all CPU from there. I can run a raw jax loop and the GPU is used fine in the same notebook. I'm at my wits end with this - Trax is way, way too slow to be useful on the CPU. How is anyone getting this to work?

I experience this issue too, however is on training. I have try to install different jax and jaxlib version, but no help. I have no idea, anyone can help?

I'm facing the same exact problem! I've been investigating this issue for 3 days and I couldn't find a solution!

Although when I use Jax alone or tensor-flow alone and monitor the GPU, I see that they are using the GPU properly! but from Trax just the ('tensor-flow numpy backend') is what using the GPU (memory and computation) but when I set the backend to ('jax') then just the memory is used without any computation!

Any help?

Same problem for me,
please post if you found a solution.

setting "trax.fastmath.set_backend('tensorflow-numpy')" seems to help, I can see the gpu cycles being used.