Pythonic Sum is broken with LinearOperators
wbeardall opened this issue · comments
wbeardall commented
Attempting to use Pythonic sum of arbitrary iterables of LinearOperators raises TypeError. This is because the builtin sum() function starts with int(0) and sums iteratively, as follows:
sum([a,b,..z]) = 0 + a + b + .. + z
Minimal reproducible example (with workaround)
import gpytorch
import torch
import traceback
import warnings
x1 = torch.rand([16,1])
x2 = torch.rand([16,1])
k1 = gpytorch.kernels.RBFKernel()(x1)
k2 = gpytorch.kernels.RBFKernel()(x2)
k3 = k1+k2
print("'+' operator works fine.")
# This does not
try:
k4 = sum([k1,k2])
except Exception as e: traceback.print_exc()
# A workaround now: functools.reduce works, but is a little unsightly
from functools import reduce
import operator
k5 = reduce(operator.add,[k1,k2])
A Fix
This can be fixed by simply including logic for LinearOperator to add int(0) in the same way as it adds a ZeroLinearOperator, by returning self
, as follows:
import numbers
# Specifically only allow numbers.Number without raising TypeError in the case where other==0
if isinstance(other, numbers.Number) and other==0:
return self
I'd recommend adding this case at the end of the conditions in LinearOperator.__add__, on line 2571, so it's only checked after the more common usecases, as follows:
def __add__(self, other: Union[torch.Tensor, LinearOperator, float]) -> LinearOperator:
from torch import Tensor
from .added_diag_linear_operator import AddedDiagLinearOperator
from .dense_linear_operator import to_linear_operator
from .diag_linear_operator import DiagLinearOperator
from .root_linear_operator import RootLinearOperator
from .sum_linear_operator import SumLinearOperator
from .zero_linear_operator import ZeroLinearOperator
import numbers.Number
if isinstance(other, ZeroLinearOperator):
return self
elif isinstance(other, DiagLinearOperator):
return AddedDiagLinearOperator(self, other)
elif isinstance(other, RootLinearOperator):
return self.add_low_rank(other.root)
elif isinstance(other, Tensor):
other = to_linear_operator(other)
shape = torch.broadcast_shapes(self.shape, other.shape)
new_self = self if self.shape[:-2] == shape[:-2] else self._expand_batch(shape[:-2])
new_other = other if other.shape[:-2] == shape[:-2] else other._expand_batch(shape[:-2])
return SumLinearOperator(new_self, new_other)
elif isinstance(other, numbers.Number) and other==0:
return self
else:
return SumLinearOperator(self, other)