google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Home Page:http://jax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

JAX 0.3.5 stalls on TPU Pods

wilson1yan opened this issue · comments

This tutorial does not work when using the latest jax version 0.3.5. Specifically, code will hang whenever jax.device_count() or jax.local_device_count() is called.

The code prints the following and then stalls.

>>> import jax
>>> jax.local_device_count()
E0410 21:04:57.403257842   30715 f758.cc:310]                no server name supplied in dns URI
E0410 21:04:57.403295131   30715 f872.cc:77]                 channel stack builder failed: {"created":"@1649624697.403284903","description":"the target uri is not valid: dns:","file":"f814.cc","file_line":1090}

This issue does not happen if I install 0.3.4, or run the same code (with 0.3.5) on a non-pod instance like v2-8.

Can you please verify you have the same libtpu_nightly and jaxlib versions installed on all VMs in the TPU pod?

Yes, all VMs have jaxlib==0.3.5 and libtpu-nightly==0.1.dev20220407

No updates yet, but we can reproduce the problem and are looking into it.

We just released jax 0.3.6 with a new libtpu to fix this issue. Please upgrade to jax 0.3.6!