mu parametrization for channel attention
xwjabc opened this issue · comments
Hi, I have another question about the mu parametrization for a special attention mechanism - channel attention.
In standard scaled dot-product attention (also regarded as spatial attention), we have Q, K, V with shape n x d
(ignoring heads) and we will calculate softmax(scale * Q K^T) V
to get a n x d
output, where scale = 1/sqrt(d)
in SP and scale = 1/d
in muP (or 1/sqrt(d_0) / width_mult
in muP for backward compatiblity).
In channel attention, we still have Q, K, V with shape n x d
(ignoring heads). The different part is, we will calculate (softmax(scale * Q^T K) V^T)^T
to get a n x d
output, where scale = 1/sqrt(n)
in SP. Since the attention map Q^T K
now has shape d x d
instead of n x n
, I am not sure how the scale should be modified in SP accordingly. Should we use 1/sqrt(n) / width_mult
?
In addition, Appendix B - Matrix-Like, Vector-Like, Scalar-Like Parameters has some interpretation behind the scale:
a multiplier of order 1=fan_in should accompany any weight that maps an infinite dimension to a finite one. This interpretation then nicely covers both the output logits and the attention logits (i.e. 1/d attention).
But such interpretation may not be directly used as a guidance to set up the scale in the channel attention.
Thanks!
Since n (which I assume to be the batch size) is finite, the coordinates of Q^T K
are \Theta(1) in d. The matmul of Q^T K
and V^T
, however, involves a summation over d, which needs to be scaled down by 1/d. So you want something like (softmax(scale * Q^T K) V^T / width_mult)^T
.
Another way to look at this is that we are mapping a tensor with two inf dimensions (Q^T K) to a tensor with just one inf dimension; hence, we need to scale by 1/fan_in after this mapping.
Yes n
is the number of tokens which is finite, whereas d
is the feature dimension which is infinite. I think I got it, and if I understood it correctly, what I need should be (softmax(scale * Q^T K) V^T / width_mult)^T = (softmax(1/sqrt(n) * Q^T K) V^T / width_mult)^T
Thank you for your quick response!
In addition, if we scale the n_head
instead of d_head
in channel attention, does it mean that we can simply use the original (softmax(scale * Q^T K) V^T)^T
? Thanks!
Sorry for the delay. I missed this earlier.
For a fixed d_head
the original formulation looks good as long as scale
isn't a function of d_head
, since we don't have a summation over infinite many coordinates anymore.
Gotcha. Thank you, Edward!