[Bug] Indexing `ConstantMulLinearOperator` with a `SumBatchLinearOperator` base operator
j-wilson opened this issue Β· comments
π Bug
To reproduce
A = ops.DenseLinearOperator(rand(4, 3, 2, 2))
B = ops.SumBatchLinearOperator(A, block_dim=-3)
C = ops.ConstantMulLinearOperator(B, rand([]))
C[:, -1:, :].to_dense()
The size of tensor a (3) must match the size of tensor b (4) at non-singleton dimension 1
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-18-e4d937430d51> in <module>
----> 1 C[:, -1:, :].to_dense()
/mnt/xarfuse/uid-22150/d091ed77-seed-nspid4026533386_cgpid5352405-ns-4026533383/linear_operator/operators/sum_batch_linear_operator.py in to_dense(self)
59
60 def to_dense(self):
---> 61 return self.base_linear_op.to_dense().sum(dim=-3) # BlockLinearOperators always use dim3 for the block_dim
/mnt/xarfuse/uid-22150/d091ed77-seed-nspid4026533386_cgpid5352405-ns-4026533383/linear_operator/utils/memoize.py in g(self, *args, **kwargs)
57 kwargs_pkl = pickle.dumps(kwargs)
58 if not _is_in_cache(self, cache_name, *args, kwargs_pkl=kwargs_pkl):
---> 59 return _add_to_cache(self, cache_name, method(self, *args, **kwargs), *args, kwargs_pkl=kwargs_pkl)
60 return _get_from_cache(self, cache_name, *args, kwargs_pkl=kwargs_pkl)
61
/mnt/xarfuse/uid-22150/d091ed77-seed-nspid4026533386_cgpid5352405-ns-4026533383/linear_operator/operators/constant_mul_linear_operator.py in to_dense(self)
164 def to_dense(self):
165 res = self.base_linear_op.to_dense()
--> 166 return res * self.expanded_constant
167
168 @cached(name="root_decomposition")
RuntimeError: The size of tensor a (3) must match the size of tensor b (4) at non-singleton dimension 1
Expected Behavior
The code is expected to behave in the same way its dense analogue would.
The problem seems to stem from ConstantMulLinearOperator._getitem
. The following appears to work, but I am not sure what it's runtime profile looks like in comparison to the existing implementation. We still index into base_linear_op
and constant
directly. If I'm not mistaken, this new version may be faster (in the particular case considered here) since we now multiply instance: SumBatchLinearOperator
by constant
rather than instance.base_linear_op
.
def _getitem(self, row_index, col_index, *batch_indices):
# NOTE TO FUTURE SELF:
# This custom __getitem__ is actually very important!
# It prevents constructing an InterpolatedLinearOperator when one isn't needed
# This affects runtimes by up to 5x on simple exact GPs
# Run __getitem__ on the base_linear_op and the constant
base_linear_op = self.base_linear_op._getitem(row_index, col_index, *batch_indices)
constant = self._constant.expand(self.batch_shape)[batch_indices]
return type(self)(base_linear_op=base_linear_op, constant=constant)
Your fix seems reasonable, and I also suspect that it is faster :) Want to throw up a PR for this?
Fixed by #37.