nnx.initializers.uniform should support custom lower and upper bounds
SamPruden opened this issue · comments
This is a tiny petty usability thing, but I just had to write this as an nnx.Conv
argument
kernel_init = lambda key, shape, dtype: jax.random.uniform(
key, shape, dtype,
minval = -args.conv_init,
maxval = args.conv_init
) if args.conv_init else None,
Which felt quite silly because it felt like I should be able to do
kernel_init = nnx.initializers.uniform(-args.conv_init, args.conv_init) if args.conv_init else None,
However initializers.uniform
only takes a single scale
parameter and outputs in [0, scale)
for some reason. I would say that where applicable the initializers should act like wrappers around the equivalent jax.random
functions and offer the same options.
I would say the same for initializers.normal
, however I've just noticed that random.normal
doesn't let you choose the mean or stddev. That's quite surprising. I would expect those options to be available in both places.
I'm being a little unfair by not golfing the manual case, so I should point out that it can be written a bit more compactly:
kernel_init = lambda *a: random.uniform(*a, -args.conv_init, args.conv_init) if args.conv_init else None,
It's not too bad but it's not quite as clean as it could be.
Whilst I'm on the topic of tiny petty unimportant things about initializers and randomness, it would also be nice if random.uniform
could accept an int
as a shape instead of having to do (n, )
.
Hey, we mostly just re-export jax's initializers from jax.nn.initializers for convenience.
Hey, we mostly just re-export jax's initializers from jax.nn.initializers for convenience.
Ah sorry, of course! I'll refile there if I can be bothered to annoy them with something so trivial.