JAX 0.3.5 stalls on TPU Pods
wilson1yan opened this issue · comments
This tutorial does not work when using the latest jax version 0.3.5
. Specifically, code will hang whenever jax.device_count()
or jax.local_device_count()
is called.
The code prints the following and then stalls.
>>> import jax
>>> jax.local_device_count()
E0410 21:04:57.403257842 30715 f758.cc:310] no server name supplied in dns URI
E0410 21:04:57.403295131 30715 f872.cc:77] channel stack builder failed: {"created":"@1649624697.403284903","description":"the target uri is not valid: dns:","file":"f814.cc","file_line":1090}
This issue does not happen if I install 0.3.4
, or run the same code (with 0.3.5
) on a non-pod instance like v2-8
.
Can you please verify you have the same libtpu_nightly
and jaxlib
versions installed on all VMs in the TPU pod?
Yes, all VMs have jaxlib==0.3.5
and libtpu-nightly==0.1.dev20220407
No updates yet, but we can reproduce the problem and are looking into it.
We just released jax 0.3.6 with a new libtpu to fix this issue. Please upgrade to jax 0.3.6!