google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Error when calling module tabulate involving WeightNorm

DBraun opened this issue · comments

System information

  • WSL on Windows 11 Pro
  • Flax, jax, jaxlib versions:
Name: flax
Version: 0.8.1
Summary: Flax: A neural network library for JAX designed for flexibility
Home-page:
Author:
Author-email: Flax team <flax-dev@google.com>
License:
Location: /home/admin/.local/lib/python3.10/site-packages
Requires: jax, msgpack, numpy, optax, orbax-checkpoint, PyYAML, rich, tensorstore, typing-extensions
Required-by: clu
---
Name: jax
Version: 0.4.24
Summary: Differentiate, compile, and transform Numpy code.
Home-page: https://github.com/google/jax
Author: JAX team
Author-email: jax-dev@google.com
License: Apache-2.0
Location: /home/admin/.local/lib/python3.10/site-packages
Requires: ml-dtypes, numpy, opt-einsum, scipy
Required-by: chex, clu, flax, jaxloudnorm, optax, orbax-checkpoint
---
Name: jaxlib
Version: 0.4.24+cuda12.cudnn89
Summary: XLA library for JAX
Home-page: https://github.com/google/jax
Author: JAX team
Author-email: jax-dev@google.com
License: Apache-2.0
Location: /home/admin/.local/lib/python3.10/site-packages
Requires: ml-dtypes, numpy, scipy
Required-by: chex, clu, optax, orbax-checkpoint
  • Python version: 3.10
  • GPU/TPU model and memory: Nvidia RTX 2080 (8 GB)
  • CUDA version (if applicable): 12.3

Discussed in #3700

Originally posted by DBraun February 19, 2024
I'm unable to call model.tabulate with compute_flops=True or compute_vjp_flops=True if the module involves flax.linen.WeightNorm.

Code:

import jax
import jax.numpy as jnp
from flax import linen as nn

model = nn.WeightNorm(nn.Conv(features=64, kernel_size=3, strides=2))
# model = nn.Conv(features=64, kernel_size=3, strides=2)

x = jnp.ones((8, 44100, 1))

key = jax.random.PRNGKey(0)
params = model.init(key, x)

print(model.tabulate(key, x, depth=4,
                     console_kwargs={'width': 180},
                     column_kwargs={'width': 180},
                     compute_flops=True,
                     compute_vjp_flops=True,
                     ))

y = model.apply(params, x)
print(y.shape)

Output:

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/mnt/c/Users/admin/AppData/Roaming/JetBrains/PyCharm2023.2/scratches/scratch1.py", line 13, in <module>
    print(model.tabulate(key, x, depth=4,
  File "/home/admin/.local/lib/python3.10/site-packages/flax/linen/module.py", line 2692, in tabulate
    return tabulate_fn(*args, **kwargs)
  File "/home/admin/.local/lib/python3.10/site-packages/flax/linen/summary.py", line 315, in _tabulate_fn
    table = table_fn(rngs, *fn_args, **fn_kwargs, **kwargs)
  File "/home/admin/.local/lib/python3.10/site-packages/flax/linen/summary.py", line 488, in _get_table_fn
    *_get_call_flops(c, compute_flops, compute_vjp_flops),
  File "/home/admin/.local/lib/python3.10/site-packages/flax/linen/summary.py", line 398, in _get_call_flops
    variables = jax.eval_shape(init, rngs, dynamic_leaves)
  File "/home/admin/.local/lib/python3.10/site-packages/flax/linen/summary.py", line 390, in init
    return c.module.init(
AttributeError: "Conv" object has no attribute "layer_forward". If "layer_forward" is defined in '.setup()', remember these fields are only accessible from inside 'init' or 'apply'.

Maybe the answer involves specifying method=... in the call to tabulate but I haven't figured it out yet. What's the solution?

For context, the model that I want is based off the following PyTorch code:

import torch.nn as nn
from torch.nn.utils import weight_norm
model = weight_norm(nn.Conv1d(1, 64, 3))
```</div>

Seems like init is called on Conv (c.module) using the method layer_forward (c.method), but I'm not sure why there is a _CallInfo object with these attributes since layer_forward is a WeightNorm method, not a Conv method. Any thoughts @cgarciae?

When computing flops each Module is run under eval_shape, there is a weird intereaction between WeightNorm and the tabulate context here as @chiamp points out.

Thanks for looking into it. One more thing, hopefully the solution will also fix the similar situation for ConvTranspose :

model = nn.WeightNorm(nn.ConvTranspose(features=64, kernel_size=(3,), strides=(2,), transpose_kernel=True))

Fixed with #3772.