cornellius-gp / linear_operator

A LinearOperator implementation to wrap the numerical nuts and bolts of GPyTorch

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Pythonic Sum is broken with LinearOperators

wbeardall opened this issue · comments

Attempting to use Pythonic sum of arbitrary iterables of LinearOperators raises TypeError. This is because the builtin sum() function starts with int(0) and sums iteratively, as follows:

sum([a,b,..z]) = 0 + a + b + .. + z

Minimal reproducible example (with workaround)

import gpytorch
import torch
import traceback
import warnings

x1 = torch.rand([16,1])
x2 = torch.rand([16,1])

k1 = gpytorch.kernels.RBFKernel()(x1)
k2 = gpytorch.kernels.RBFKernel()(x2)

k3 = k1+k2

print("'+' operator works fine.")

# This does not
try:
    k4 = sum([k1,k2])
except Exception as e: traceback.print_exc()

# A workaround now: functools.reduce works, but is a little unsightly

from functools import reduce
import operator

k5 = reduce(operator.add,[k1,k2])

A Fix

This can be fixed by simply including logic for LinearOperator to add int(0) in the same way as it adds a ZeroLinearOperator, by returning self, as follows:

import numbers
# Specifically only allow numbers.Number without raising TypeError in the case where other==0
if isinstance(other, numbers.Number) and other==0:
    return self

I'd recommend adding this case at the end of the conditions in LinearOperator.__add__, on line 2571, so it's only checked after the more common usecases, as follows:

def __add__(self, other: Union[torch.Tensor, LinearOperator, float]) -> LinearOperator:
    from torch import Tensor

    from .added_diag_linear_operator import AddedDiagLinearOperator
    from .dense_linear_operator import to_linear_operator
    from .diag_linear_operator import DiagLinearOperator
    from .root_linear_operator import RootLinearOperator
    from .sum_linear_operator import SumLinearOperator
    from .zero_linear_operator import ZeroLinearOperator
    
    import numbers.Number

    if isinstance(other, ZeroLinearOperator):
        return self
    elif isinstance(other, DiagLinearOperator):
        return AddedDiagLinearOperator(self, other)
    elif isinstance(other, RootLinearOperator):
        return self.add_low_rank(other.root)
    elif isinstance(other, Tensor):
        other = to_linear_operator(other)
        shape = torch.broadcast_shapes(self.shape, other.shape)
        new_self = self if self.shape[:-2] == shape[:-2] else self._expand_batch(shape[:-2])
        new_other = other if other.shape[:-2] == shape[:-2] else other._expand_batch(shape[:-2])
        return SumLinearOperator(new_self, new_other)
    elif isinstance(other, numbers.Number) and other==0:
        return self
    else:
        return SumLinearOperator(self, other)