cgarciae / simple-pytree

A dead simple Python package for creating custom JAX pytree objects

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Inheritance of dataclasses

daniel-dodd opened this issue · comments

Hi @cgarciae,

Many thanks for swiftly addressing the generic issue #4.

I have possibly found another issue. Let's say I create two classes A and B where B inherits from A.

# version == 0.1.5
import jax.tree_util as jtu
from simple_pytree import Pytree, static_field

class A(Pytree):
    a: int
    b: int = static_field()

    def __init__(self, a=1, b=2):
        self.a = a
        self.b = b

class B(A):
    c: int

    def __init__(self, a=1, b=2, c=3):
        super().__init__(a=a, b=b)
        self.c = c

If we flatten this, we get the following:

print(jtu.tree_flatten(B()))
([1, 3], PyTreeDef(CustomNode(B[(('a', 'c'), {'b': 2, '_pytree__initialized': True})], [*, *])))

As expected, since b was static, it does not appear in the flattened pytree.

However, if we do the same using dataclasses:

import jax.tree_util as jtu
from dataclasses import dataclass
from simple_pytree import Pytree, static_field

@dataclass
class A(Pytree):
    a: int = 1
    b: int = static_field(2)

@dataclass
class B(A):
    c: int = 3

We get a different result.

print(jtu.tree_flatten(B()))
([1, 2, 3], PyTreeDef(CustomNode(B[(('a', 'b', 'c'), {'_pytree__initialized': True})], [*, *, *])))

And the field b is not being treated as static.

The metadata seems fine.

from dataclasses import fields

# We're good on the class.
print([f.metadata for f in fields(B)])

# And upon instantiation.
print([f.metadata for f in fields(B())])
[mappingproxy({}), mappingproxy({'pytree_node': False}), mappingproxy({})]
[mappingproxy({}), mappingproxy({'pytree_node': False}), mappingproxy({})]

If you agree this behaviour is not expected, I would happy to look into this and open a PR.

Cheers,
Dan :)

Hey @daniel-dodd! Thanks again for reporting the issue. Looking into it.

Fixed in 0.1.6