google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Can not assign dict whos key is not string as module attribute

wztdream opened this issue · comments

Hi,
It seems current flax.linen not allow assign a dict with non-string keys to the module attribute.
See below simple example:
it will trigger error:
AssertionError: A state dict must only have string keys.

Questions:

  1. Is it an intended behavior? Why?
  2. If it is intended, is there any work around? As it is quite possible we need assign the information contained by a dict to the module, and the key of the dict may not string.
import flax.linen as nn
import jax
import jax.numpy as jnp


class Foo(nn.Module):
    def setup(self):
        self.a = {(1, 2): 3} # here the dict using tuple as key

    @nn.compact
    def __call__(self, x):
        return x


foo = Foo()
rng = jax.random.PRNGKey(0)
x = jnp.ones(shape=(3, 3))
vars = foo.init({"params": rng}, x)
out = foo.apply(vars, x)
print(out)
commented

This should be fixed

@jheek just tried this in a public Colab and installed flax from main, but the problem still seems to be there.

It's because we traverse any assignment looking for Module leaves and have overly strict requirements on the structure of the tree (e.g. string keys) for any leaf Module and that's spilling over as a constraint on any leaf type.