cornellius-gp / linear_operator

A LinearOperator implementation to wrap the numerical nuts and bolts of GPyTorch

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[Feature Request] Block matrix support

hughsalimbeni opened this issue Β· comments

πŸš€ Feature Request

Represent [TN, TM] tensors by TxT blocks of NxM lazy tensors. While block matrices are supported, the efficient representation is only when there is a diagonal structure over the T dimensions.


Here is an example that linear_operator cannot deal with:

import torch
import itertools

T, N, M = 2, 4, 3
As = [torch.rand(N, M) for _ in range(T)]
Bs = [[torch.rand(M, M) for _ in range(T)] for _ in range(T)]
Cs = [torch.rand(N, N) for _ in range(T)]
L = torch.rand(T, T)

A_bl = torch.zeros((N * T, M * T))  # BlockDiag (non-square)
B_bl = torch.zeros((M * T, M * T))  # Dense
C_bl = torch.zeros((N * T, N * T))  # BlockDiag
L_bl = torch.kron(L, torch.eye(N))  # Kroneker

for t in range(T):
    A_bl[N * t : N * (t + 1), M * t : M * (t + 1)] = As[t]
    C_bl[N * t : N * (t + 1), N * t : N * (t + 1)] = Cs[t]

for t1, t2 in itertools.product(range(T), range(T)):
    B_bl[M * t1 : M * (t1 + 1), M * t2 : M * (t2 + 1)] = Bs[t1][t2]

# Desired calculation
print("inefficient method")
print(torch.diag(L_bl @ (C_bl + A_bl @ B_bl @ A_bl.T) @ L_bl.T))

This calculation turns up in some multi-output GP models. It has a straightforward efficient implementation:

M_diag = {}
# We only need the diagonal of each block of M
for t1, t2 in itertools.product(range(T), range(T)):
    r = (As[t1].T * (Bs[t1][t2] @ As[t2].T)).sum(0)
    if t1 == t2:
        r += torch.diag(Cs[t1])
    M_diag[(t1, t2)] = r

# The rotation is applied blockwise due to the kron structure
R = {}
for t in range(T):  # we don't need the off-diag blocks
    r = 0
    for i1, i2 in itertools.product(range(T), range(T)):
        r += L[t, i1] * M_diag[(i1, i2)] * L[t, i2]
    R[t] = r

print("fast way")
print(torch.concat([R[t] for t in range(T)]))

Currently, this calculation could be implemented inside linear_operator like this

from linear_operator.operators import (

class BlockDiagLinearOperatorNonSquare(BlockLinearOperator):
    _add_batch_dim = BlockDiagLinearOperator._add_batch_dim
    _remove_batch_dim = BlockDiagLinearOperator._remove_batch_dim
    _get_indices = BlockDiagLinearOperator._get_indices
    _size = BlockDiagLinearOperator._size
    num_blocks = BlockDiagLinearOperator.num_blocks

    def __init__(self, base_linear_op, block_dim=-3):
        super().__init__(base_linear_op, block_dim)

A_lo = BlockDiagLinearOperatorNonSquare(torch.stack(As, 0))
B_lo = to_linear_operator(B_bl)
C_lo = BlockDiagLinearOperator(to_linear_operator(torch.stack(Cs, 0)))
L_lo = KroneckerProductLinearOperator(L, IdentityLinearOperator(N))

M = MatmulLinearOperator(A_lo, (MatmulLinearOperator(B_lo, A_lo.T)))

print("using linear operator, with to_dense()")
    MatmulLinearOperator(L_lo, MatmulLinearOperator(C_lo + M, L_lo.T))

Removing the to_dense() gives an error, however.


Add block linear operator class that can keep track of the [T, T] block structure, represented as T^2 lazy tensors of the same shape. Implement matrix multiplication between block matrices as the appropriate linear operators on the blocks.

As a work-around, I have written manual implementations of specific cases, such as above.

I'm willing to work on PR for this

Additional context


Thanks for the suggestion, @hughsalimbeni! @gpleiss, @jacobrgardner and I have talked in the past about expanding linear_operator beyond the current focus on square (really, symmetric PSD) matrices.

The BlockDiagLinearOperatorNonSquare extending BlockDiagLinearOperator seems like a nifty way of realizing this without a ton of refactoring, but ideally we'd rethink the inheritance structure in a way that we'd have something general like

LinearOperator -> BlockLinearOperator -> BlockDiagLinearOperator -> DiagLinearOperator

where operators are not assumed to be square (could just have a is_square property that computes from the trailing two dimensions) or symmetric or positive definite (those could also be properties).

This would of course a major redesign of the whole library and so presumably out of scope for what you're trying to achieve here. But adding your suggestion could be a step on the way to a more general setup, and could inform / be absorbed in a larger rewrite down the road. So I'm happy to help review a PR for this.

Looks like a great addition. The key question is what functions need to be implemented to make this a reality. From the library description, we must implement:

I'm not sure what else makes sense. It seems like we might want