google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

`typing.get_type_hints()` is broken for linen modules

brentyi opened this issue · comments

I have some serialization code that involves a recursive call to get_type_hints(), which breaks for flax modules:

from typing import get_type_hints

from flax import linen as nn


class Network(nn.Module):
    layers: int

# Fails!
# NameError: name 'Module' is not defined
print(get_type_hints(Network))

The reason for this seems to be that forward references are (seemingly unnecessarily) used when fields are being dynamically added to the module dataclass, but the typing module tries to resolve these names in the wrong local namespace:

flax/flax/linen/module.py

Lines 533 to 534 in 96c78cd

parent_annotation = Union[Type["Module"], Type["Scope"],
Type["_Sentinel"], None]

This can be confirmed because adding one extra line fixes the error:

from typing import get_type_hints

from flax import linen as nn
from flax.linen.module import Module, Scope, _Sentinel  # New


class Network(nn.Module):
    layers: int

# Works!
# {'layers': <class 'int'>, 'parent': typing.Union[typing.Type[flax.linen.module.Module], typing.Type[flax.core.scope.Scope], typing.Type[flax.linen.module._Sentinel], NoneType], 'name': <class 'str'>}
print(get_type_hints(Network))

Nice catch, that forward reference is indeed unnecessary because it is only necessary when the class being defined is in the signature of the class method (see PEP 484).

I wasn't aware that this breaks get_type_hints though, but it seems to be a known issue. From the typing documentation:

"Note get_type_hints() does not work with imported type aliases that include forward references. Enabling postponed evaluation of annotations (PEP 563) may remove the need for most forward references".

I'll take a look at your fix.