google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

DenseGeneral with more than 2 dimensions cannot be partitioned

kvablack opened this issue · comments

DenseGeneral first computes a 2-dimensional "flat shape" to initialize its kernel (see this line) and then later reshapes the kernel to the correct shape. However, the partitioning API (e.g., nn.with_partitioning) works by wrapping the kernel initializer. So if you have a DenseGeneral layer with a 3D kernel, and you try to partition this kernel, e.g.,

dense = partial(
            nn.DenseGeneral,
            axis=-1,
            dtype=self.dtype,
            features=(self.num_heads, head_dim),
            kernel_init=nn.with_logical_partitioning(
                self.kernel_init, ("embed", "heads", "head_dim")
            ),
        )

JAX throws an error complaining about a 3D sharding constraint on a 2D array, because Flax tries to apply the constraint before the reshape (see this line).

@kvablack its a very good point. I don't know what was the original reason for this design, I'll ask the team and get back.