google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Home Page:http://jax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[sparse] allow specifying axis type in bcoo_broadcast_in_dim

jakevdp opened this issue · comments

Currently B[np.newaxis, :] will add a sparse axis when possible, and a batch axis otherwise.

We should add the ability to specify this. Possible API:

B[sparse.newaxis, :]  # add a sparse axis if possible, batch otherwise (same as current np.newaxis behavior)
B[sparse.newaxis.sparse, :]  # add a sparse axis, error if not possible
B[sparse.newaxis.batch, :]  # add a batch axis, error if not possible
B[..., sparse.newaxis]  # add a sparse axis if possible, dense otherwise
B[..., sparse.newaxis.dense]  # add a dense axis if possible, error if not possible

I'm going to close this because I don't think we'll be adding this feature.