flax.linen.Module without __call__ can't initialize params
kho opened this issue · comments
Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.
Problem you have encountered:
The flax documentation on flax.linen.Module says You can define arbitrary “forward pass” methods on your Module subclass. While no methods are special-cased, __call__ is a popular choice
but this simple module below doesn't work. It seems I can only define "forward pass" methods in addition to __call__
.
class NoCall(nn.Module):
def setup(self):
self.b = self.param('b', nn.initializers.ones, ())
def work(self, x):
return x + self.b
NoCall().init(jax.random.PRNGKey(0), jnp.zeros([]))
What you expected to happen:
No exception.
Logs, error messages, etc:
---------------------------------------------------------------------------
UnfilteredStackTrace Traceback (most recent call last)
<ipython-input-3-21bb9eb06d37> in <module>()
12
---> 13 NoCall().init(jax.random.PRNGKey(0), jnp.zeros([]))
8 frames
/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
161 try:
--> 162 return fun(*args, **kwargs)
163 except Exception as e:
/usr/local/lib/python3.7/dist-packages/flax/linen/module.py in init(self, rngs, method, mutable, *args, **kwargs)
1123 rngs, *args,
-> 1124 method=method, mutable=mutable, **kwargs)
1125 return v_out
/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
161 try:
--> 162 return fun(*args, **kwargs)
163 except Exception as e:
/usr/local/lib/python3.7/dist-packages/flax/linen/module.py in init_with_output(self, rngs, method, mutable, *args, **kwargs)
1091 return self.apply(
-> 1092 {}, *args, rngs=rngs, method=method, mutable=mutable, **kwargs)
1093
/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
161 try:
--> 162 return fun(*args, **kwargs)
163 except Exception as e:
/usr/local/lib/python3.7/dist-packages/flax/linen/module.py in apply(self, variables, rngs, method, mutable, capture_intermediates, *args, **kwargs)
1055 if method is None:
-> 1056 method = self.__call__
1057 method = _get_unbound_fn(method)
/usr/local/lib/python3.7/dist-packages/flax/linen/module.py in __getattr__(self, name)
658 raise AttributeError(
--> 659 f'"{self.__class__.__name__}" object has no attribute "{name}"')
660
UnfilteredStackTrace: AttributeError: "NoCall" object has no attribute "__call__"
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
AttributeError Traceback (most recent call last)
<ipython-input-3-21bb9eb06d37> in <module>()
11 return x + self.b
12
---> 13 NoCall().init(jax.random.PRNGKey(0), jnp.zeros([]))
/usr/local/lib/python3.7/dist-packages/flax/linen/module.py in __getattr__(self, name)
657 else:
658 raise AttributeError(
--> 659 f'"{self.__class__.__name__}" object has no attribute "{name}"')
660
661 def __dir__(self) -> List[str]:
AttributeError: "NoCall" object has no attribute "__call__"
Steps to reproduce:
See code snippet above.
Hey! You have to use the method
argument if you want the forward to be something other than __call__
e.g.
module = NoCall()
variables = module.init(jax.random.PRNGKey(0), jnp.zeros([]), method=module.work)
Thanks for the tip! Sorry for missing that in the docs.