google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

*Module Parameters* section of docs is outdated.

PaulScemama opened this issue · comments

Hi, first off thanks for a great library -- flax is awesome.

I wanted to revisit the documentation to gain a better understanding of flax. In basics there is a section on module parameters.

I wanted to point out that it would appear as though the code seems to not work at the moment.

Here is a stripped version of what is currently in the docs

import flax.linen as nn
import jax.numpy as jnp
import jax.random as random


class SimpleDense(nn.Module):
  features: int
  kernel_init = nn.initializers.lecun_normal()

  @nn.compact
  def __call__(self, inputs):
    kernel = self.param('kernel',
                        self.kernel_init, # Initialization function
                        (inputs.shape[-1], self.features))  # init_args
    y = jnp.dot(inputs, kernel)
    return y

x = jnp.ones((1, 7))
model = SimpleDense(features=3)
key, init_key = random.split(random.key(123))

params = model.init(init_key, x)
# Error: TypeError: Cannot interpret '7' as a data type

Seems to be something to do with how *init_args is being unpacked. I tried reproducing similar behaviour with the following

initializer = nn.initializers.glorot_normal()

def foo(rng_key, args):
    
    def initialize():
        return nn.initializers.glorot_normal()(rng_key, *args)

    return initialize()

foo(random.key(1), (4,5))
# TypeError: Cannot interpret '5' as a data type

But I had trouble navigating the flax codebase as I am unfamiliar with it. Thanks again!

You need to specify a type annotation to the dataclass field:

class SimpleDense(nn.Module):
  features: int
  kernel_init: Callable = nn.initializers.lecun_normal()
  ...

@chiamp thanks!

I also think maybe an error message for not type annotating the dataclass field may be good, since the error message that came from it was a bit cryptic.

Not adding a type annotation turns kernel_init into a class method:

class SimpleDense(nn.Module):
  features: int
  kernel_init = nn.initializers.lecun_normal()

SimpleDense.kernel_init(jax.random.key(0), (1, 1)) == nn.initializers.lecun_normal()(jax.random.key(0), (1, 1))

I believe there are use-cases for these, but @cgarciae can speak more to this.

Ahh I see @chiamp. So when we don't type annotate kernel_init, it becomes a bound method. E.g.

from typing import Callable

import flax.linen as nn
import jax.numpy as jnp
import jax.random as random


class SimpleDense(nn.Module):
  features: int
  kernel_init= nn.initializers.lecun_normal()


x = jnp.ones((1, 7))
model = SimpleDense(features=3)
print(model.kernel_init)
# <bound method variance_scaling.<locals>.init of SimpleDense(
#    # attributes
#    features = 3
# )>

And then when we type annotate, it is only an attribute of the class (not bound).

from typing import Callable

import flax.linen as nn
import jax.numpy as jnp
import jax.random as random


class SimpleDense(nn.Module):
  features: int
  kernel_init: Callable = nn.initializers.lecun_normal()


x = jnp.ones((1, 7))
model = SimpleDense(features=3)
print(model.kernel_init)
# <function variance_scaling.<locals>.init at 0x7f498fb66200>

In the former case, this boundedness messed up the order of the passing in the arguments to it during the initialization of self.param (see top of thread).