Get all unique kinds
thomaspinder opened this issue · comments
Hi,
Is there a way that I can get a list of all the unique kinds within a nested dataclass? For example:
class KindOne: pass
class KindTwo: pass
@dataclass
class SubModel(to.Tree):
parameter: jnp.array = to.field(
default=jnp.array([1.0]), node=True, kind=KindOne
)
@dataclass
class Model(to.Tree):
parameter: jnp.array = to.field(
default=jnp.array([1.0]), node=True, kind=KindTwo
)
m = Model()
m.unique_kinds() # [KindOne, KindTwo]
Interesting, there is a way to do this using a combination of to.apply
and the .field_metadata
property:
from dataclasses import dataclass
from typing import Set
import jax
import jax.numpy as jnp
import treeo as to
from treeo.utils import field
class KindOne: pass
class KindTwo: pass
@dataclass
class SubModel(to.Tree):
parameter: jnp.ndarray = to.field(default=jnp.array([1.0]), node=True, kind=KindOne)
@dataclass
class Model(to.Tree):
submodel: SubModel
parameter: jnp.ndarray = to.field(default=jnp.array([1.0]), node=True, kind=KindTwo)
def unique_kinds(tree: to.Tree) -> Set[type]:
kinds = set()
def add_subtree_kinds(subtree: to.Tree):
for field in subtree.field_metadata.values():
if field.kind is not type(None):
kinds.add(field.kind)
to.apply(add_subtree_kinds, tree)
return kinds
m = Model(SubModel())
print(unique_kinds(m)) # {KindOne, KindTwo}
apply
traverses all the Tree
s within a Pytree, and field_metadata
is a dictionary describing each field.