google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

flax.linen.Module without __call__ can't initialize params

kho opened this issue · comments

Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.

Problem you have encountered:

The flax documentation on flax.linen.Module says You can define arbitrary “forward pass” methods on your Module subclass. While no methods are special-cased, __call__ is a popular choice but this simple module below doesn't work. It seems I can only define "forward pass" methods in addition to __call__.

class NoCall(nn.Module):

  def setup(self):
    self.b = self.param('b', nn.initializers.ones, ())

  def work(self, x):
    return x + self.b

NoCall().init(jax.random.PRNGKey(0), jnp.zeros([]))

What you expected to happen:

No exception.

Logs, error messages, etc:

---------------------------------------------------------------------------
UnfilteredStackTrace                      Traceback (most recent call last)
<ipython-input-3-21bb9eb06d37> in <module>()
     12 
---> 13 NoCall().init(jax.random.PRNGKey(0), jnp.zeros([]))

8 frames
/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
    161     try:
--> 162       return fun(*args, **kwargs)
    163     except Exception as e:

/usr/local/lib/python3.7/dist-packages/flax/linen/module.py in init(self, rngs, method, mutable, *args, **kwargs)
   1123         rngs, *args,
-> 1124         method=method, mutable=mutable, **kwargs)
   1125     return v_out

/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
    161     try:
--> 162       return fun(*args, **kwargs)
    163     except Exception as e:

/usr/local/lib/python3.7/dist-packages/flax/linen/module.py in init_with_output(self, rngs, method, mutable, *args, **kwargs)
   1091     return self.apply(
-> 1092         {}, *args, rngs=rngs, method=method, mutable=mutable, **kwargs)
   1093 

/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
    161     try:
--> 162       return fun(*args, **kwargs)
    163     except Exception as e:

/usr/local/lib/python3.7/dist-packages/flax/linen/module.py in apply(self, variables, rngs, method, mutable, capture_intermediates, *args, **kwargs)
   1055     if method is None:
-> 1056       method = self.__call__
   1057     method = _get_unbound_fn(method)

/usr/local/lib/python3.7/dist-packages/flax/linen/module.py in __getattr__(self, name)
    658       raise AttributeError(
--> 659           f'"{self.__class__.__name__}" object has no attribute "{name}"')
    660 

UnfilteredStackTrace: AttributeError: "NoCall" object has no attribute "__call__"

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

AttributeError                            Traceback (most recent call last)
<ipython-input-3-21bb9eb06d37> in <module>()
     11     return x + self.b
     12 
---> 13 NoCall().init(jax.random.PRNGKey(0), jnp.zeros([]))

/usr/local/lib/python3.7/dist-packages/flax/linen/module.py in __getattr__(self, name)
    657     else:
    658       raise AttributeError(
--> 659           f'"{self.__class__.__name__}" object has no attribute "{name}"')
    660 
    661   def __dir__(self) -> List[str]:

AttributeError: "NoCall" object has no attribute "__call__"

Steps to reproduce:

See code snippet above.

Hey! You have to use the method argument if you want the forward to be something other than __call__ e.g.

module = NoCall()
variables = module.init(jax.random.PRNGKey(0), jnp.zeros([]), method=module.work)

Thanks for the tip! Sorry for missing that in the docs.