JasonGross / guarantees-based-mechanistic-interpretability

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

`ein.array` cannot compute the diagonal of non-square tensors (DimensionBindError: Dim 'dim' previously bound to a dimension of size 3 cannot bind to a dimension of size 4)

JasonGross opened this issue · comments

import gbmi.utils.ein as ein
import torch

foo = torch.randn(2, 3, 4, 5)
bar = ein.array(lambda q_tok, max_tok: foo[q_tok, max_tok, max_tok, 0], sizes=[foo.shape[0], foo.shape[1]])
# DimensionBindError: Dim 'dim' previously bound to a dimension of size 3 cannot bind to a dimension of size 4