google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Home Page:http://jax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Provide Windows binaries on PyPI

cool-RR opened this issue · comments

I'm on Windows and I don't know how to build packages. I'd like to run pip install jaxlib and have it work.

I saw there's discussion on #438, but that ticket was about supporting Windows. It seems that Windows is partially supported now, so this is a separate ticket for providing a Windows binary on PyPI.

Will this ticket cover conda installs on windows as well?

+1

Here are a few precompiled Windows wheels for jaxlib (unsupported)
https://github.com/erwincoumans/jax/releases/tag/winwhl-0.1.61

I compiled the latest main-branch jaxlib on Windows with surprisingly no problems at all.

  • Bazel 4.1
  • Python 3.9.6
  • CUDA 11.4

I maintain a public windows builder at https://github.com/cloudhan/jax-windows-builder via Github Actions

Disclaimer: I am a Microsoft empolyee but this builder is not an effort from Microsoft.

commented

@cloudhan Amazing. Thank you for providing this!

I just wanted to chime in here and say that while its great that someone has created a github repo with windows wheels, we are still constantly getting error reports from users who can't install our python library on windows because there is no jaxlib wheels on PyPI.

I know that it might not be a priority, but considering theres an Open Source fork already that does it, maybe it's worth prioritising. I am noticing this issue across many of Googles open source ML libraries where there is either no Windows wheels or no Arm64 wheels on Linux.

This means all down stream libraries that try to use Google ML libraries essentially have no windows or arm64 linux (think Docker Containers on Apple Silicon) support out of the box.

To work around this, we either compile our own and provide pre-built wheels into the container build process or use a bootstrapping tool which detects unsupported platforms and provides instructions on how to install these third party forks.

I am not sure if it is helpful to simplify things, but there are a lot of CPU use-cases for JAX. So if CUDA support is the issue then what about having official CPU-only binaries to start with? Besides convenience of installation, having versioned binaries is essential for reproducibility.

At this point I cannot recommend JAX or teach with it since many students (and me, for that matter) use Windows. WSL is not a solution, nor is Colab, when pytorch and others work perfectly fine across various OSs.

Good news! We just shipped experimental jaxlib CPU-only wheels on pypi for JAX 0.4.13 for Python versions 3.9-3.11. It will probably take a few releases to mature fully and for us to shake out all the bugs, so please report problems.

We'd like to thank everyone from the community that has helped get us to this point! (e.g @cloudhan @mlxd and I'm sure there are others I'm forgetting.)

JAX on Windows is a community-supported effort, but we now have regular CI testing and we're hoping to release CPU-only wheels regularly from now on. PRs to make it even better are very welcome!