Crash when using generic module class
NeilGirdhar opened this issue · comments
Neil Girdhar commented
from typing import Any, Generic, TypeVar
import flax.linen as nn
T = TypeVar('T')
class C(nn.Module, Generic[T]):
def f(self, t: T) -> T:
return t
class D(nn.Module):
def setup(self):
c = C[Any]()
def __call__(self) -> None:
pass
rngs = {}
D().init(rngs)
Gives
Traceback (most recent call last):
File "/home/neil/src/cmm/a.py", line 21, in <module>
D().init(rngs)
File "/home/neil/src/cmm/a.py", line 14, in setup
c = C[Any]()
File "/home/neil/.pyenv/versions/3.10.1/lib/python3.10/typing.py", line 946, in __call__
result.__orig_class__ = self
AssertionError: Trying to register submodules on unbound scope.
in CPython 3.10.1.