ai4co / rl4co

A PyTorch library for all things Reinforcement Learning (RL) for Combinatorial Optimization (CO)

Home Page:https://rl4.co

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Mistery: `einops` works, but `view` does not?

fedebotu opened this issue · comments

We encountered the following problem:

import torch
from einops import rearrange

num_heads = 8

a = torch.randn(512, 20, 128)

# einops
a_einops = rearrange(a, 'b n (h d) -> b h n d', h=num_heads)

# torch
batch_size, length, hidden_dim = a.shape
a_torch = a.view(batch_size, num_heads, length, -1)

print(a_einops.shape)
print(a_torch.shape)

print(torch.allclose(a_einops, a_torch))

False

Why is this the case? By substituting view with einops, Attention works as it should

Solution:

import torch
from einops import rearrange

num_heads = 8
b = 2
l = 10
h = num_heads * 3

a = torch.randn(b, l, h) # [B, L, H]

# einops
a_einops = rearrange(a, 'b l (h d) -> b h l d', h=num_heads)

# torch
batch_size, length, hidden_dim = a.shape
a_torch = a.view(batch_size, length, num_heads, -1)
a_torch = a_torch.transpose(1, 2)

print(a_einops.shape)
print(a_torch.shape)

print(torch.allclose(a_einops, a_torch))

Ok I was real dumb
This was simply due to how PyTorch indexes tensor dimensions, thanks @Junyoungpark !