NonConcreteBooleanIndexError on call to jnp.unique
fonnesbeck opened this issue · comments
I'm getting the following error:
/usr/local/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py in _expand_bool_indices(idx, shape)
3875 if not type(abstract_i) is ConcreteArray:
3876 # TODO(mattjj): improve this error by tracking _why_ the indices are not concrete
-> 3877 raise errors.NonConcreteBooleanIndexError(abstract_i)
3878 elif _ndim(i) == 0:
3879 raise TypeError("JAX arrays do not support boolean scalar indices")
NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[4897])
as the result of a call to jnp.unique
:
202[ ]() """
[203](file:///var/repos/pie_live/research/projections/pitchers/stuff_proj.py?line=202)
--> [204](file:///var/repos/pie_live/research/projections/pitchers/stuff_proj.py?line=203)
ages_pred = jnp.unique(age_idx)
where age_idx
is an int DeviceArray:
DeviceArray([24, 25, 26, ..., 6, 14, 10], dtype=int32)
Its not immediately clear why this error would be propagated in this context. Any ideas appreciated.
JAX JIT currently only support static shape, while return value of unique
has a shape depends on input value.
https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.unique.html
Because the size of the output of unique is data-dependent, the function is not typically compatible with JIT. The JAX version adds the optional size argument which must be specified statically for jnp.unique to be used within some of JAX’s transformations.
This is working as intended, but we should definitely improve this error to more directly point to the documentation of the issue. Can you share the code that led to a NonConcreteBooleanIndexError
? When I try this, I find a ConcretizationTypeError
:
from jax import jit
import jax.numpy as jnp
jit(jnp.unique)(jnp.arange(5))
# ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(int32[5])>with<DynamicJaxprTrace(level=0/1)>
# The error arose for the first argument of jnp.unique()
# While tracing the function unique at /usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py:4215 for jit, this concrete value was not available in Python because it depends on the value of the argument 'ar'.
The code is part of a much bigger model that I can't share, but I will try to create a smaller, reproducible example.
Thanks - don't worry about the reproduction, I think I understand where it's coming from now, and I can work on improving the error message. The root cause is attempting to JIT-compile jnp.unique
, which returns an array with data-dependent shape. If you want to use this within JIT, you'll need to statically specify the size
argument to jnp.unique
.