SysCV / idisc

iDisc: Internal Discretization for Monocular Depth Estimation [CVPR 2023]

Home Page:https://arxiv.org/abs/2304.06334

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Transposed Cross Attention (page 4 in the paper)

MarioBgle opened this issue · comments

Dear Piccinelli,

Thank you so much for your great work! iDisc is the best performing depth estimator I ever used.
I need help understanding how the transposed cross attention is implemented.

Let us assume $d_k=1$ and $\mathbf {V}= \mathbf{I}$ for clarity.
On page 4 of your paper you specify trasposed cross attention as: $\left[\text{softmax}(\mathbf{K} \mathbf{Q}^{T})\right]^{T}$. In comparison to regular attention $\text{softmax}(\mathbf{Q} \mathbf{K}^{T})$ the positions of the query and key tensor are switched, and the similarity matrix is transposed after the softmax. I believe transposed cross attention is being computed here. sink_competition=True is the only difference I can find to normal cross attention in attention.py. For sink_competition = True, Q and K are not switched- however they are switched in your formula for transposed cross attention. Below you can find a small example in which I show that my naive implementation of your transposed cross attention does not yield the same result as your code.

Thanks in advance.

import torch
import torch.nn.functional as F


def softmax(x, sink_competition=False):
    if sink_competition:
        attn = F.softmax(x, dim=-2)
        attn = attn / torch.sum(attn, dim=(-1,), keepdim=True)
    else:
        attn = F.softmax(x, dim=-1)
    return attn



q = torch.tensor([[[1, 2, 3], [4, 5, 6]],   # batch 1
                  [[7, 8, 9], [10, 11, 12]]], dtype=torch.float32)  # batch 2
k = torch.tensor([[[1, 0, 1], [0, 1, 1]],   # batch 1
                  [[1, 1, 0], [0, 1, 0]]], dtype=torch.float32)   # batch 2

similarity_matrix_transposed_naive_implementation = torch.einsum("bid, bjd -> bij", k, q) # K Q^T (transposed similarity matrix)

attention_transposed = softmax(similarity_matrix_transposed_naive_implementation, False).transpose(1, 2)

similarity_matrix_code = torch.einsum("bid, bjd -> bij", q, k) # Calculating transposed cross attention like in the code
attention_code = softmax(similarity_matrix_code, True)

print("naive implementation similarity matrix (KQ^T):", similarity_matrix_transposed_naive_implementation)
print("naive implementation of transposed cross attention:", attention_transposed)
print("code implementation of transposed cross attention:", attention_code)
print("Difference between attention matrices:", torch.sum(torch.abs(attention_transposed - attention_code)))

Edit: I just missed that " attn = attn / torch.sum(attn, dim=(-1,), keepdim=True)" belongs to the normalization to avoid vanishing gradients, which is explained on page 4: The weights $\mathbf{W}_{ij}^{t}$ may be normalized to 1
along the i dimension to avoid vanishing or exploding quan-
tities due to the summation of un-normalized distribution

If I remove this line, the naive implementation and the code-implementation are the same.
Sorry for the inconvenience.