Low-level error when calling `lax.convert_element_type_p` with uncanocinalized `new_dtype`
romanngg opened this issue · comments
Tiny repro: https://colab.research.google.com/gist/icml2022anon/a368de2bbbd7ed6a6b4ea12a6966683c/jax_type_error.ipynb
from jax import numpy as np, lax, jacfwd
a = np.ones((1,), dtype=np.float16)
f = lambda x: lax.convert_element_type_p.bind(x, new_dtype=np.float64, weak_type=False)
jacfwd(f)(a)
causes
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
---------------------------------------------------------------------------
UnfilteredStackTrace Traceback (most recent call last)
<ipython-input-1-b663e19f60d4> in <module>()
5
----> 6 jacfwd(f)(a)
27 frames
UnfilteredStackTrace: RuntimeError: UNKNOWN: <unknown>:0: error: type of return operand 0 ('tensor<1x1xf64>') doesn't match function result type ('tensor<1x1xf32>') in function @main
<unknown>:0: note: see current operation: "func.return"(%0) : (tensor<1x1xf64>) -> ()
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
RuntimeError Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py in transpose(a, axes)
494 _check_arraylike("transpose", a)
495 axes = np.arange(ndim(a))[::-1] if axes is None else axes
--> 496 return lax.transpose(a, axes)
497
498
RuntimeError: UNKNOWN: <unknown>:0: error: type of return operand 0 ('tensor<1x1xf64>') doesn't match function result type ('tensor<1x1xf32>') in function @main
<unknown>:0: note: see current operation: "func.return"(%0) : (tensor<1x1xf64>) -> ()
As @mattjj told me in the chat, this can be fixed via
- call lax.convert_element_type (or even lax._convert_element_type if you really need to control weak_type) rather than bind
- pass new_dtype=jax.dtypes.canonicalize_dtype(np.float64)
- enable the x64 flag
but is still considered a bug.
Thanks for the report! Here's a slightly simpler repo:
import jax.numpy as jnp
from jax import lax
y = lax.convert_element_type_p.bind(jnp.ones((2, 3)), new_dtype=jnp.float64, weak_type=False)
z = y.T
The issue here is that calling the primitive directly side-steps around the X64 flag, so that y
has type float64
. Then the dtype rule for lax.transpose
will return the wrong type, meaning the aval and the buffer do not match.
You can achieve the same thing this way:
with jax.experimental.enable_x64():
y = jnp.ones((2, 3), dtype='float64')
z = y.T
Essentially any standard lax primitives which uses _input_dtype
for its dtype rule is going to have this same problem when faced with a 64-bit input while the X64 flag is set to False.
Incidentally, it's exactly this kind of difficulty with the x64 context manager that's prevented us from making it a supported API, and is further convincing us that the whole X64 mode idea should be removed.
This is similar in spirit to the errors reported in #5982
You've just found a clever way to enable X64 mode without the context manager 😁
Thanks for opening this issue, Roman. After thinking about this, I think we should consider it intended behavior for internal APIs. It is a sharp edge related to x64, but fortunately not one in public APIs!