google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Home Page:http://jax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Unhelpful error messages related to pmap leading dimension requirements

markpaskin opened this issue · comments

pmap splits work based on the leading dimension of its arguments, which must be equal and cannot be larger than the number of devices. As a JAX n00b, I wasn't initially aware of this, thinking that pmap would internally split along the leading dimension. This led me to encounter some unhelpful error messages, one of which was triggered by using keyword arguments. Repro details below:

import jax
import jax.numpy as jnp
import numpy as np

Here's a very simple computation that scales a (5,) array by a (1,) array's value:

x = np.arange(5)
w = np.array([2.])

def scale(x, w):
  return x * w

scale(x, w)
--> array([0., 2., 4., 6., 8.])

Note that I'm doing this on a single CPU:

jax.local_device_count() 
--> 1

Running this computation using pmap is inherently broken, because the first dimension of x is 5, which is greater than the number of devices.

So, we'd expect a helpful error message, but I've found that if arguments are supplied by keyword, we don't get one:

jax.pmap(scale)(x=x, w=w)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
jax/_src/api.py in _mapped_axis_size(tree, vals, dims, name, kws)
   1499   try:
-> 1500     size, = mapped_axis_sizes
   1501     return size

ValueError: too many values to unpack (expected 1)

During handling of the above exception, another exception occurred:

UnfilteredStackTrace                      Traceback (most recent call last)
6 frames
<ipython-input-7-997e0e638d4f> in <module>()
----> 1 jax.pmap(scale)(x=x, w=w)

jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
    161     try:
--> 162       return fun(*args, **kwargs)
    163     except Exception as e:

jax/_src/api.py in cache_miss(*args, **kwargs)
   2011 
-> 2012     out_tree, out_flat = f_pmapped_(*args, **kwargs)
   2013     out_pytree_def = out_tree()

jax/_src/api.py in pmap_f(*args, **kwargs)
   1884   def pmap_f(*args, **kwargs):
-> 1885     p = _prepare_pmap(
   1886         fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple,

jax/_src/api.py in _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple, global_arg_shapes, args, kwargs)
   1828       kws=True))
-> 1829   local_axis_size = _mapped_axis_size(
   1830       in_tree, args, in_axes_flat, "pmap", kws=True)

jax/_src/api.py in _mapped_axis_size(tree, vals, dims, name, kws)
   1512       tree, leaf = treedef_children(tree)
-> 1513       assert treedef_is_leaf(leaf)
   1514     # TODO(mattjj,phawkins): add a way to inspect pytree kind more directly

UnfilteredStackTrace: AssertionError

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

AssertionError                            Traceback (most recent call last)
<ipython-input-7-997e0e638d4f> in <module>()
----> 1 jax.pmap(scale)(x=x, w=w)

AssertionError:

Hi, thanks for the report. Can I ask what version of JAX you're using? On the most recent version I see this error:

import jax.numpy as jnp
import jax
x = jnp.arange(5)
w = jnp.array([2.])

def scale(x, w):
  return x * w

print(jax.__version__)
# 0.3.7
jax.pmap(scale)(x, w)
...
ValueError: pmap got inconsistent sizes for array axes to be mapped:
arg 0 has shape (5,) and axis 0 is to be mapped
arg 1 has shape (1,) and axis 0 is to be mapped
so
arg 0 has an axis to be mapped of size 5
arg 1 has an axis to be mapped of size 1

If you're using a recent JAX version, perhaps I'm misinterpreting what code you're running. Could you include the exact function call that's leading to the unhelpful error?

Hi Jake. My JAX version is 0.3.8. I've shared an internal Colab with you that reproduces the issue.

I see, thanks. It looks like the helpful error message is somehow missed when the arguments are passed by keyword rather than by position.

It looks like the problem is when there are multiple keyword arguments passed to pmap, the shape checking logic fails this assertion:

assert treedef_is_leaf(leaf)

This is definitely a bug; it looks like the code was added by @mattjj in #5387. I'm going to assign the issue to him. Matt - feel free to bounce it back to me if you don't have the cycles to look at this.

It looks like removing that assertion is enough to make this raise a helpful message. I don't understand well enough the intent of that code to know if it's the right fix.

tree, leaf = treedef_children(tree) decompose args treedef to positional-args treedef and kwargs treedef, only remove this assertion seems to raise an error with empty message?

Any updates to this?