google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Home Page:http://jax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Low-level error when calling `lax.convert_element_type_p` with uncanocinalized `new_dtype`

romanngg opened this issue · comments

Tiny repro: https://colab.research.google.com/gist/icml2022anon/a368de2bbbd7ed6a6b4ea12a6966683c/jax_type_error.ipynb

from jax import numpy as np, lax, jacfwd

a = np.ones((1,), dtype=np.float16)
f = lambda x: lax.convert_element_type_p.bind(x, new_dtype=np.float64, weak_type=False)

jacfwd(f)(a)

causes

WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
---------------------------------------------------------------------------
UnfilteredStackTrace                      Traceback (most recent call last)
<ipython-input-1-b663e19f60d4> in <module>()
      5 
----> 6 jacfwd(f)(a)

27 frames
UnfilteredStackTrace: RuntimeError: UNKNOWN: <unknown>:0: error: type of return operand 0 ('tensor<1x1xf64>') doesn't match function result type ('tensor<1x1xf32>') in function @main
<unknown>:0: note: see current operation: "func.return"(%0) : (tensor<1x1xf64>) -> ()

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py in transpose(a, axes)
    494   _check_arraylike("transpose", a)
    495   axes = np.arange(ndim(a))[::-1] if axes is None else axes
--> 496   return lax.transpose(a, axes)
    497 
    498 

RuntimeError: UNKNOWN: <unknown>:0: error: type of return operand 0 ('tensor<1x1xf64>') doesn't match function result type ('tensor<1x1xf32>') in function @main
<unknown>:0: note: see current operation: "func.return"(%0) : (tensor<1x1xf64>) -> ()

As @mattjj told me in the chat, this can be fixed via

  1. call lax.convert_element_type (or even lax._convert_element_type if you really need to control weak_type) rather than bind
  2. pass new_dtype=jax.dtypes.canonicalize_dtype(np.float64)
  3. enable the x64 flag

but is still considered a bug.

Thanks for the report! Here's a slightly simpler repo:

import jax.numpy as jnp
from jax import lax

y = lax.convert_element_type_p.bind(jnp.ones((2, 3)), new_dtype=jnp.float64, weak_type=False)
z = y.T

The issue here is that calling the primitive directly side-steps around the X64 flag, so that y has type float64. Then the dtype rule for lax.transpose will return the wrong type, meaning the aval and the buffer do not match.

You can achieve the same thing this way:

with jax.experimental.enable_x64():
  y = jnp.ones((2, 3), dtype='float64')
z = y.T

Essentially any standard lax primitives which uses _input_dtype for its dtype rule is going to have this same problem when faced with a 64-bit input while the X64 flag is set to False.

Incidentally, it's exactly this kind of difficulty with the x64 context manager that's prevented us from making it a supported API, and is further convincing us that the whole X64 mode idea should be removed.

This is similar in spirit to the errors reported in #5982

You've just found a clever way to enable X64 mode without the context manager 😁

Thanks for opening this issue, Roman. After thinking about this, I think we should consider it intended behavior for internal APIs. It is a sharp edge related to x64, but fortunately not one in public APIs!