lucidrains / alphafold2

To eventually become an unofficial Pytorch implementation / replication of Alphafold2, as details of the architecture get released

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

The definition of row-wise attention and col-wise attention

CiaoHe opened this issue · comments

Hi, Luci:
Sorry, it's me again.
Here, I was confused by the definition of row-wise and col-wise attention.

if self.row_attn:
w_x = rearrange(x, 'b h w d -> (b w) h d')

if self.col_attn:
h_x = rearrange(x, 'b h w d -> (b h) w d')

Based on what I thought, the row w_x should be represented by (b h) w d, since once fetch one row, each row should have w(width) units.

So, maybe here need an inverse of the above definition?