google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Home Page:http://jax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

sum(keepdims=True) should work in pallas

voznesenskym opened this issue · comments

Description

I am writing a pallas kernel, where one of the lines:

frobenius_sq_norm = square_norm(w_tl).sum(keepdims=True)... includes a sum() w/ keepdims. I don't expect sum() to work without keepdims, at that would produce a scalar. However, for keepdims=True, being vec->vec, it should work.

Cannot lower reductions to scalar. Reduce to one element vector instead, using keepdims=True.

This is because the actual implementation is first removing the dimension, and then adding it back in.

In reductions.py in jax.numpy, we can find:

if keepdims:
   result = lax.expand_dims(result, pos_dims)

The real fix would probably be to pass along keepdims to all leaf locations where we actually end up invoking the op, and ensure that that op respects it and preserves the vec, instead of repackaging a scalar into a vec.