google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

ToLinen is not hashable (Linen modules are)

PhilipVinc opened this issue · comments

Linen modules are washable, so I would expect nnx.bridge.ToLinen to be as well.

In [1]: from flax import linen as nn, nnx

In [2]: import jax

In [3]: model = nnx.bridge.ToLinen(nnx.Linear, args=(32, 64))

In [4]: hash(model)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
File ~/Documents/pythonenvs/netket/python-3.11.2/lib/python3.11/site-packages/flax/linen/module.py:726, in _wrap_hash.<locals>.wrapped(self)
    725 try:
--> 726   hash_value = hash_fn(self)
    727 except TypeError as exc:

File <string>:3, in __hash__(self)

TypeError: unhashable type: 'dict'

The above exception was the direct cause of the following exception:

TypeError                                 Traceback (most recent call last)
Cell In[4], line 1
----> 1 hash(model)

File ~/Documents/pythonenvs/netket/python-3.11.2/lib/python3.11/site-packages/flax/linen/module.py:728, in _wrap_hash.<locals>.wrapped(self)
    726   hash_value = hash_fn(self)
    727 except TypeError as exc:
--> 728   raise TypeError(
    729     'Failed to hash Flax Module.  '
    730     'The module probably contains unhashable attributes.  '
    731     f'Module={self}'
    732   ) from exc
    733 return hash_value

TypeError: Failed to hash Flax Module.  The module probably contains unhashable attributes.  Module=ToLinen(
    # attributes
    nnx_class = Linear
    args = (32, 64)
    kwargs = {}
    skip_rng = False
)

This is problematic because I cannot use ToLinen in lieu of standard linen modules which I pass as static arguments jax jax.jit functions.

I just realised that the non-hashable element is the kwargs dictionary, as hash(dict()) fails.

Using for example flax.core.FrozenDict works correctly

>>> from flax.core import FrozenDict
>>> from flax import linen as nn, nnx
>>> model = nnx.bridge.ToLinen(nnx.Linear, args=(32, 64), kwargs=FrozenDict())
>>> hash(model)
5814856164823000827

Which makes sense
Right now one could do something like

>>> model = nnx.bridge.ToLinen(nnx.Linear, args=(32, 64))
>>> model.kwargs['hello'] = 1

and potentially break some invariants in jax's caching.

I think the default in here should be some sort of frozen dictionary, and it would be reasonable to freeze the dictionary passed in to the ToLinen module. However I'm not sure how to achieve that latter point?

Thanks @PhilipVinc for reporting this! I've sent #4159 to try to address this.