google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Home Page:http://jax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Adding the corresponding indexes in `jax.ops.segment_max`?

felixchalumeau opened this issue · comments

Hi!

Would it be possible to add an option in jax.ops.segment_max to get the indexes of the maximal values retrieved? I am facing a situation where I need them and it's actually not that easy to reconstruct them from the output of jax.ops.segment_max (unless I missed an easy way).

I guess that for jitting, it would be necessary to retrieve only one arbitrary chosen index in case of multiple similar maximums in the same segment.

Is this something that has already been discussed? Is there any technical limitation to do so? Happy to get any insight and happy to provide more motivation if necessary!

Here is a short example of the requested feature:

Current behavior

>>> data = jnp.arange(6) * 2
>>> segment_ids = jnp.array([0, 0, 1, 1, 2, 2])
>>> segment_max(data, segment_ids)
DeviceArray([2, 6, 10], dtype=int32)

Desired behavior

>>> data = jnp.arange(6) * 2
>>> segment_ids = jnp.array([0, 0, 1, 1, 2, 2])
>>> segment_max(data, segment_ids, retrieve_indexes=True)
(DeviceArray([2, 6, 10], dtype=int32), DeviceArray([1, 3, 5], dtype=int32))

(Also, this option could be added to the other functions of the jax.ops package for consistency)

Thanks!

If we were to add such an API, I'd suggest that it should be called jax.ops.segment_argmax rather than being an extension of the current segment_max.

However, the bigger issue is that I don't believe that it is easy to implement this API using the current lax.scatter primitive. segment_max is essentially just a scatter-max call. However, it's not obvious to me whether we can implement scatter-argmax without extending XLA to support a "variadic" scatter that lets us look at both keys and values in the same scatter operation. That seems like a reasonably large undertaking.

Was this issue ever resolved, or a workaround found? I also am finding that the segment_max function is exactly what I need, but I also need the indices of the maxes. Thanks!