google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Home Page:http://jax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Jax doesn't see my GPU, even though Pytorch does

asemic-horizon opened this issue · comments

Jax sounds like an impressive project, thanks for working on it.

That said: on Ubuntu 18.04, this happens

➜  python
Python 3.6.9 (default, Oct  8 2020, 12:12:24) 
[GCC 8.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.

>>> import torch, jax; print(torch.cuda.is_available()); print(jax.devices())
True
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[CpuDevice(id=0)]

I first tried to pip install jax and got various errors; the error messages said it was common with old versions of pip that didn't support newer kinds of wheel and directed me to upgrade pip, which I did (from 9 to 20). Now Jax seems to be installed (at least various numpy-compatible functions do), but not to the point where it appears to see my GPU, a laptop "Geforce" card by Nvidia.

I'm not sure what systems diagnostics I can bring to help. This is the name of my card:

✗  lspci | grep NVIDIA
01:00.0 3D controller: NVIDIA Corporation GP108M [GeForce MX150] (rev a1)

As best as I understand it, these are drivers:

✗  lsmod | grep nvidia
nvidia_uvm            970752  0
nvidia_drm             53248  3
nvidia_modeset       1212416  2 nvidia_drm
nvidia              27643904  103 nvidia_uvm,nvidia_modeset
drm_kms_helper        184320  2 nvidia_drm,i915
drm                   491520  8 drm_kms_helper,nvidia_drm,i915

Thanks for reading this anyway.

Hello 👋 Just to confirm - did you follow the Linux-specific installation instructions from the README? Also, have you tried installing JAX in a separate virtual environment that excludes PyTorch? 🤷‍♀️

https://github.com/google/jax#installation

On Linux, it is often necessary to first update pip to a version that supports manylinux2010 wheels.

If you want to install JAX with both CPU and GPU support, using existing CUDA and CUDNN7 installations on your machine (for example, preinstalled on your cloud VM), you can run

pip install --upgrade pip
pip install --upgrade jax jaxlib==0.1.57+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html

The jaxlib version must correspond to the version of the existing CUDA installation you want to use, with cuda110 for CUDA 11.0, cuda102 for CUDA 10.2, and cuda101 for CUDA 10.1. You can find your CUDA version with: install path:

nvcc --version

Note that some GPU functionality expects the CUDA installation to be at /usr/local/cuda-X.X, where X.X should be replaced with the CUDA version number (e.g. cuda-10.2). If CUDA is installed elsewhere on your system, you can either create a symlink:

sudo ln -s /path/to/cuda /usr/local/cuda-X.X

Or set the following environment variable before importing JAX:

XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda

I had the same issue, but managed to solve it. It seems pytorch bundles its own cuda, so that's why you don't have to install it separately but it sees your gpu and nvidia-smi works. Installling cuda for your GPU following these instructions solved the issue for me: https://developer.nvidia.com/cuda-downloads

@GJBoth That's awesome. Can you confirm the following (and please correct me if I'm wrong):

  • Having JAX in a separate env should technically help identify if JAX can detect your CUDA (and not PyTorch's bundled one)
  • The current instructions assume that you've taken care of your CUDA installation (see extract below) but maybe it would help to nudge the users to go to https://developer.nvidia.com/cuda-downloads and install CUDA, if they haven't already.

"... using existing CUDA and CUDNN7 installations on your machine (for example, preinstalled on your cloud VM)..."
...
"Note that some GPU functionality expects the CUDA installation to be at /usr/local/cuda-X.X, where X.X should be replaced with the CUDA version number (e.g. cuda-10.2). If CUDA is installed elsewhere on your system..."

WDYT? @asemic-horizon @GJBoth

I have no experience installing cuda in a specific env. It seems that the symbolic link wouldn't work, see this thread:
https://discuss.pytorch.org/t/where-is-cudatoolkit-path-when-installed-via-conda/47791/5

I think many new jax users will come from pytorch, so adding a nudge a la 'If you're coming from pytorch, make sure to install cuda separately, if you haven't yet.') Two more observations: I could run nvidia-smi, but not nvcc, so this might be a nice check to see if you have pytorch cuda or systemwide. Furthermore, jupyter notebooks tend to die silently with these issues, so running things as a script gives you much more info.

I had the same problem with jax not recognizing GPU. I did the following two steps:

pip install --upgrade pip
pip install --upgrade jax jaxlib==0.1.57+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html

and make a link as follows:

ln -s /usr/lib/nvidia-cuda-toolkit /usr/local/cuda-10.1

After this, jax still didn't recognize GPU. Then, I did the following steps hinted from the warning message in jax about GPU:

cd /usr/lib/nvidia-cuda-toolkit
mkdir nvvm
cd nvvm
sudo ln -s /usr/lib/nvidia-cuda-toolkit/libdevice libdevice

You would need to use "sudo" for the above steps. After these, jax recognises my GPU.

sudo ln -s /usr/lib/nvidia-cuda-toolkit/libdevice libdevice

I am remotely connected to a slurm cluster and do not have sudo rights. In fact, I do not even have permission to make a symbolic link.
Plus, my environment has no GPU, the GPU is assigned via the sbatch job file using the command "#SBATCH --gres=gpu:1" .
This is way too complicated. Yet, PyTorch seems to work perfectly well.

sudo ln -s /usr/lib/nvidia-cuda-toolkit/libdevice libdevice

I am remotely connected to a slurm cluster and do not have sudo rights. In fact, I do not even have permission to make a symbolic link.
Plus, my environment has no GPU, the GPU is assigned via the sbatch job file using the command "#SBATCH --gres=gpu:1" .
This is way too complicated. Yet, PyTorch seems to work perfectly well.

Hi, I have the same problem with remotely connected to a slurm cluster. Do you solve this issue? How to solve it?

Do you solve this issue?

Unfortunately not. I don't have sudo control over the cluster and this makes it hard. The best way for JAX is that they ship a cuda bundle with the installation, similar to PyTorch.

FYI #6581 bundles libdevice.10.bc with jaxlib wheels, which hopefully will help avoid this particular problem. If you are feeling motivated, you could try patching in that PR and building a jaxlib from source to see if it fixes your problems.

Similar to myjr52, I was able to solve this simply by replacing this:

pip install --upgrade jax jaxlib

with this (you'll have the change cuda111 based on your output of nvcc --version - mine is 11.1):

pip install --upgrade jax jaxlib==0.1.69+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html

I didn't need to do any of the additional steps mentioned by myjr52.

Guess it's a bit late for this. But I got mine fixed by specifying the exact whl link found in the https://storage.googleapis.com/jax-releases/jax_releases.html. Just I need cuda 11.0.

The one I used was:

pip uninstall jax jaxlib -y
pip install https://storage.googleapis.com/jax-releases/cuda110/jaxlib-0.1.71+cuda110-cp38-none-manylinux2010_x86_64.whl

Do you solve this issue?

Unfortunately not. I don't have sudo control over the cluster and this makes it hard. The best way for JAX is that they ship a cuda bundle with the installation, similar to PyTorch.

@morawi Curious to know if this is solved for you yet since I'm going through the same thing with JAX on a slurm cluster

Do you solve this issue?

Unfortunately not. I don't have sudo control over the cluster and this makes it hard. The best way for JAX is that they ship a cuda bundle with the installation, similar to PyTorch.

@morawi Curious to know if this is solved for you yet since I'm going through the same thing with JAX on a slurm cluster

I just stopped using it.

I have solved this problem very easily just following this issue, some googling and stumbling upon two SO questions, and the readme of this project.

I had nvidia drivers installed in my laptop through the Pop OS store, and I installed nvidia-cuda-toolkit through apt, and then installed PyTorch (earlier).

My cuda version is 11.2.

I did not have to do any other installation.

  1. I upgraded pip.
  2. I installed jax[cuda11] instead if just jax.
  3. Followed other generic instructions from the ReadMe.
  4. Created two symlinks- one for nvcc and another for cuda.

Then it started working great.

If someone else stumbles into this, the CUDA wheel releases are now stored on https://storage.googleapis.com/jax-releases/jax_cuda_releases.html for some reason.

Same issue, using

FROM nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu22.04
...
RUN python3 -m pipinstall "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
... 

Jax doesn't see the gpu

import jax
print(jax.devices())

only cpu

Hello,

same issue here! Torch can find my GPU, JAX does not!

nvcc --version                                                                             [16:16:43]
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Wed_Jun__8_16:49:14_PDT_2022
Cuda compilation tools, release 11.7, V11.7.99
Build cuda_11.7.r11.7/compiler.31442593_0

pacman -Q cudnn                                                                            [16:37:00]
cudnn 8.5.0.96-1

I followed the installation guide here:

When I run
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
It first asks me for the storage.googleapis.com user (WTF ??)
then I just press enter and it return an access error!

Collecting jaxlib==0.3.22+cuda11.cudnn82
User for storage.googleapis.com:   WARNING: 401 Error, Credentials not correct for https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.3.22%2Bcuda11.cudnn82-cp310-cp310-manylinux2014_x86_64.whl
  ERROR: HTTP error 401 while getting https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.3.22%2Bcuda11.cudnn82-cp310-cp310-manylinux2014_x86_64.whl (from https://storage.googleapis.com/jax-releases/jax_cuda_releases.html)
ERROR: Could not install requirement jaxlib==0.3.22+cuda11.cudnn82 from https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.3.22%2Bcuda11.cudnn82-cp310-cp310-manylinux2014_x86_64.whl (from jax[cuda]) because of HTTP error 401 Client Error: Unauthorized for url: https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.3.22%2Bcuda11.cudnn82-cp310-cp310-manylinux2014_x86_64.whl for URL https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.3.22%2Bcuda11.cudnn82-cp310-cp310-manylinux2014_x86_64.whl (from https://storage.googleapis.com/jax-releases/jax_cuda_releases.html)
FAIL: 1


If I run

pip install --upgrade "jax[cuda]" or pip install "jax[cuda11_cudnn82]"

Then it unrolls the version till the very first that has not cuda! (WTF^2)

....
  Preparing metadata (setup.py) ... done
  Using cached jax-0.2.25.tar.gz (786 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.2.24.tar.gz (786 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.2.22-py3-none-any.whl
WARNING: jax 0.2.22 does not provide the extra 'cuda'
Installing collected packages: jax
  Attempting uninstall: jax
    Found existing installation: jax 0.3.22
    Uninstalling jax-0.3.22:
      Successfully uninstalled jax-0.3.22
Successfully installed jax-0.2.22

The pip install --upgrade "jax[all]" runs just fine: Successfully installed jax-0.3.22
but GPU access is not available (see topmost).

Thanks for the help

Update:

If I download my appropriate version

[cuda11/jaxlib-0.3.22+cuda11.cudnn82-cp39-cp39-manylinux2014_x86_64.whl](https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.3.22+cuda11.cudnn82-cp39-cp39-manylinux2014_x86_64.whl)

and run:

pip install --upgrade "jax[cuda11_cudnn82]" -f ~/Downloads/jax

Here we go again:

[16:53:15]
Looking in links: /home/cocconat/Downloads/jax
Requirement already satisfied: jax[cuda11_cudnn82] in /home/cocconat/.virtualenvs/learning/lib/python3.10/site-packages (0.3.22)
Requirement already satisfied: absl-py in /home/cocconat/.virtualenvs/learning/lib/python3.10/site-packages (from jax[cuda11_cudnn82]) (1.2.0)
Requirement already satisfied: opt-einsum in /home/cocconat/.virtualenvs/learning/lib/python3.10/site-packages (from jax[cuda11_cudnn82]) (3.3.0)
Requirement already satisfied: numpy>=1.20 in /home/cocconat/.virtualenvs/learning/lib/python3.10/site-packages (from jax[cuda11_cudnn82]) (1.23.3)
Requirement already satisfied: etils[epath] in /home/cocconat/.virtualenvs/learning/lib/python3.10/site-packages (from jax[cuda11_cudnn82]) (0.8.0)
Requirement already satisfied: typing-extensions in /home/cocconat/.virtualenvs/learning/lib/python3.10/site-packages (from jax[cuda11_cudnn82]) (4.3.0)
Requirement already satisfied: scipy>=1.5 in /home/cocconat/.virtualenvs/learning/lib/python3.10/site-packages (from jax[cuda11_cudnn82]) (1.9.1)
Collecting jax[cuda11_cudnn82]
  Using cached jax-0.3.22-py3-none-any.whl
  Using cached jax-0.3.21.tar.gz (1.1 MB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.20.tar.gz (1.1 MB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.19.tar.gz (1.1 MB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.17-py3-none-any.whl
  Using cached jax-0.3.16.tar.gz (1.0 MB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.15.tar.gz (1.0 MB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.14.tar.gz (990 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.13.tar.gz (951 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.12.tar.gz (947 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.11.tar.gz (947 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.10.tar.gz (939 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.9.tar.gz (937 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.8.tar.gz (935 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.7.tar.gz (944 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.6.tar.gz (936 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.5.tar.gz (946 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.4.tar.gz (924 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.3.tar.gz (924 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.2.tar.gz (926 kB)
^C  Preparing metadata (setup.py) ... canceled
ERROR: Operation cancelled by user
FAIL: 1


@aquaresima Please open a new issue, please don't ping long-closed issues.