google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Home Page:http://jax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

NonConcreteBooleanIndexError on call to jnp.unique

fonnesbeck opened this issue · comments

I'm getting the following error:

/usr/local/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py in _expand_bool_indices(idx, shape)
   3875       if not type(abstract_i) is ConcreteArray:
   3876         # TODO(mattjj): improve this error by tracking _why_ the indices are not concrete
-> 3877         raise errors.NonConcreteBooleanIndexError(abstract_i)
   3878       elif _ndim(i) == 0:
   3879         raise TypeError("JAX arrays do not support boolean scalar indices")

NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[4897])

as the result of a call to jnp.unique:

202[    ]() """
    [203](file:///var/repos/pie_live/research/projections/pitchers/stuff_proj.py?line=202) 
--> [204](file:///var/repos/pie_live/research/projections/pitchers/stuff_proj.py?line=203)     
ages_pred = jnp.unique(age_idx)

where age_idx is an int DeviceArray:

DeviceArray([24, 25, 26, ...,  6, 14, 10], dtype=int32)

Its not immediately clear why this error would be propagated in this context. Any ideas appreciated.

JAX JIT currently only support static shape, while return value of unique has a shape depends on input value.
https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.unique.html

Because the size of the output of unique is data-dependent, the function is not typically compatible with JIT. The JAX version adds the optional size argument which must be specified statically for jnp.unique to be used within some of JAX’s transformations.

This is working as intended, but we should definitely improve this error to more directly point to the documentation of the issue. Can you share the code that led to a NonConcreteBooleanIndexError? When I try this, I find a ConcretizationTypeError:

from jax import jit
import jax.numpy as jnp
jit(jnp.unique)(jnp.arange(5))
# ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(int32[5])>with<DynamicJaxprTrace(level=0/1)>
# The error arose for the first argument of jnp.unique()
# While tracing the function unique at /usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py:4215 for jit, this concrete value was not available in Python because it depends on the value of the argument 'ar'.

@jakevdp
mask can be abstract while ar is concrete.

def _unique(ar, axis, return_index=False, return_inverse=False, return_counts=False,
size=None, fill_value=None, return_true_size=False):
"""
Find the unique elements of an array along a particular axis.
"""
if ar.shape[axis] == 0 and size and fill_value is None:
raise ValueError(
"jnp.unique: for zero-sized input with nonzero size argument, fill_value must be specified")
aux, mask, perm = _unique_sorted_mask(ar, axis)
ind = mask if size is None else nonzero(mask, size=size)[0]
result = aux[ind] if aux.size else aux

The code is part of a much bigger model that I can't share, but I will try to create a smaller, reproducible example.

Thanks - don't worry about the reproduction, I think I understand where it's coming from now, and I can work on improving the error message. The root cause is attempting to JIT-compile jnp.unique, which returns an array with data-dependent shape. If you want to use this within JIT, you'll need to statically specify the size argument to jnp.unique.