ToLinen is not hashable (Linen modules are)
PhilipVinc opened this issue · comments
Linen modules are washable, so I would expect nnx.bridge.ToLinen
to be as well.
In [1]: from flax import linen as nn, nnx
In [2]: import jax
In [3]: model = nnx.bridge.ToLinen(nnx.Linear, args=(32, 64))
In [4]: hash(model)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
File ~/Documents/pythonenvs/netket/python-3.11.2/lib/python3.11/site-packages/flax/linen/module.py:726, in _wrap_hash.<locals>.wrapped(self)
725 try:
--> 726 hash_value = hash_fn(self)
727 except TypeError as exc:
File <string>:3, in __hash__(self)
TypeError: unhashable type: 'dict'
The above exception was the direct cause of the following exception:
TypeError Traceback (most recent call last)
Cell In[4], line 1
----> 1 hash(model)
File ~/Documents/pythonenvs/netket/python-3.11.2/lib/python3.11/site-packages/flax/linen/module.py:728, in _wrap_hash.<locals>.wrapped(self)
726 hash_value = hash_fn(self)
727 except TypeError as exc:
--> 728 raise TypeError(
729 'Failed to hash Flax Module. '
730 'The module probably contains unhashable attributes. '
731 f'Module={self}'
732 ) from exc
733 return hash_value
TypeError: Failed to hash Flax Module. The module probably contains unhashable attributes. Module=ToLinen(
# attributes
nnx_class = Linear
args = (32, 64)
kwargs = {}
skip_rng = False
)
This is problematic because I cannot use ToLinen
in lieu of standard linen modules which I pass as static arguments jax jax.jit functions.
I just realised that the non-hashable element is the kwargs
dictionary, as hash(dict())
fails.
Using for example flax.core.FrozenDict
works correctly
>>> from flax.core import FrozenDict
>>> from flax import linen as nn, nnx
>>> model = nnx.bridge.ToLinen(nnx.Linear, args=(32, 64), kwargs=FrozenDict())
>>> hash(model)
5814856164823000827
Which makes sense
Right now one could do something like
>>> model = nnx.bridge.ToLinen(nnx.Linear, args=(32, 64))
>>> model.kwargs['hello'] = 1
and potentially break some invariants in jax's caching.
I think the default in here should be some sort of frozen dictionary, and it would be reasonable to freeze the dictionary passed in to the ToLinen
module. However I'm not sure how to achieve that latter point?
Thanks @PhilipVinc for reporting this! I've sent #4159 to try to address this.