google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

with_partitioning has surprising behavior with MultiHeadAttention and DenseGeneral

jameslyon opened this issue · comments

Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.

Problem you have encountered:

from pprint import pprint

from flax import linen as nn
import jax
import jax.numpy as jnp

m = nn.MultiHeadDotProductAttention(
    name='mha',
    num_heads=2,
    qkv_features=8,
    deterministic=True,
    kernel_init=nn.with_partitioning(nn.initializers.lecun_normal(), (None, 'data')),
    bias_init=nn.with_partitioning(nn.initializers.zeros, ('data',)),
)

pprint(jax.eval_shape(m.init, jax.random.key(0), jax.ShapeDtypeStruct((1, 2, 8), jnp.float32)))

Gives the result

{'params': {'key': {'bias': Partitioned(value=ShapeDtypeStruct(shape=(2, 4), dtype=float32),
                                        names=('data',),
                                        mesh=None),
                    'kernel': Partitioned(value=ShapeDtypeStruct(shape=(8, 2, 4), dtype=float32),
                                          names=(None, 'data'),
                                          mesh=None)},
            'out': {'bias': Partitioned(value=ShapeDtypeStruct(shape=(8,), dtype=float32),
                                        names=('data',),
                                        mesh=None),
                    'kernel': Partitioned(value=ShapeDtypeStruct(shape=(2, 4, 8), dtype=float32),
                                          names=(None, 'data'),
                                          mesh=None)},
            'query': {'bias': Partitioned(value=ShapeDtypeStruct(shape=(2, 4), dtype=float32),
                                          names=('data',),
                                          mesh=None),
                      'kernel': Partitioned(value=ShapeDtypeStruct(shape=(8, 2, 4), dtype=float32),
                                            names=(None, 'data'),
                                            mesh=None)},
            'value': {'bias': Partitioned(value=ShapeDtypeStruct(shape=(2, 4), dtype=float32),
                                          names=('data',),
                                          mesh=None),
                      'kernel': Partitioned(value=ShapeDtypeStruct(shape=(8, 2, 4), dtype=float32),
                                            names=(None, 'data'),
                                            mesh=None)}}}

The result here means that the QKV weights and biases are partitioned along the num_heads axis rather than the feature axis, which is an issue because num_heads is typically smaller and may not be a multiple of the mesh axis size. The 'out' variables however are partitioned as expected.

What you expected to happen:

The current API doesn't really have a clean solution to this problem because nn.with_partitioning is under-specified in the case where different rank tensors are created using the same initializer. I would like to be able to write a custom initializer which can account for the rank of the tensor being initialized, but that doesn't actually work because DenseGeneral (which MultiHeadSelfAttention uses internally) reshapes variables after creating them in rank 2 (for kernels) or rank 1 (for biases).

So I don't really have an expectation here: I would like some way for nn.with_partitioning to be able to specify MHSA partitioning accurately, but I'm not sure how to get there.

I'm working on this here #3893. Sadly it might take a while to merge because its a breaking change.