google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Home Page:http://jax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Implement `at` method for `_IndexUpdateRef`

Edenhofer opened this issue · comments

Currently the following does not work

from jax import numpy as jnp

a = jnp.arange(3)
a = a.at[1:].at[:-1].set(0)

though IMHO the desired instruction is clear. The equivalent numpy expression a[1:][:-1] = 0 works according to one's intuition and sets the second element to 0.

This is an interesting idea, and all things being equal I think it would make sense to support it. That said, XLA's scatter/gather semantics are quite complicated, and I don't know any mechanism whereby it would allow us to "stack" gathers and scatters in this manner, so implementation would be non-trivial. I'll leave it open as a P3 enhancement.

I hope I am not oversimplifying things but one should be able to merge indices into an array reasonably well before even reaching any of the lower level XLA calls. At the very least it should be possible to merge slices respectively indices in consecutive at calls. If the user starts to mix slices and indices, things may start to get more involved.

The code to convert an indexing expression to an XLA gather/scatter is complicated:

def _index_to_gather(x_shape, idx, normalize_indices=True):

It may be possible and indeed easier to merge the Python-level indexing objects rather than trying to compose to XLA-level gather/scatter objects.

Don't forget that even at the Python level, your index composition would have to account for indices containing integers, slices, np.newaxis, ellipses, (broadcasted) arrays of indices, (broadcasted) boolean masks, and every combination thereof. I think it would take some work to get right!

And even for numpy, this kind of composability doesn't always work as you might expect. For example:

In [1]: import numpy as np

In [2]: idx = np.array([2, 4, 5])

In [3]: x = np.arange(10)

In [4]: x[idx][:3] = 999

In [5]: print(x)
[0 1 2 3 4 5 6 7 8 9]

The reason is that while x[idx] = 999 calls x.__setitem__, x[idx] without an associated assignment calls __getitem__, which in this case returns a copy rather than a view, and the second index expression modifies that copy instead of the original array. So if we were to add such composition, we'd have to think carefully about when and where to mimic / diverge from numpy.

Good point. I overlooked all the other ways to index an array. This would start to get messy very quickly.

I am not sure about your example though. The at method in JAX for arrays already behaves different compared to numpy in that respect so I am not sure whether one would add any more complexity by nesting it.

The ndindex library can do these sort of calculations with as_subindex: https://quansight-labs.github.io/ndindex/api.html#ndindex.ndindex.NDIndex.as_subindex

At a quick glance this seems to only be implemented for relatively simple cases. Citing from your link:

i.as_subindex(j) is currently only implemented when j is a slice with positive steps and nonnegative start and stop, or a Tuple of the same. To use it with slices with negative start or stop, call reduce() with a shape first.