google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Home Page:http://jax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[hcb] request for custom batching

oliverdutton opened this issue · comments

Please:
Can the implementation of custom_call be modified to allow for the custom_batching rules to be used with it. Ideally with both the simplest case (sequential_vmap) but also custom.

I have a section of code that I'd like to run on the CPU in the middle of a large GPU suited set of operations. I vmap this to fill the GPU, but that means I can't use host_callback anymore.

Please would you consider the modifications

import jax
import jax.experimental.host_callback as hcb
from jax import vmap, numpy as jnp
import numpy as np

# This function runs on the host
def host_eig(m: np.ndarray) -> np.ndarray:
  return np.linalg.eigvals(m)

# This function is used in JAX
@jax.custom_batching.sequential_vmap
def device_fun(m):
  # We send "m" to the host, asking it to call "host_eig" and return the result.
  # We have to specify the result shape and dtype, either in the form of an
  # example return value or any object that has `shape` and `dtype` attributes,
  # e.g., a NumPy array or a `jax.ShapeDtypeStruct`.
  return hcb.call(host_eig, m,
                  # Given an input of shape (..., d, d), eig output has shape (..., d)
                  result_shape=jax.ShapeDtypeStruct(m.shape[:-1], m.dtype))

print(device_fun(jnp.eye(4)))

print(vmap(device_fun)(jnp.eye(4)[None])) # Errors.

---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
<ipython-input-1-ea43230de86b> in <module>
     25 print(device_fun(jnp.eye(4)))
     26 
---> 27 print(vmap(device_fun)(jnp.eye(4)[None]))

    [... skipping hidden 27 frame]

/opt/conda/lib/python3.8/site-packages/jax/experimental/host_callback.py in <lambda>(j)
   1742 xla.register_translation(id_p, lambda ctx, avals_in, avals_out, *args: args)
   1743 
-> 1744 dispatch.outfeed_rewriter = lambda j: _rewrite_jaxpr(j, False, False)
   1745 
   1746 

/opt/conda/lib/python3.8/site-packages/jax/experimental/host_callback.py in _rewrite_jaxpr(jaxpr, has_input_token, has_output_token)
   1432       output_token_var = mk_new_var(last_token_var.aval)
   1433       output_itoken_var = mk_new_var(last_itoken_var.aval)
-> 1434       _rewrite_eqn(eqn, eqns, last_token_var, output_token_var,
   1435                    last_itoken_var, output_itoken_var, mk_new_var)
   1436       last_token_var = output_token_var

/opt/conda/lib/python3.8/site-packages/jax/experimental/host_callback.py in _rewrite_eqn(eqn, eqns, input_token_var, output_token_var, input_itoken_var, output_itoken_var, mk_new_var)
   1503     new_invars = eqn.invars[0:nr_const_and_carry] + [
   1504         input_token_var, input_itoken_var] + eqn.invars[nr_const_and_carry:]
-> 1505     new_jaxpr = _rewrite_closed_jaxpr(carry_jaxpr, True, True)
   1506     # The rewrite has put the token at end, it has to be at end of carry
   1507     new_jaxpr_invars = new_jaxpr.jaxpr.invars

/opt/conda/lib/python3.8/site-packages/jax/experimental/host_callback.py in _rewrite_closed_jaxpr(cjaxpr, has_input_token, has_output_token)
   1396                           has_output_token: bool) -> core.ClosedJaxpr:
   1397   """Rewrites a ClosedJaxpr to thread the token, if needed."""
-> 1398   new_jaxpr = _rewrite_jaxpr(cjaxpr.jaxpr, has_input_token, has_output_token)
   1399   return core.ClosedJaxpr(new_jaxpr, cjaxpr.consts)
   1400 

/opt/conda/lib/python3.8/site-packages/jax/experimental/host_callback.py in _rewrite_jaxpr(jaxpr, has_input_token, has_output_token)
   1432       output_token_var = mk_new_var(last_token_var.aval)
   1433       output_itoken_var = mk_new_var(last_itoken_var.aval)
-> 1434       _rewrite_eqn(eqn, eqns, last_token_var, output_token_var,
   1435                    last_itoken_var, output_itoken_var, mk_new_var)
   1436       last_token_var = output_token_var

/opt/conda/lib/python3.8/site-packages/jax/experimental/host_callback.py in _rewrite_eqn(eqn, eqns, input_token_var, output_token_var, input_itoken_var, output_itoken_var, mk_new_var)
   1631             ), eqn.source_info))
   1632   else:
-> 1633     raise NotImplementedError(f"outfeed rewrite {eqn.primitive}")
   1634 
   1635 

NotImplementedError: outfeed rewrite custom_vmap_call

This came up recently over in #8853 as well.

It is actually possible to add custom batching to a host_callback.call, you just to monkey-patch how batching for it works.

See for example the following.

from typing import Any, NamedTuple
import jax
import jax.experimental.host_callback as hcb
import jax.interpreters.batching as batching
import jax.numpy as jnp

class _MyFuncArg(NamedTuple):
    value: Any

def _hcb_func(arg: _MyFuncArg):
    value = arg.value
    print(f"Not batching")
    return value + 1

def _hcb_func_batched(arg: _MyFuncArg):
    value = arg.value
    print("Batching")
    return value + 1

old_hcb_batching_rule = batching.primitive_batchers[hcb.outside_call_p]

def hcb_batching_rule(arg_flat, batch_axes, *, arg_treedef, **params):
    leaves = [None] * arg_treedef.num_leaves
    call_type = type(jax.tree_unflatten(arg_treedef, leaves))
    if call_type is _MyFuncArg:
        arg = jax.tree_unflatten(arg_treedef, arg_flat)
        result_shape = jax.ShapedArray(arg.value.shape, arg.value.dtype)
        out = hcb.call(_hcb_func_batched, arg, result_shape=result_shape)
        out = jax.tree_leaves(out)
        # (Make sure to update batch_axes if you need to.)
        return out, batch_axes
    return old_hcb_batching_rule(arg_flat, batch_axes, arg_treedef=arg_treedef, **params)

batching.primitive_batchers[hcb.outside_call_p] = hcb_batching_rule

def myfunc(x):
    result_shape = jax.ShapedArray(x.shape, x.dtype)
    return hcb.call(_hcb_func, _MyFuncArg(x), result_shape=result_shape)

myfunc(jnp.array(1))
jax.vmap(myfunc)(jnp.array([1, 1]))

I think working with primitives still counts as using internal APIs, so I don't think there's any claims about forawrd compatibility.

For a more complete example, Equinox does this for some of its (experimental!) operations: see here for an example.

Amazing, thank you so much