flax.linen.module.init still fails under dynamic type checking for nested modules
evangelos-ch opened this issue · comments
Related issue: #3224
While the snippet posted in that issue does work now, there still seems to be a failure mode when nested modules (all of which are runtime type checked) are used.
from jax import numpy as jnp
import flax.linen as nn
import jax
from beartype import beartype
from jaxtyping import jaxtyped
@jaxtyped(typechecker=beartype)
class MyModuleInternal(nn.Module):
hidden_size: int = 2
@nn.compact
def __call__(self, x):
return nn.Dense(self.hidden_size)(x)
@jaxtyped(typechecker=beartype)
class MyModule(nn.Module):
hidden_dim: int
def setup(self) -> None:
self.internal_module = MyModuleInternal(self.hidden_dim) # <-- failure here
def __call__(self, x):
x = self.internal_module(x)
return x
model = MyModule(5)
params = model.init(
rngs={"params": jax.random.PRNGKey(0)},
x=jnp.ones((1, 1)),
)
This snippet fails with the following error:
---------------------------------------------------------------------------
BeartypeCallHintParamViolation Traceback (most recent call last)
[... skipping hidden 1 frame]
<@beartype(__main__.check_params) at 0x7d4c408e2830> in check_params(__beartype_get_violation, __beartype_conf, __beartype_object_137766497224064, __beartype_object_99821132912832, __beartype_object_99821132891488, __beartype_object_137766477140992, __beartype_func, *args, **kwargs)
BeartypeCallHintParamViolation: Method __main__.check_params() parameter parent="MyModule(
# attributes
hidden_dim = 5
)" violates type hint typing.Union[typing.Type[flax.linen.module.Module], flax.core.scope.Scope, typing.Type[flax.linen.module._Sentinel], NoneType]
Looking at nn.Module's _ParentType, indeed the type of the argument to parent
is expected to be Type[nn.Module]
so a class, rather than an instance of nn.Module
which is what is actually being passed in. This seems to have been the problem for the previously reported instance of this issue in #3224 , since the PR that fixes it (#3371) changed the type annotation from Type[Scope]
to simply Scope
, to adjust the expectation from a class being provided to an instance.