google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Home Page:http://jax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Add support for cuda 11.3

VanWieren opened this issue · comments

As far as I can tell the only gpu support is for cuda <=11.2

Actually: good news! The CUDA 11.1 wheel should work for 11.2 and 11.3. We're going to document this in the next release of jaxlib.

With the released CUDA 11.1 jaxlib wheel, you may need to set the environment variable XLA_FLAGS=--xla_gpu_cuda_data_dir=/usr/local/cuda/, but that should be fixed with the next release.

Please try it out and let us know if it works!

We just released jaxlib 0.1.66, which should support CUDA 11.3 without any workarounds. Install the cuda111 variant of the wheel, which should work on CUDA 11.3 without any special setup.

Hope that helps!