google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

flax.linen.module.init still fails under dynamic type checking for nested modules

evangelos-ch opened this issue · comments

Related issue: #3224

While the snippet posted in that issue does work now, there still seems to be a failure mode when nested modules (all of which are runtime type checked) are used.

Colab Link

from jax import numpy as jnp
import flax.linen as nn
import jax
from beartype import beartype

from jaxtyping import jaxtyped

@jaxtyped(typechecker=beartype)
class MyModuleInternal(nn.Module):
    hidden_size: int = 2

    @nn.compact
    def __call__(self, x):
      return nn.Dense(self.hidden_size)(x)


@jaxtyped(typechecker=beartype)
class MyModule(nn.Module):
    hidden_dim: int

    def setup(self) -> None:
        self.internal_module = MyModuleInternal(self.hidden_dim)  # <-- failure here
  
    def __call__(self, x):
        x = self.internal_module(x)
        return x


model = MyModule(5)

params = model.init(
    rngs={"params": jax.random.PRNGKey(0)},
    x=jnp.ones((1, 1)),
)

This snippet fails with the following error:

---------------------------------------------------------------------------
BeartypeCallHintParamViolation            Traceback (most recent call last)
    [... skipping hidden 1 frame]

<@beartype(__main__.check_params) at 0x7d4c408e2830> in check_params(__beartype_get_violation, __beartype_conf, __beartype_object_137766497224064, __beartype_object_99821132912832, __beartype_object_99821132891488, __beartype_object_137766477140992, __beartype_func, *args, **kwargs)

BeartypeCallHintParamViolation: Method __main__.check_params() parameter parent="MyModule(
    # attributes
    hidden_dim = 5
)" violates type hint typing.Union[typing.Type[flax.linen.module.Module], flax.core.scope.Scope, typing.Type[flax.linen.module._Sentinel], NoneType]

Looking at nn.Module's _ParentType, indeed the type of the argument to parent is expected to be Type[nn.Module] so a class, rather than an instance of nn.Module which is what is actually being passed in. This seems to have been the problem for the previously reported instance of this issue in #3224 , since the PR that fixes it (#3371) changed the type annotation from Type[Scope] to simply Scope, to adjust the expectation from a class being provided to an instance.