cornellius-gp / linear_operator

A LinearOperator implementation to wrap the numerical nuts and bolts of GPyTorch

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[Bug] Failing to slice the CatLinearOperator when indexes are negative or when using boolean array

MoiseRousseau opened this issue Β· comments

πŸ› Bug

When slicing the CatLinearOperator using a negative index, the final shape of the slice does not match the expected shape and an error is returned. This fails at least for ToeplitzLinearOperator, the DiagLinearOperator and the IdentityLinearOperator.

To reproduce

** Code snippet to reproduce **

from linear_operator.operators import IdentityLinearOperator as Ops
from linear_operator.operators import cat as cat_ops

N = 8
base = cat_ops([Ops(N) for _ in range(2)], dim=1)
print(base.shape) #should be 8,16
print(base[:,3:base.shape[-1]-3].shape) #should be 8,10
print(base[:,3:-3].shape) #fail...

** Stack trace/error message **

torch.Size([8, 16])
torch.Size([8, 10])
Traceback (most recent call last):
  File "/home/moise/Program/moise/linear_operator/debug.py", line 8, in <module>
    print(base[:,3:-3].shape) #fail...
  File "/home/moise/Program/moise/linear_operator/linear_operator/operators/_linear_operator.py", line 2870, in __getitem__
    raise RuntimeError(
RuntimeError: CatLinearOperator.__getitem__ failed! Expected a final shape of size torch.Size([8, 10]), got torch.Size([8, 5]). This is a bug with LinearOperator, or your custom LinearOperator.

Expected Behavior

The slice behave as it is working when using positive indexes.

System information

LinearOperator Version 0.5.2
PyTorch Version 2.0.1
Ubuntu 22.04

New findings: I also get a similar error when slicing using boolean array and without using the CatLinearOperator, such as:

from linear_operator.operators import IdentityLinearOperator as Ops

N = 4
cond = [True,False,False,True]
ops = Ops(N)
print(ops.shape)
ops[:,cond]

Which gives:

torch.Size([4, 4])
Traceback (most recent call last):
  File "/home/moise/Program/moise/linear_operator/debug.py", line 8, in <module>
    ops[:,cond]
  File "/home/moise/Program/moise/linear_operator/linear_operator/operators/_linear_operator.py", line 2870, in __getitem__
    raise RuntimeError(
RuntimeError: IdentityLinearOperator.__getitem__ failed! Expected a final shape of size torch.Size([4, 4]), got torch.Size([4, 2]). This is a bug with LinearOperator, or your custom LinearOperator.

Looks like there may be a number of places where negative indexing isn't properly supported. I'll put up a fix for the CatLinearOperator case, but this should probably be audited more comprehensively.

I also don't think we've given much though to supporting boolean indexing with linear_operator - @gpleiss is that right?

Boolean indexing sounds tricky with linear operators. @MoiseRousseau do you have a good use case?

I found a workaround doing torch.argwhere(bool_array) and then slice using the index. I was just reporting the error. Maybe this can be a way to implement it (even if this is suboptimal) ?