StephennFernandes / t5x_cuda

A working t5x repo thats executable on nvidia GPUs. compatible to pretrain models on 2 A6000 #note: designed for personal usage, use on your own caution

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

t5x training on a GPU machine locally.

FatemehMashhadi opened this issue · comments

Hey,
I’ve been trying to pretraining on a single node GPU machine. cuda 11.8 and cudnn 8.8.
I’m follwing t5x/contrib/gpu/scripts_gpu instruction but unable train on GPU.

I0513 15:22:19.096940 139889676502848 partitioning.py:452] `activation_partitioning_dims` = 1, `parameter_partitioning_dims` = 1
I0513 15:22:19.102880 139889676502848 xla_bridge.py:455] Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I0513 15:22:19.102970 139889676502848 xla_bridge.py:455] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I0513 15:22:19.103725 139889676502848 xla_bridge.py:455] Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
I0513 15:22:19.103788 139889676502848 xla_bridge.py:455] Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.
W0513 15:22:19.103837 139889676502848 xla_bridge.py:463] No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

Can you help me to solve it?

Hey it seems like your flax version that's installed on your machine is only the CPU version and not the GPU version.

I had seen this error, months ago on my system as well

I install jax, jaxlib from dockerfile and install flax from T5x setup. (pip install flax @ git+https://github.com/google/flax#egg=flax)

RUN pip install --upgrade "jax[cuda11_pip]"==0.4.1 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 
RUN pip install --upgrade "jaxlib[cuda11_pip]"==0.4.1 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 
RUN pip install -e '.[test,gpu]'

How to install flax GPU version?

I'd recommend using one of our pre-built JAX Toolbox containers which are validated with a nightly CI on NVIDIA GPUs.