google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

nnx static fields not part of static tree structure

NeilGirdhar opened this issue · comments

from flax.experimental import nnx
from jax import jit
from jax.tree import flatten


class C(nnx.Module, experimental_pytree=True):
    def __init__(self, x):
        self.x = x


c = C(1)
d = C(2)

values, tree, = flatten(c)
valuesd, treed, = flatten(d)


@jit
def f(x):
    print(x.x)


print(hash(tree), hash(treed), tree, treed, values, valuesd)
f(c)
f(d)
f(c)

Despite the static fields being different:

  static_fields={
    'x': 1
  }
# vs
  static_fields={
    'x': 2
  }

the hashes are the same and therefore the jitted function is called once.

Have I misunderstood how the pytree flattener is supposed to work?

@cgarciae Thanks! This works on master (prints "1 2"), but not on 0.8.2 (prints "1").