psum_scatter does not allow scatter_dimension to be negative
jewillco opened this issue · comments
Jeremiah Willcock commented
Description
In jax.lax.psum_scatter
, I tried to use -1 as scatter_dimension
to indicate that the last dimension should be used. However, this triggers an error: expects scatter_dimension >= 0
. Negative indices are normally supported for axes in JAX and Python in general.
System info (python version, jaxlib version, accelerator, etc.)
N/A
rajasekharporeddy commented
Jeremiah Willcock commented
Duplicate of #20125
Jeremiah Willcock commented
Duplicate of #20125