google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Home Page:http://jax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

TPUv3 CHECK failure at all_gather_emitter.cc:1731) hlo_instruction_->operand(0)->shape().element_type() == hlo_instruction_->shape().element_type()

zhangqiaorjc opened this issue · comments

This tracks a CloudTPU issue faced by one of our customers. The root cause is in XLA:TPU

We saw the following error on TPUv3 with jax==0.3.6 and jaxlib==0.3.5 on a ResNet 18 model:

  File "/usr/local/lib/python3.8/dist-packages/jax/experimental/pjit.py", line 288, in wrapped
    out = pjit_p.bind(*args_flat, **params)
  File "/usr/local/lib/python3.8/dist-packages/jax/core.py", line 323, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/usr/local/lib/python3.8/dist-packages/jax/core.py", line 326, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/usr/local/lib/python3.8/dist-packages/jax/core.py", line 657, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/usr/local/lib/python3.8/dist-packages/jax/experimental/pjit.py", line 631, in _pjit_call_impl
    compiled = _pjit_lower(
  File "/usr/local/lib/python3.8/dist-packages/jax/interpreters/pxla.py", line 2346, in compile
    self._executable = MeshExecutable.from_hlo(
  File "/usr/local/lib/python3.8/dist-packages/jax/interpreters/pxla.py", line 2456, in from_hlo
    xla_executable = dispatch.compile_or_get_cached(backend, computation, compile_options)
  File "/usr/local/lib/python3.8/dist-packages/jax/_src/dispatch.py", line 664, in compile_or_get_cached
    return backend_compile(backend, computation, compile_options)
  File "/usr/local/lib/python3.8/dist-packages/jax/_src/profiler.py", line 206, in wrapper
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/jax/_src/dispatch.py", line 618, in backend_compile
    return backend.compile(built_c, compile_options=options)
RuntimeError: INTERNAL: RET_CHECK failure (platforms/xla/service/ba16c7433/lowering/[all_gather_emitter.cc:1731](http://all_gather_emitter.cc:1731/)) hlo_instruction_->operand(0)->shape().element_type() == hlo_instruction_->shape().element_type()

We tracked this issue down to a number of (semi-private) performance flags were being passed to libtpu, but those flags do not apply or work on TPUv3. The fix is not to do that.