google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Clarify expectations for `variables` dict in apply()

melissatan opened this issue · comments

Filing this to document some points for future reference, which came up while working on a fix for #1768.

#1768 is a case where a user unintentionally passed variables={'params': {'params': ...}}. This can be solved by adding a check inside apply().

There's a broader case to consider, of whether the library should enforce that the variables dict always contains a 'params' key. This would catch cases where the user unintentionally passes variables={'kernel': ...}, though I'd guess that that is relatively rarer.

Benefits of enforcing:

  • Easier to detect invalid input early on

Costs of enforcing:

  • It may not always be correct to assume that variables contains 'params', for real usecases.
  • Would require a bunch of updates across the Flax tests: there are multiple instances of a variables dict not containing 'params' where it was intended to be valid, e.g. in linen_transforms_test
(Pdb) variables
FrozenDict({
    test: {
        inner: {
            baz: DeviceArray([1.], dtype=float32),
        },
    },
})

That along with the docstring in variables.py would need to be updated, if the library enforces 'params'.

Closing this for now because no opinions have been raised on either side of this discussion. But if any come in I'm happy to pick it back up.