Adding `tree_util.stack_leaves()` and `tree_util.unstack_leaves()`
ayaka14732 opened this issue · comments
stack_leaves
: Stack the leaves of one or more PyTrees along a new axis.unstack_leaves
: Unstack the leaves of a PyTree.
References:
- https://docs.liesel-project.org/en/v0.1.4/_modules/liesel/goose/pytree.html#stack_leaves
- https://gist.github.com/willwhitney/dd89cac6a5b771ccff18b06b33372c75?permalink_comment_id=4634557#gistcomment-4634557
- https://github.com/ayaka14732/llama-2-jax/blob/ab33e1f15489daa8b9040389c77e486cd450e461/lib/tree_utils/__init__.py
To be clear, are these the semantics you have in mind?
def stack_leaves(pytrees, axis):
return jax.tree.map(lambda *xs: jnp.stack(xs, axis), pytrees)
To be clear, are these the semantics you have in mind?
Yes
For something like this, I'd probably lean toward recommending users implement what they need via existing API composability, rather than providing a new API for something that can already be pretty succinctly expressed. What do you think?
Maybe adding tree util cookbook would be useful? @jakevdp
A pytree cookbook would be an interesting idea! This idea also came up in #20594. @ayaka14732, is that something you'd be interested in thinking about?