google-deepmind / deepmind-research

This repository contains implementations and illustrative code to accompany DeepMind publications

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'

caihongch opened this issue · comments

Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'

how can i solve this problem?

From: https://pypi.org/project/jax/

pip install --upgrade pip

CUDA 12 installation

  • Note: wheels only available on linux.
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

CUDA 11 installation

  • Note: wheels only available on linux.
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Try work

from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

Hi, I tried your way and installed jax gpu version using pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html, and ran the below codes,

from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform), it is showing,

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
cpu

my Gpu specification is a follows,
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.125.06 Driver Version: 525.125.06 CUDA Version: 12.0 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA GeForce ... Off | 00000000:01:00.0 On | Off |
| 0% 49C P8 24W / 450W | 376MiB / 24564MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| 0 N/A N/A 2020 G /usr/lib/xorg/Xorg 140MiB |
| 0 N/A N/A 2151 G /usr/bin/gnome-shell 83MiB |
| 0 N/A N/A 3362 G ...0/usr/lib/firefox/firefox 150MiB |
+-----------------------------------------------------------------------------+
and I am using python 3.11.0 version
I also Installed, some additional dependecies like,
numpy>=1.16.4
jax>=0.2.6
jaxlib>=0.1.69
flax>=0.2.2
opencv-python>=4.4.0
Pillow>=7.2.0
pyyaml>=5.3.1
scipy>=1.4.1
tensorboard>=2.4.0
tensorflow>=2.3.1
tensorflow-hub>=0.11.0

why my jax is not detecting the gpu/tpu, i am running on ubuntu 22.04 version?

thank you, this worked
Though I had to first go through all this
CUDA Toolkit: I ran it with 11.8
cuDNN: I ran it with 8.9.23.28
TensortRT: U ran it with version for 11.8
Then
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html