TPUv3 CHECK failure at all_gather_emitter.cc:1731) hlo_instruction_->operand(0)->shape().element_type() == hlo_instruction_->shape().element_type()
zhangqiaorjc opened this issue · comments
Qiao Zhang commented
This tracks a CloudTPU issue faced by one of our customers. The root cause is in XLA:TPU
We saw the following error on TPUv3 with jax==0.3.6 and jaxlib==0.3.5 on a ResNet 18 model:
File "/usr/local/lib/python3.8/dist-packages/jax/experimental/pjit.py", line 288, in wrapped
out = pjit_p.bind(*args_flat, **params)
File "/usr/local/lib/python3.8/dist-packages/jax/core.py", line 323, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/usr/local/lib/python3.8/dist-packages/jax/core.py", line 326, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/usr/local/lib/python3.8/dist-packages/jax/core.py", line 657, in process_primitive
return primitive.impl(*tracers, **params)
File "/usr/local/lib/python3.8/dist-packages/jax/experimental/pjit.py", line 631, in _pjit_call_impl
compiled = _pjit_lower(
File "/usr/local/lib/python3.8/dist-packages/jax/interpreters/pxla.py", line 2346, in compile
self._executable = MeshExecutable.from_hlo(
File "/usr/local/lib/python3.8/dist-packages/jax/interpreters/pxla.py", line 2456, in from_hlo
xla_executable = dispatch.compile_or_get_cached(backend, computation, compile_options)
File "/usr/local/lib/python3.8/dist-packages/jax/_src/dispatch.py", line 664, in compile_or_get_cached
return backend_compile(backend, computation, compile_options)
File "/usr/local/lib/python3.8/dist-packages/jax/_src/profiler.py", line 206, in wrapper
return func(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/jax/_src/dispatch.py", line 618, in backend_compile
return backend.compile(built_c, compile_options=options)
RuntimeError: INTERNAL: RET_CHECK failure (platforms/xla/service/ba16c7433/lowering/[all_gather_emitter.cc:1731](http://all_gather_emitter.cc:1731/)) hlo_instruction_->operand(0)->shape().element_type() == hlo_instruction_->shape().element_type()
Peter Hawkins commented
We tracked this issue down to a number of (semi-private) performance flags were being passed to libtpu, but those flags do not apply or work on TPUv3. The fix is not to do that.