google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

nnx.initializers.uniform should support custom lower and upper bounds

SamPruden opened this issue · comments

commented

This is a tiny petty usability thing, but I just had to write this as an nnx.Conv argument

kernel_init = lambda key, shape, dtype: jax.random.uniform(
  key, shape, dtype,
  minval = -args.conv_init,
  maxval = args.conv_init
) if args.conv_init else None,

Which felt quite silly because it felt like I should be able to do

kernel_init = nnx.initializers.uniform(-args.conv_init, args.conv_init) if args.conv_init else None,

However initializers.uniform only takes a single scale parameter and outputs in [0, scale) for some reason. I would say that where applicable the initializers should act like wrappers around the equivalent jax.random functions and offer the same options.

I would say the same for initializers.normal, however I've just noticed that random.normal doesn't let you choose the mean or stddev. That's quite surprising. I would expect those options to be available in both places.

commented

I'm being a little unfair by not golfing the manual case, so I should point out that it can be written a bit more compactly:

kernel_init = lambda *a: random.uniform(*a, -args.conv_init, args.conv_init) if args.conv_init else None,

It's not too bad but it's not quite as clean as it could be.

Whilst I'm on the topic of tiny petty unimportant things about initializers and randomness, it would also be nice if random.uniform could accept an int as a shape instead of having to do (n, ).

Hey, we mostly just re-export jax's initializers from jax.nn.initializers for convenience.

commented

Hey, we mostly just re-export jax's initializers from jax.nn.initializers for convenience.

Ah sorry, of course! I'll refile there if I can be bothered to annoy them with something so trivial.