Can not assign dict whos key is not string as module attribute
wztdream opened this issue · comments
wztdream commented
Hi,
It seems current flax.linen not allow assign a dict with non-string keys to the module attribute.
See below simple example:
it will trigger error:
AssertionError: A state dict must only have string keys.
Questions:
- Is it an intended behavior? Why?
- If it is intended, is there any work around? As it is quite possible we need assign the information contained by a dict to the module, and the key of the dict may not string.
import flax.linen as nn
import jax
import jax.numpy as jnp
class Foo(nn.Module):
def setup(self):
self.a = {(1, 2): 3} # here the dict using tuple as key
@nn.compact
def __call__(self, x):
return x
foo = Foo()
rng = jax.random.PRNGKey(0)
x = jnp.ones(shape=(3, 3))
vars = foo.init({"params": rng}, x)
out = foo.apply(vars, x)
print(out)
jheek commented
This should be fixed
Marc van Zee commented
@jheek just tried this in a public Colab and installed flax from main, but the problem still seems to be there.
Anselm Levskaya commented
It's because we traverse any assignment looking for Module leaves and have overly strict requirements on the structure of the tree (e.g. string keys) for any leaf Module and that's spilling over as a constraint on any leaf type.