Transposed Cross Attention (page 4 in the paper)
MarioBgle opened this issue · comments
Dear Piccinelli,
Thank you so much for your great work! iDisc is the best performing depth estimator I ever used.
I need help understanding how the transposed cross attention is implemented.
Let us assume
On page 4 of your paper you specify trasposed cross attention as:
Thanks in advance.
import torch
import torch.nn.functional as F
def softmax(x, sink_competition=False):
if sink_competition:
attn = F.softmax(x, dim=-2)
attn = attn / torch.sum(attn, dim=(-1,), keepdim=True)
else:
attn = F.softmax(x, dim=-1)
return attn
q = torch.tensor([[[1, 2, 3], [4, 5, 6]], # batch 1
[[7, 8, 9], [10, 11, 12]]], dtype=torch.float32) # batch 2
k = torch.tensor([[[1, 0, 1], [0, 1, 1]], # batch 1
[[1, 1, 0], [0, 1, 0]]], dtype=torch.float32) # batch 2
similarity_matrix_transposed_naive_implementation = torch.einsum("bid, bjd -> bij", k, q) # K Q^T (transposed similarity matrix)
attention_transposed = softmax(similarity_matrix_transposed_naive_implementation, False).transpose(1, 2)
similarity_matrix_code = torch.einsum("bid, bjd -> bij", q, k) # Calculating transposed cross attention like in the code
attention_code = softmax(similarity_matrix_code, True)
print("naive implementation similarity matrix (KQ^T):", similarity_matrix_transposed_naive_implementation)
print("naive implementation of transposed cross attention:", attention_transposed)
print("code implementation of transposed cross attention:", attention_code)
print("Difference between attention matrices:", torch.sum(torch.abs(attention_transposed - attention_code)))
Edit: I just missed that " attn = attn / torch.sum(attn, dim=(-1,), keepdim=True)" belongs to the normalization to avoid vanishing gradients, which is explained on page 4: The weights
along the i dimension to avoid vanishing or exploding quan-
tities due to the summation of un-normalized distribution
If I remove this line, the naive implementation and the code-implementation are the same.
Sorry for the inconvenience.