How is `MultiHeadDotProductAttention` implemented?
ayaka14732 opened this issue · comments
Ayaka commented
Problem you have encountered:
The output of flax.linen.MultiHeadDotProductAttention
seems unfathomable to me. I checked the documentation, but there is no formula for this.
Test script:
import flax.linen as nn
import jax.numpy as np
batch_size = 1
max_sent_len = 2
n_heads = 1
d_model = 3
d_k = 4
d_v = 4
d_ff = 4
src = np.ones((batch_size, max_sent_len, d_model))
src = src.at[0, 0, 0].set(2.)
dst = np.ones((batch_size, max_sent_len, d_model))
q_a = np.ones((n_heads, d_model, d_k))
k_a = np.ones((n_heads, d_model, d_k))
v_a = np.ones((n_heads, d_model, d_v))
q_b = np.zeros((n_heads, d_k))
k_b = np.zeros((n_heads, d_k))
v_b = np.zeros((n_heads, d_v))
ff_a = np.eye(n_heads * d_v, d_ff)
ff_b = np.zeros((d_ff))
mask_dec_1d = np.ones((batch_size, max_sent_len), dtype=np.bool_)
mask = np.einsum('bi,bj->bij', mask_dec_1d, mask_dec_1d)[:, None]
model = nn.MultiHeadDotProductAttention(num_heads=n_heads, qkv_features=d_k * n_heads, out_features=d_ff, broadcast_dropout=False)
output = model.apply({'params': {
'query': {'kernel': q_a.transpose(1, 0, 2), 'bias': q_b},
'key': {'kernel': k_a.transpose(1, 0, 2), 'bias': k_b},
'value': {'kernel': v_a.transpose(1, 0, 2), 'bias': v_b},
'out': {'kernel': ff_a.reshape(n_heads, d_v, d_ff), 'bias': ff_b},
}}, src, dst, mask=mask)
print(output)
Output:
$ python main.py
[[[3. 3. 3. 3.]
[3. 3. 3. 3.]]]
What you expected to happen:
Manual calculation:
src
[[[2. 1. 1.]
[1. 1. 1.]]]
dst
[[[1. 1. 1.]
[1. 1. 1.]]]
q_a
[[[1. 1. 1. 1.]
[1. 1. 1. 1.]
[1. 1. 1. 1.]]]
k_a
[[[1. 1. 1. 1.]
[1. 1. 1. 1.]
[1. 1. 1. 1.]]]
v_a
[[[1. 1. 1. 1.]
[1. 1. 1. 1.]
[1. 1. 1. 1.]]]
ff_a
[[1. 0. 0. 0.]
[0. 1. 0. 0.]
[0. 0. 1. 0.]
[0. 0. 0. 1.]]
ff_b
[0. 0. 0. 0.]
q = np.dot(dst, q_a)
[[[[3. 3. 3. 3.]]
[[3. 3. 3. 3.]]]]
k = np.dot(src, k_a)
[[[[4. 4. 4. 4.]]
[[3. 3. 3. 3.]]]]
v = np.dot(src, v_a)
[[[[4. 4. 4. 4.]]
[[3. 3. 3. 3.]]]]
qk = np.einsum('bkhm,bvhm->bhkv', q, k)
[[[[48. 36.]
[48. 36.]]]]
qk = qk / np.sqrt(d_k)
[[[[24. 18.]
[24. 18.]]]]
qk = nn.softmax(qk)
[[[[0.9975274 0.00247262]
[0.9975274 0.00247262]]]]
t = np.einsum('bhkv,bvhm->bkhm', qk, v)
[[[[3.9917908 3.9917908 3.9917908 3.9917908]]
[[3.9917908 3.9917908 3.9917908 3.9917908]]]]
d0, d1, d2, d3 = t.shape; t = t.reshape(d0, d1, d2 * d3)
[[[3.9917908 3.9917908 3.9917908 3.9917908]
[3.9917908 3.9917908 3.9917908 3.9917908]]]
Equivalent script for calculation:
def fwd_attention(q_kernel: np.ndarray, q_bias: np.ndarray, k_kernel: np.ndarray, k_bias: np.ndarray, v_kernel: np.ndarray, v_bias: np.ndarray, fc1_kernel: np.ndarray, fc1_bias: np.ndarray, src: np.ndarray, dst: np.ndarray, mask: np.ndarray) -> np.ndarray:
_, _, d_k = q_kernel.shape
q = np.dot(dst, q_kernel) + q_bias
k = np.dot(src, k_kernel) + k_bias
v = np.dot(src, v_kernel) + v_bias
qk = np.einsum('bkhm,bvhm->bhkv', q, k)
qk = qk / np.sqrt(d_k)
qk = np.where(mask, qk, np.NINF)
qk = nn.softmax(qk)
qk = np.where(mask, qk, 0)
t = np.einsum('bhkv,bvhm->bkhm', qk, v)
d0, d1, d2, d3 = t.shape
t = t.reshape(d0, d1, d2 * d3)
t = fwd_linear(fc1_kernel, fc1_bias, t)
return t
Expected output
[[[3.9917908 3.9917908 3.9917908 3.9917908]
[3.9917908 3.9917908 3.9917908 3.9917908]]]
Actual output
[[[3. 3. 3. 3.]
[3. 3. 3. 3.]]]
System information
- OS Platform and Distribution: Ubuntu 20.04.4 LTS x86_64
- Flax version: 0.4.0
- JAX version: 0.3.4
- jaxlib version: 0.3.2
- Python version: 3.10.2
- GPU/TPU model and memory: TPU
- libtpu version: libtpu-nightly 0.1.dev20220315
jheek commented
I think you have to reverse the src and dst arguments in model.apply inputs_q = dst
and inputs_kv = src