google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Error when calling `Module.tabulate` on normalization wrappers like `WeightNorm` and `SpectralNorm`

chiamp opened this issue · comments

Follow-up from #3735. Partial fix in #3772.

Minimum repro:

import jax, jax.numpy as jnp
from flax import linen as nn

model = nn.WeightNorm(nn.Dense(3))
x = jnp.ones((1, 2))
key = jax.random.key(0)

print(model.tabulate(key,
                     x,
                     compute_flops=True,
                     compute_vjp_flops=True,
                     ))

Error message:

Traceback (most recent call last):
  File "/Users/marcuschiam/Desktop/asdf.py", line 20, in <module>
    print(model.tabulate(key,
  File "/Users/marcuschiam/flax/flax/linen/module.py", line 2843, in tabulate
    return tabulate_fn(*args, **kwargs)
  File "/Users/marcuschiam/flax/flax/linen/summary.py", line 315, in _tabulate_fn
    table = table_fn(rngs, *fn_args, **fn_kwargs, **kwargs)
  File "/Users/marcuschiam/flax/flax/linen/summary.py", line 490, in _get_table_fn
    *_get_call_flops(c, compute_flops, compute_vjp_flops),
  File "/Users/marcuschiam/flax/flax/linen/summary.py", line 400, in _get_call_flops
    variables = jax.eval_shape(init, rngs, dynamic_leaves)
  File "/Users/marcuschiam/flax/flax/linen/summary.py", line 392, in init
    return c.module.init(
AttributeError: "Dense" object has no attribute "layer_forward". If "layer_forward" is defined in '.setup()', remember these fields are only accessible from inside 'init' or 'apply'.

When calling tabulate(), calls contains _CallInfo objects where the method does not exist in the module, causing the AttributeError seen in #3735.

This is illustrated by inspecting the _CallInfo objects in the debugger:

for c in calls: 
  print(f'Index: {c.index}\tModule: {type(c.module)}\tPath: {c.path}\tMethod: {c.method}')
Index: 0        Module: <class 'flax.linen.normalization.WeightNorm'>   Path: ()        Method: __call__
Index: 1        Module: <class 'flax.linen.linear.Dense'>       Path: ()        Method: <lambda>
Index: 2        Module: <class 'flax.linen.linear.Dense'>       Path: ('layer_instance',)       Method: layer_forward
Index: 3        Module: <class 'flax.linen.linear.Dense'>       Path: ('layer_instance',)       Method: __call__
Index: 4        Module: <class 'flax.linen.normalization.WeightNorm'>   Path: ()        Method: _l2_normalize
Index: 5        Module: <class 'flax.linen.normalization.WeightNorm'>   Path: ()        Method: _l2_normalize
Index: 6        Module: <class 'flax.linen.linear.Dense'>       Path: ('layer_instance',)       Method: layer_forward
Index: 7        Module: <class 'flax.linen.linear.Dense'>       Path: ('layer_instance',)       Method: __call__

For the _CallInfo object with index 2, Dense does not have a layer_forward method/attribute and an error is thrown.
For the _CallInfo object with index 1, the <lambda> function is skipped because the path () has already been visited in index 0.