google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Home Page:http://jax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Adding `tree_util.stack_leaves()` and `tree_util.unstack_leaves()`

ayaka14732 opened this issue · comments

To be clear, are these the semantics you have in mind?

def stack_leaves(pytrees, axis):
  return jax.tree.map(lambda *xs: jnp.stack(xs, axis), pytrees)
commented

To be clear, are these the semantics you have in mind?

Yes

For something like this, I'd probably lean toward recommending users implement what they need via existing API composability, rather than providing a new API for something that can already be pretty succinctly expressed. What do you think?

Maybe adding tree util cookbook would be useful? @jakevdp

A pytree cookbook would be an interesting idea! This idea also came up in #20594. @ayaka14732, is that something you'd be interested in thinking about?