Error when calling module tabulate involving WeightNorm
DBraun opened this issue · comments
System information
- WSL on Windows 11 Pro
- Flax, jax, jaxlib versions:
Name: flax
Version: 0.8.1
Summary: Flax: A neural network library for JAX designed for flexibility
Home-page:
Author:
Author-email: Flax team <flax-dev@google.com>
License:
Location: /home/admin/.local/lib/python3.10/site-packages
Requires: jax, msgpack, numpy, optax, orbax-checkpoint, PyYAML, rich, tensorstore, typing-extensions
Required-by: clu
---
Name: jax
Version: 0.4.24
Summary: Differentiate, compile, and transform Numpy code.
Home-page: https://github.com/google/jax
Author: JAX team
Author-email: jax-dev@google.com
License: Apache-2.0
Location: /home/admin/.local/lib/python3.10/site-packages
Requires: ml-dtypes, numpy, opt-einsum, scipy
Required-by: chex, clu, flax, jaxloudnorm, optax, orbax-checkpoint
---
Name: jaxlib
Version: 0.4.24+cuda12.cudnn89
Summary: XLA library for JAX
Home-page: https://github.com/google/jax
Author: JAX team
Author-email: jax-dev@google.com
License: Apache-2.0
Location: /home/admin/.local/lib/python3.10/site-packages
Requires: ml-dtypes, numpy, scipy
Required-by: chex, clu, optax, orbax-checkpoint
- Python version: 3.10
- GPU/TPU model and memory: Nvidia RTX 2080 (8 GB)
- CUDA version (if applicable): 12.3
Discussed in #3700
Originally posted by DBraun February 19, 2024
I'm unable to call model.tabulate
with compute_flops=True
or compute_vjp_flops=True
if the module involves flax.linen.WeightNorm
.
Code:
import jax
import jax.numpy as jnp
from flax import linen as nn
model = nn.WeightNorm(nn.Conv(features=64, kernel_size=3, strides=2))
# model = nn.Conv(features=64, kernel_size=3, strides=2)
x = jnp.ones((8, 44100, 1))
key = jax.random.PRNGKey(0)
params = model.init(key, x)
print(model.tabulate(key, x, depth=4,
console_kwargs={'width': 180},
column_kwargs={'width': 180},
compute_flops=True,
compute_vjp_flops=True,
))
y = model.apply(params, x)
print(y.shape)
Output:
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/mnt/c/Users/admin/AppData/Roaming/JetBrains/PyCharm2023.2/scratches/scratch1.py", line 13, in <module>
print(model.tabulate(key, x, depth=4,
File "/home/admin/.local/lib/python3.10/site-packages/flax/linen/module.py", line 2692, in tabulate
return tabulate_fn(*args, **kwargs)
File "/home/admin/.local/lib/python3.10/site-packages/flax/linen/summary.py", line 315, in _tabulate_fn
table = table_fn(rngs, *fn_args, **fn_kwargs, **kwargs)
File "/home/admin/.local/lib/python3.10/site-packages/flax/linen/summary.py", line 488, in _get_table_fn
*_get_call_flops(c, compute_flops, compute_vjp_flops),
File "/home/admin/.local/lib/python3.10/site-packages/flax/linen/summary.py", line 398, in _get_call_flops
variables = jax.eval_shape(init, rngs, dynamic_leaves)
File "/home/admin/.local/lib/python3.10/site-packages/flax/linen/summary.py", line 390, in init
return c.module.init(
AttributeError: "Conv" object has no attribute "layer_forward". If "layer_forward" is defined in '.setup()', remember these fields are only accessible from inside 'init' or 'apply'.
Maybe the answer involves specifying method=...
in the call to tabulate
but I haven't figured it out yet. What's the solution?
For context, the model that I want is based off the following PyTorch code:
import torch.nn as nn
from torch.nn.utils import weight_norm
model = weight_norm(nn.Conv1d(1, 64, 3))
```</div>
When computing flops each Module is run under eval_shape
, there is a weird intereaction between WeightNorm
and the tabulate context here as @chiamp points out.
Thanks for looking into it. One more thing, hopefully the solution will also fix the similar situation for ConvTranspose
:
model = nn.WeightNorm(nn.ConvTranspose(features=64, kernel_size=(3,), strides=(2,), transpose_kernel=True))
Fixed with #3772.