google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Crash when using generic module class

NeilGirdhar opened this issue · comments

from typing import Any, Generic, TypeVar

import flax.linen as nn

T = TypeVar('T')

class C(nn.Module, Generic[T]):
    def f(self, t: T) -> T:
        return t


class D(nn.Module):
    def setup(self):
        c = C[Any]()

    def __call__(self) -> None:
        pass


rngs = {}
D().init(rngs)

Gives

Traceback (most recent call last):
  File "/home/neil/src/cmm/a.py", line 21, in <module>
    D().init(rngs)
  File "/home/neil/src/cmm/a.py", line 14, in setup
    c = C[Any]()
  File "/home/neil/.pyenv/versions/3.10.1/lib/python3.10/typing.py", line 946, in __call__
    result.__orig_class__ = self
AssertionError: Trying to register submodules on unbound scope.

in CPython 3.10.1.