[Feature Request] Block matrix support
hughsalimbeni opened this issue Β· comments
π Feature Request
Represent [TN, TM] tensors by TxT blocks of NxM lazy tensors. While block matrices are supported, the efficient representation is only when there is a diagonal structure over the T dimensions.
Motivation
Here is an example that linear_operator cannot deal with:
import torch
import itertools
T, N, M = 2, 4, 3
As = [torch.rand(N, M) for _ in range(T)]
Bs = [[torch.rand(M, M) for _ in range(T)] for _ in range(T)]
Cs = [torch.rand(N, N) for _ in range(T)]
L = torch.rand(T, T)
A_bl = torch.zeros((N * T, M * T)) # BlockDiag (non-square)
B_bl = torch.zeros((M * T, M * T)) # Dense
C_bl = torch.zeros((N * T, N * T)) # BlockDiag
L_bl = torch.kron(L, torch.eye(N)) # Kroneker
for t in range(T):
A_bl[N * t : N * (t + 1), M * t : M * (t + 1)] = As[t]
C_bl[N * t : N * (t + 1), N * t : N * (t + 1)] = Cs[t]
for t1, t2 in itertools.product(range(T), range(T)):
B_bl[M * t1 : M * (t1 + 1), M * t2 : M * (t2 + 1)] = Bs[t1][t2]
# Desired calculation
print("inefficient method")
print(torch.diag(L_bl @ (C_bl + A_bl @ B_bl @ A_bl.T) @ L_bl.T))
This calculation turns up in some multi-output GP models. It has a straightforward efficient implementation:
M_diag = {}
# We only need the diagonal of each block of M
for t1, t2 in itertools.product(range(T), range(T)):
r = (As[t1].T * (Bs[t1][t2] @ As[t2].T)).sum(0)
if t1 == t2:
r += torch.diag(Cs[t1])
M_diag[(t1, t2)] = r
# The rotation is applied blockwise due to the kron structure
R = {}
for t in range(T): # we don't need the off-diag blocks
r = 0
for i1, i2 in itertools.product(range(T), range(T)):
r += L[t, i1] * M_diag[(i1, i2)] * L[t, i2]
R[t] = r
print("fast way")
print(torch.concat([R[t] for t in range(T)]))
Currently, this calculation could be implemented inside linear_operator like this
from linear_operator.operators import (
to_linear_operator,
IdentityLinearOperator,
BlockDiagLinearOperator,
BlockLinearOperator,
MatmulLinearOperator,
KroneckerProductLinearOperator,
)
class BlockDiagLinearOperatorNonSquare(BlockLinearOperator):
_add_batch_dim = BlockDiagLinearOperator._add_batch_dim
_remove_batch_dim = BlockDiagLinearOperator._remove_batch_dim
_get_indices = BlockDiagLinearOperator._get_indices
_size = BlockDiagLinearOperator._size
num_blocks = BlockDiagLinearOperator.num_blocks
def __init__(self, base_linear_op, block_dim=-3):
super().__init__(base_linear_op, block_dim)
A_lo = BlockDiagLinearOperatorNonSquare(torch.stack(As, 0))
B_lo = to_linear_operator(B_bl)
C_lo = BlockDiagLinearOperator(to_linear_operator(torch.stack(Cs, 0)))
L_lo = KroneckerProductLinearOperator(L, IdentityLinearOperator(N))
M = MatmulLinearOperator(A_lo, (MatmulLinearOperator(B_lo, A_lo.T)))
print("using linear operator, with to_dense()")
print(
MatmulLinearOperator(L_lo, MatmulLinearOperator(C_lo + M, L_lo.T))
.to_dense()
.diagonal()
)
Removing the to_dense()
gives an error, however.
Pitch
Add block linear operator class that can keep track of the [T, T] block structure, represented as T^2 lazy tensors of the same shape. Implement matrix multiplication between block matrices as the appropriate linear operators on the blocks.
As a work-around, I have written manual implementations of specific cases, such as above.
I'm willing to work on PR for this
Additional context
None
Thanks for the suggestion, @hughsalimbeni! @gpleiss, @jacobrgardner and I have talked in the past about expanding linear_operator
beyond the current focus on square (really, symmetric PSD) matrices.
The BlockDiagLinearOperatorNonSquare
extending BlockDiagLinearOperator
seems like a nifty way of realizing this without a ton of refactoring, but ideally we'd rethink the inheritance structure in a way that we'd have something general like
LinearOperator
-> BlockLinearOperator
-> BlockDiagLinearOperator
-> DiagLinearOperator
where operators are not assumed to be square (could just have a is_square
property that computes from the trailing two dimensions) or symmetric or positive definite (those could also be properties).
This would of course a major redesign of the whole library and so presumably out of scope for what you're trying to achieve here. But adding your suggestion could be a step on the way to a more general setup, and could inform / be absorbed in a larger rewrite down the road. So I'm happy to help review a PR for this.
Looks like a great addition. The key question is what functions need to be implemented to make this a reality. From the library description, we must implement:
_matmul
_transpose_nonbatch
I'm not sure what else makes sense. It seems like we might want
_diagonal
_root_decomposition?
_root_inv_decomposition?
_solve?
inv_quad_logdet?
_svd?
_symeig?