nnx static fields not part of static tree structure
NeilGirdhar opened this issue · comments
Neil Girdhar commented
from flax.experimental import nnx
from jax import jit
from jax.tree import flatten
class C(nnx.Module, experimental_pytree=True):
def __init__(self, x):
self.x = x
c = C(1)
d = C(2)
values, tree, = flatten(c)
valuesd, treed, = flatten(d)
@jit
def f(x):
print(x.x)
print(hash(tree), hash(treed), tree, treed, values, valuesd)
f(c)
f(d)
f(c)
Despite the static fields being different:
static_fields={
'x': 1
}
# vs
static_fields={
'x': 2
}
the hashes are the same and therefore the jitted function is called once.
Have I misunderstood how the pytree flattener is supposed to work?
Neil Girdhar commented
@cgarciae Thanks! This works on master (prints "1 2"), but not on 0.8.2 (prints "1").