Inheritance of dataclasses
daniel-dodd opened this issue · comments
Hi @cgarciae,
Many thanks for swiftly addressing the generic issue #4.
I have possibly found another issue. Let's say I create two classes A
and B
where B
inherits from A
.
# version == 0.1.5
import jax.tree_util as jtu
from simple_pytree import Pytree, static_field
class A(Pytree):
a: int
b: int = static_field()
def __init__(self, a=1, b=2):
self.a = a
self.b = b
class B(A):
c: int
def __init__(self, a=1, b=2, c=3):
super().__init__(a=a, b=b)
self.c = c
If we flatten this, we get the following:
print(jtu.tree_flatten(B()))
([1, 3], PyTreeDef(CustomNode(B[(('a', 'c'), {'b': 2, '_pytree__initialized': True})], [*, *])))
As expected, since b
was static, it does not appear in the flattened pytree.
However, if we do the same using dataclasses
:
import jax.tree_util as jtu
from dataclasses import dataclass
from simple_pytree import Pytree, static_field
@dataclass
class A(Pytree):
a: int = 1
b: int = static_field(2)
@dataclass
class B(A):
c: int = 3
We get a different result.
print(jtu.tree_flatten(B()))
([1, 2, 3], PyTreeDef(CustomNode(B[(('a', 'b', 'c'), {'_pytree__initialized': True})], [*, *, *])))
And the field b
is not being treated as static.
The metadata seems fine.
from dataclasses import fields
# We're good on the class.
print([f.metadata for f in fields(B)])
# And upon instantiation.
print([f.metadata for f in fields(B())])
[mappingproxy({}), mappingproxy({'pytree_node': False}), mappingproxy({})]
[mappingproxy({}), mappingproxy({'pytree_node': False}), mappingproxy({})]
If you agree this behaviour is not expected, I would happy to look into this and open a PR.
Cheers,
Dan :)
Hey @daniel-dodd! Thanks again for reporting the issue. Looking into it.
Fixed in 0.1.6