cornellius-gp / linear_operator

A LinearOperator implementation to wrap the numerical nuts and bolts of GPyTorch

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[Bug] torch.cat fails for linear operators

chrisyeh96 opened this issue Β· comments

πŸ› Bug

torch.cat fails for linear operators.

To reproduce

** Code snippet to reproduce **

from linear_operator.operators import DiagLinearOperator
import torch

D = DiagLinearOperator(torch.randn(2, 3, 100))  # Represents an operator of size 2 x 3 x 100
torch.cat([D, D], dim=-2)

** Stack trace/error message **

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/ubuntu/miniconda3/envs/env/lib/python3.12/site-packages/linear_operator/operators/_linear_operator.py", line 2948, in __torch_function__
    raise NotImplementedError(f"torch.{name}({arg_classes}, {kwarg_classes}) is not implemented.")
NotImplementedError: torch.cat(list, dim=int) is not implemented.

Expected Behavior

According to the documentation, torch.cat should work on linear operators.

System information

Please complete the following information:

  • LinearOperator version: 0.5.2
  • PyTorch Version: 2.2.1
  • OS: Ubuntu 20.04.6 LTS

Hmm interesting. Yeah not sure why the docs contain this, seems like this was never implemented. Support is there for many unary or binary operators, but torch.cat operates on a list of objects rather than a LinearOperator directly. @gpleiss have you considered this and similar operators in the past?

There is a CatLinearOperator that implements what is needed. I wish that linear_operator could implement LinearOperator in a way such that calling torch.cat on a list of LinearOperator automatically creates a CatLinearOperator

Yes, that makes a lot of sense and would be great to have. I'm not sure if that is easy to do with the __torch_function__ setup that we leverage for doing this dispatching under the hood. Let me see if I can get some intel on this.