DenseGeneral with more than 2 dimensions cannot be partitioned
kvablack opened this issue · comments
DenseGeneral first computes a 2-dimensional "flat shape" to initialize its kernel (see this line) and then later reshapes the kernel to the correct shape. However, the partitioning API (e.g., nn.with_partitioning
) works by wrapping the kernel initializer. So if you have a DenseGeneral layer with a 3D kernel, and you try to partition this kernel, e.g.,
dense = partial(
nn.DenseGeneral,
axis=-1,
dtype=self.dtype,
features=(self.num_heads, head_dim),
kernel_init=nn.with_logical_partitioning(
self.kernel_init, ("embed", "heads", "head_dim")
),
)
JAX throws an error complaining about a 3D sharding constraint on a 2D array, because Flax tries to apply the constraint before the reshape (see this line).
@kvablack its a very good point. I don't know what was the original reason for this design, I'll ask the team and get back.