cgarciae / treeo

A small library for creating and manipulating custom JAX Pytree classes

Home Page:https://cgarciae.github.io/treeo

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Get all unique kinds

thomaspinder opened this issue · comments

Hi,

Is there a way that I can get a list of all the unique kinds within a nested dataclass? For example:

class KindOne: pass
class KindTwo: pass

@dataclass
class SubModel(to.Tree):
    parameter: jnp.array = to.field(
        default=jnp.array([1.0]), node=True, kind=KindOne
    )


@dataclass 
class Model(to.Tree):
    parameter: jnp.array = to.field(
        default=jnp.array([1.0]), node=True, kind=KindTwo
    )

m = Model()

m.unique_kinds() # [KindOne, KindTwo]

Interesting, there is a way to do this using a combination of to.apply and the .field_metadata property:

from dataclasses import dataclass
from typing import Set

import jax
import jax.numpy as jnp
import treeo as to
from treeo.utils import field


class KindOne: pass
class KindTwo: pass

@dataclass
class SubModel(to.Tree):
    parameter: jnp.ndarray = to.field(default=jnp.array([1.0]), node=True, kind=KindOne)

@dataclass
class Model(to.Tree):
    submodel: SubModel
    parameter: jnp.ndarray = to.field(default=jnp.array([1.0]), node=True, kind=KindTwo)

def unique_kinds(tree: to.Tree) -> Set[type]:
    kinds = set()

    def add_subtree_kinds(subtree: to.Tree):
        for field in subtree.field_metadata.values():
            if field.kind is not type(None):
                kinds.add(field.kind)

    to.apply(add_subtree_kinds, tree)

    return kinds

m = Model(SubModel())

print(unique_kinds(m))  # {KindOne, KindTwo}

apply traverses all the Trees within a Pytree, and field_metadata is a dictionary describing each field.