google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Home Page:http://jax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

psum_scatter does not allow scatter_dimension to be negative

jewillco opened this issue · comments

Description

In jax.lax.psum_scatter, I tried to use -1 as scatter_dimension to indicate that the last dimension should be used. However, this triggers an error: expects scatter_dimension >= 0. Negative indices are normally supported for axes in JAX and Python in general.

System info (python version, jaxlib version, accelerator, etc.)

N/A

Hi @jewillco

Thanks for reporting. There's a related issue reported under #20125 that's still open. It might be helpful to reference it.

Duplicate of #20125

Duplicate of #20125