google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

How is `MultiHeadDotProductAttention` implemented?

ayaka14732 opened this issue · comments

commented

Problem you have encountered:

The output of flax.linen.MultiHeadDotProductAttention seems unfathomable to me. I checked the documentation, but there is no formula for this.

Test script:

import flax.linen as nn
import jax.numpy as np

batch_size = 1
max_sent_len = 2
n_heads = 1
d_model = 3
d_k = 4
d_v = 4
d_ff = 4

src = np.ones((batch_size, max_sent_len, d_model))
src = src.at[0, 0, 0].set(2.)
dst = np.ones((batch_size, max_sent_len, d_model))

q_a = np.ones((n_heads, d_model, d_k))
k_a = np.ones((n_heads, d_model, d_k))
v_a = np.ones((n_heads, d_model, d_v))

q_b = np.zeros((n_heads, d_k))
k_b = np.zeros((n_heads, d_k))
v_b = np.zeros((n_heads, d_v))

ff_a = np.eye(n_heads * d_v, d_ff)
ff_b = np.zeros((d_ff))

mask_dec_1d = np.ones((batch_size, max_sent_len), dtype=np.bool_)
mask = np.einsum('bi,bj->bij', mask_dec_1d, mask_dec_1d)[:, None]

model = nn.MultiHeadDotProductAttention(num_heads=n_heads, qkv_features=d_k * n_heads, out_features=d_ff, broadcast_dropout=False)

output = model.apply({'params': {
    'query': {'kernel': q_a.transpose(1, 0, 2), 'bias': q_b},
    'key': {'kernel': k_a.transpose(1, 0, 2), 'bias': k_b},
    'value': {'kernel': v_a.transpose(1, 0, 2), 'bias': v_b},
    'out': {'kernel': ff_a.reshape(n_heads, d_v, d_ff), 'bias': ff_b},
}}, src, dst, mask=mask)

print(output)

Output:

$ python main.py
[[[3. 3. 3. 3.]
  [3. 3. 3. 3.]]]

What you expected to happen:

Manual calculation:

src
[[[2. 1. 1.]
  [1. 1. 1.]]]

dst
[[[1. 1. 1.]
  [1. 1. 1.]]]

q_a
[[[1. 1. 1. 1.]
  [1. 1. 1. 1.]
  [1. 1. 1. 1.]]]

k_a
[[[1. 1. 1. 1.]
  [1. 1. 1. 1.]
  [1. 1. 1. 1.]]]

v_a
[[[1. 1. 1. 1.]
  [1. 1. 1. 1.]
  [1. 1. 1. 1.]]]

ff_a
[[1. 0. 0. 0.]
 [0. 1. 0. 0.]
 [0. 0. 1. 0.]
 [0. 0. 0. 1.]]

ff_b
[0. 0. 0. 0.]
q = np.dot(dst, q_a)
[[[[3. 3. 3. 3.]]
  [[3. 3. 3. 3.]]]]

k = np.dot(src, k_a)
[[[[4. 4. 4. 4.]]
  [[3. 3. 3. 3.]]]]

v = np.dot(src, v_a)
[[[[4. 4. 4. 4.]]
  [[3. 3. 3. 3.]]]]

qk = np.einsum('bkhm,bvhm->bhkv', q, k)
[[[[48. 36.]
   [48. 36.]]]]

qk = qk / np.sqrt(d_k)
[[[[24. 18.]
   [24. 18.]]]]

qk = nn.softmax(qk)
[[[[0.9975274  0.00247262]
   [0.9975274  0.00247262]]]]

t = np.einsum('bhkv,bvhm->bkhm', qk, v)
[[[[3.9917908 3.9917908 3.9917908 3.9917908]]
  [[3.9917908 3.9917908 3.9917908 3.9917908]]]]

d0, d1, d2, d3 = t.shape; t = t.reshape(d0, d1, d2 * d3)
[[[3.9917908 3.9917908 3.9917908 3.9917908]
  [3.9917908 3.9917908 3.9917908 3.9917908]]]

Equivalent script for calculation:

def fwd_attention(q_kernel: np.ndarray, q_bias: np.ndarray, k_kernel: np.ndarray, k_bias: np.ndarray, v_kernel: np.ndarray, v_bias: np.ndarray, fc1_kernel: np.ndarray, fc1_bias: np.ndarray, src: np.ndarray, dst: np.ndarray, mask: np.ndarray) -> np.ndarray:
    _, _, d_k = q_kernel.shape

    q = np.dot(dst, q_kernel) + q_bias
    k = np.dot(src, k_kernel) + k_bias
    v = np.dot(src, v_kernel) + v_bias

    qk = np.einsum('bkhm,bvhm->bhkv', q, k)
    qk = qk / np.sqrt(d_k)
    qk = np.where(mask, qk, np.NINF)
    qk = nn.softmax(qk)
    qk = np.where(mask, qk, 0)

    t = np.einsum('bhkv,bvhm->bkhm', qk, v)
    d0, d1, d2, d3 = t.shape
    t = t.reshape(d0, d1, d2 * d3)

    t = fwd_linear(fc1_kernel, fc1_bias, t)
    return t

Expected output

[[[3.9917908 3.9917908 3.9917908 3.9917908]
  [3.9917908 3.9917908 3.9917908 3.9917908]]]

Actual output

[[[3. 3. 3. 3.]
  [3. 3. 3. 3.]]]

System information

  • OS Platform and Distribution: Ubuntu 20.04.4 LTS x86_64
  • Flax version: 0.4.0
  • JAX version: 0.3.4
  • jaxlib version: 0.3.2
  • Python version: 3.10.2
  • GPU/TPU model and memory: TPU
  • libtpu version: libtpu-nightly 0.1.dev20220315
commented

I think you have to reverse the src and dst arguments in model.apply inputs_q = dst and inputs_kv = src

commented

@jheek Thank you very much! Never thought I'd spend 12 hours debugging a bug of reversing two parameters 😅