`ein.array` cannot compute the diagonal of non-square tensors (DimensionBindError: Dim 'dim' previously bound to a dimension of size 3 cannot bind to a dimension of size 4)
JasonGross opened this issue · comments
Jason Gross commented
import gbmi.utils.ein as ein
import torch
foo = torch.randn(2, 3, 4, 5)
bar = ein.array(lambda q_tok, max_tok: foo[q_tok, max_tok, max_tok, 0], sizes=[foo.shape[0], foo.shape[1]])
# DimensionBindError: Dim 'dim' previously bound to a dimension of size 3 cannot bind to a dimension of size 4