`KeyError: <class 'bfloat16'>` after using `lax._convert_element_type`
romanngg opened this issue · comments
from jax._src.lax import lax
from jax import numpy as np
a = lax._convert_element_type(np.ones((), np.float32), new_dtype=np.bfloat16, weak_type=True)
np.where(np.ones((), a.dtype), a, np.ones((), a.dtype))
gives
---------------------------------------------------------------------------
UnfilteredStackTrace Traceback (most recent call last)
<ipython-input-33-cebcbf9d106e> in <module>()
4 a = lax._convert_element_type(np.ones((), np.float32), new_dtype=np.bfloat16, weak_type=True)
----> 5 np.where(np.ones((), a.dtype), a, np.ones((), a.dtype))
23 frames
UnfilteredStackTrace: KeyError: <class 'bfloat16'>
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
KeyError Traceback (most recent call last)
google3/third_party/py/jax/_src/numpy/util.py in _promote_dtypes(*args)
251 return args
252 else:
--> 253 to_dtype, weak_type = dtypes._lattice_result_type(*args)
254 to_dtype = dtypes.canonicalize_dtype(to_dtype)
255 return [lax_internal._convert_element_type(x, to_dtype, weak_type) for x in args]
KeyError: <class 'bfloat16'>
As suggested offline by @jakevdp and @hawkinsp I tried replacing np.bfloat16
with np.dtype(np.bfloat16)
and np.dtype('bfloat16')
, but getting the same error.
Thanks for the report – I think this is related to misuse of non-public APIs. Although weak_type is an annotation on any dtype, in the normal course of things it's only generated as an attribute on float64
, int64
, and complex128
(or their X64-demoted equivalents). JAX type promotion doesn't recognize "weak bfloat16" and errors. If you can find a way to cause this error (i.e. generate a weakly-typed bfloat16 value) without using a private API, it would be a bug.
In any case, we could fix this by adding another check to the _jax_type utility function here:
Lines 245 to 248 in 7008b32
The problem is that jnp.dtype('bfloat16').type(0).item()
does not return a Python float
as it does for other floating point dtypes.
I think this should be fixed by the bfloat16 item()
change that @pschuh is working on.
Looks like this is fixed as of jaxlib version 0.3.14