[Bug] Failing to slice the CatLinearOperator when indexes are negative or when using boolean array
MoiseRousseau opened this issue Β· comments
π Bug
When slicing the CatLinearOperator using a negative index, the final shape of the slice does not match the expected shape and an error is returned. This fails at least for ToeplitzLinearOperator
, the DiagLinearOperator
and the IdentityLinearOperator
.
To reproduce
** Code snippet to reproduce **
from linear_operator.operators import IdentityLinearOperator as Ops
from linear_operator.operators import cat as cat_ops
N = 8
base = cat_ops([Ops(N) for _ in range(2)], dim=1)
print(base.shape) #should be 8,16
print(base[:,3:base.shape[-1]-3].shape) #should be 8,10
print(base[:,3:-3].shape) #fail...
** Stack trace/error message **
torch.Size([8, 16])
torch.Size([8, 10])
Traceback (most recent call last):
File "/home/moise/Program/moise/linear_operator/debug.py", line 8, in <module>
print(base[:,3:-3].shape) #fail...
File "/home/moise/Program/moise/linear_operator/linear_operator/operators/_linear_operator.py", line 2870, in __getitem__
raise RuntimeError(
RuntimeError: CatLinearOperator.__getitem__ failed! Expected a final shape of size torch.Size([8, 10]), got torch.Size([8, 5]). This is a bug with LinearOperator, or your custom LinearOperator.
Expected Behavior
The slice behave as it is working when using positive indexes.
System information
LinearOperator Version 0.5.2
PyTorch Version 2.0.1
Ubuntu 22.04
New findings: I also get a similar error when slicing using boolean array and without using the CatLinearOperator, such as:
from linear_operator.operators import IdentityLinearOperator as Ops
N = 4
cond = [True,False,False,True]
ops = Ops(N)
print(ops.shape)
ops[:,cond]
Which gives:
torch.Size([4, 4])
Traceback (most recent call last):
File "/home/moise/Program/moise/linear_operator/debug.py", line 8, in <module>
ops[:,cond]
File "/home/moise/Program/moise/linear_operator/linear_operator/operators/_linear_operator.py", line 2870, in __getitem__
raise RuntimeError(
RuntimeError: IdentityLinearOperator.__getitem__ failed! Expected a final shape of size torch.Size([4, 4]), got torch.Size([4, 2]). This is a bug with LinearOperator, or your custom LinearOperator.
Looks like there may be a number of places where negative indexing isn't properly supported. I'll put up a fix for the CatLinearOperator
case, but this should probably be audited more comprehensively.
I also don't think we've given much though to supporting boolean indexing with linear_operator
- @gpleiss is that right?
Boolean indexing sounds tricky with linear operators. @MoiseRousseau do you have a good use case?
I found a workaround doing torch.argwhere(bool_array)
and then slice using the index. I was just reporting the error. Maybe this can be a way to implement it (even if this is suboptimal) ?