google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Home Page:http://jax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

`KeyError: <class 'bfloat16'>` after using `lax._convert_element_type`

romanngg opened this issue · comments

from jax._src.lax import lax
from jax import numpy as np

a = lax._convert_element_type(np.ones((), np.float32), new_dtype=np.bfloat16, weak_type=True)
np.where(np.ones((), a.dtype), a, np.ones((), a.dtype))

gives

---------------------------------------------------------------------------
UnfilteredStackTrace                      Traceback (most recent call last)
<ipython-input-33-cebcbf9d106e> in <module>()
      4 a = lax._convert_element_type(np.ones((), np.float32), new_dtype=np.bfloat16, weak_type=True)
----> 5 np.where(np.ones((), a.dtype), a, np.ones((), a.dtype))

23 frames
UnfilteredStackTrace: KeyError: <class 'bfloat16'>

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

KeyError                                  Traceback (most recent call last)
google3/third_party/py/jax/_src/numpy/util.py in _promote_dtypes(*args)
    251     return args
    252   else:
--> 253     to_dtype, weak_type = dtypes._lattice_result_type(*args)
    254     to_dtype = dtypes.canonicalize_dtype(to_dtype)
    255     return [lax_internal._convert_element_type(x, to_dtype, weak_type) for x in args]

KeyError: <class 'bfloat16'>

As suggested offline by @jakevdp and @hawkinsp I tried replacing np.bfloat16 with np.dtype(np.bfloat16) and np.dtype('bfloat16'), but getting the same error.

Thanks for the report – I think this is related to misuse of non-public APIs. Although weak_type is an annotation on any dtype, in the normal course of things it's only generated as an attribute on float64, int64, and complex128 (or their X64-demoted equivalents). JAX type promotion doesn't recognize "weak bfloat16" and errors. If you can find a way to cause this error (i.e. generate a weakly-typed bfloat16 value) without using a private API, it would be a bug.

In any case, we could fix this by adding another check to the _jax_type utility function here:

jax/jax/_src/dtypes.py

Lines 245 to 248 in 7008b32

def _jax_type(dtype, weak_type):
"""Return the jax type for a dtype and weak type."""
return type(dtype.type(0).item()) if (weak_type and dtype != bool) else dtype

The problem is that jnp.dtype('bfloat16').type(0).item() does not return a Python float as it does for other floating point dtypes.

I think this should be fixed by the bfloat16 item() change that @pschuh is working on.

Looks like this is fixed as of jaxlib version 0.3.14