Mistery: `einops` works, but `view` does not?
fedebotu opened this issue · comments
We encountered the following problem:
import torch
from einops import rearrange
num_heads = 8
a = torch.randn(512, 20, 128)
# einops
a_einops = rearrange(a, 'b n (h d) -> b h n d', h=num_heads)
# torch
batch_size, length, hidden_dim = a.shape
a_torch = a.view(batch_size, num_heads, length, -1)
print(a_einops.shape)
print(a_torch.shape)
print(torch.allclose(a_einops, a_torch))
False
Why is this the case? By substituting view
with einops
, Attention works as it should
Solution:
import torch
from einops import rearrange
num_heads = 8
b = 2
l = 10
h = num_heads * 3
a = torch.randn(b, l, h) # [B, L, H]
# einops
a_einops = rearrange(a, 'b l (h d) -> b h l d', h=num_heads)
# torch
batch_size, length, hidden_dim = a.shape
a_torch = a.view(batch_size, length, num_heads, -1)
a_torch = a_torch.transpose(1, 2)
print(a_einops.shape)
print(a_torch.shape)
print(torch.allclose(a_einops, a_torch))
Ok I was real dumb
This was simply due to how PyTorch indexes tensor dimensions, thanks @Junyoungpark !