Clarify expectations for `variables` dict in apply()
melissatan opened this issue · comments
Filing this to document some points for future reference, which came up while working on a fix for #1768.
#1768 is a case where a user unintentionally passed variables={'params': {'params': ...}}
. This can be solved by adding a check inside apply().
There's a broader case to consider, of whether the library should enforce that the variables
dict always contains a 'params' key. This would catch cases where the user unintentionally passes variables={'kernel': ...}
, though I'd guess that that is relatively rarer.
Benefits of enforcing:
- Easier to detect invalid input early on
Costs of enforcing:
- It may not always be correct to assume that
variables
contains 'params', for real usecases. - Would require a bunch of updates across the Flax tests: there are multiple instances of a variables dict not containing 'params' where it was intended to be valid, e.g. in linen_transforms_test
(Pdb) variables
FrozenDict({
test: {
inner: {
baz: DeviceArray([1.], dtype=float32),
},
},
})
That along with the docstring in variables.py would need to be updated, if the library enforces 'params'.
Closing this for now because no opinions have been raised on either side of this discussion. But if any come in I'm happy to pick it back up.