Adding the corresponding indexes in `jax.ops.segment_max`?
felixchalumeau opened this issue · comments
Hi!
Would it be possible to add an option in jax.ops.segment_max
to get the indexes of the maximal values retrieved? I am facing a situation where I need them and it's actually not that easy to reconstruct them from the output of jax.ops.segment_max
(unless I missed an easy way).
I guess that for jitting, it would be necessary to retrieve only one arbitrary chosen index in case of multiple similar maximums in the same segment.
Is this something that has already been discussed? Is there any technical limitation to do so? Happy to get any insight and happy to provide more motivation if necessary!
Here is a short example of the requested feature:
Current behavior
>>> data = jnp.arange(6) * 2
>>> segment_ids = jnp.array([0, 0, 1, 1, 2, 2])
>>> segment_max(data, segment_ids)
DeviceArray([2, 6, 10], dtype=int32)
Desired behavior
>>> data = jnp.arange(6) * 2
>>> segment_ids = jnp.array([0, 0, 1, 1, 2, 2])
>>> segment_max(data, segment_ids, retrieve_indexes=True)
(DeviceArray([2, 6, 10], dtype=int32), DeviceArray([1, 3, 5], dtype=int32))
(Also, this option could be added to the other functions of the jax.ops
package for consistency)
Thanks!
If we were to add such an API, I'd suggest that it should be called jax.ops.segment_argmax
rather than being an extension of the current segment_max
.
However, the bigger issue is that I don't believe that it is easy to implement this API using the current lax.scatter
primitive. segment_max
is essentially just a scatter-max call. However, it's not obvious to me whether we can implement scatter-argmax
without extending XLA to support a "variadic" scatter that lets us look at both keys and values in the same scatter operation. That seems like a reasonably large undertaking.
Was this issue ever resolved, or a workaround found? I also am finding that the segment_max
function is exactly what I need, but I also need the indices of the maxes. Thanks!