google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Home Page:http://jax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

`tree_flatten` gets `None` when used in `defvjp`

yuanqing-wang opened this issue · comments

Please:

  • Check for duplicate issues.
  • Provide a complete example of how to reproduce the bug, wrapped in triple backticks like this:

I wanted to define a custom_vjp for a function that takes a string to indicate what operations are to be conducted. Since str is not a validate type for JAX whereas any objects that can be flatten and unflatten as a pytree is, I thought of this hacky way to define my own string object:

import jax
import jax.numpy as jnp
from functools import partial

OP_TO_IDX, IDX_TO_OP = {"add": 0, "mul": 1}, {0: "add", 1: "mul"}

@jax.tree_util.register_pytree_node_class
class OpStr(object):
    def __init__(self, op):
        assert op in OP_TO_IDX, "can only be one of the ops"
        self.op = op

    def __eq__(self, other):
        return self.op == other

    def __hash__(self):
        return hash(self.op)

    def tree_flatten(self):
        return ([OP_TO_IDX[self.op]]), None

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        idx = int(children[0])
        return cls(op=IDX_TO_OP[idx])

@partial(jax.custom_vjp, nondiff_argnums=(2,))
def op(x, y, op):
    if op == "add": z = x + y
    if op == "mul": z = x * y
    return z

def op_fwd(x, y, op):
    cache = (x, y, op)
    if op == "add": z = x + y
    if op == "mul": z = x * y
    return z, cache

def op_bwd(cache, dz):
    x, y, op = cache
    if op == "add": dz_dx = dz_dy = 1.0
    if op == "mul": dz_dx, dz_dy = y, x
    return dz_dx, dz_dy

op.defvjp(op_fwd, op_bwd)
grad_op = jax.grad(op)
grad_op(1.0, 1.0, OpStr("add"))
  • If applicable, include full error messages/tracebacks.

But this gives me

Traceback (most recent call last):
  File "/Users/wangy1/Documents/GitHub/dgl/tests/jax/test_op_str.py", line 47, in <module>
    grad_op(1.0, 1.0, OpStr("add"))
  File "/Users/wangy1/Documents/GitHub/dgl/tests/jax/test_op_str.py", line 24, in tree_unflatten
    idx = int(children[0])
TypeError: int() argument must be a string, a bytes-like object or a real number, not 'NoneType'

Which is not expected since the flattened tree shouldn't be None.

It sounds like you're running into the issues discussed here: https://jax.readthedocs.io/en/latest/pytrees.html#custom-pytrees-and-initialization

Do the recommendations there answer your question?

Indeed, to underscore Jake's point, pytrees are technically required to be able to contain any Python object, just like Python tuples can. That is, pytrees are defined in terms of isomorphisms to tuples. Some bits of JAX internals code rely on that property. In this case, we're tree-mapping an assertion function which returns None (which is how Python represents 'no return value'); there's still a result pytree being built, with Nones substituted in place for what were previously leaves. But since this custom pytree node class doesn't handle containing nones, it breaks.

That said, while this bit of JAX internals code is relying on something that technically is part of the pytree contract, usually it's possible for us to rewrite the code not to demand this property hold true (i.e. not to require a custom pytree type to contain Nones). I think we can do that in this case too...

#10367 will fix this issue, though there's actually another one here which had to be fixed too. See the line of code below marked with a comment:

import jax
import jax.numpy as jnp
from functools import partial

OP_TO_IDX, IDX_TO_OP = {"add": 0, "mul": 1}, {0: "add", 1: "mul"}

@jax.tree_util.register_pytree_node_class
class OpStr(object):
    def __init__(self, op):
        assert op in OP_TO_IDX, "can only be one of the ops"
        self.op = op

    def __eq__(self, other):
        return self.op == other

    def __hash__(self):
        return hash(self.op)

    def tree_flatten(self):
        return ([OP_TO_IDX[self.op]]), None

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        idx = int(children[0])
        return cls(op=IDX_TO_OP[idx])


@partial(jax.custom_vjp, nondiff_argnums=(2,))
def op(x, y, op):
    if op == "add": z = x + y
    if op == "mul": z = x * y
    return z

def op_fwd(x, y, op):
    cache = (x, y, op)
    if op == "add": z = x + y
    if op == "mul": z = x * y
    return z, cache

def op_bwd(op, cache, dz):  # NOTE added arg for nondiff_argnums
    x, y, op = cache
    if op == "add": dz_dx = dz_dy = 1.0
    if op == "mul": dz_dx, dz_dy = y, x
    return dz_dx, dz_dy

op.defvjp(op_fwd, op_bwd)
grad_op = jax.grad(op)
grad_op(1.0, 1.0, OpStr("add"))

Notice that when using nondiff_argnums the bwd function already gets that input as an argument. So the opcode need not be passed via residuals.

Thanks so much, @jakevdp and @mattjj ! So sorry that I have overlooked that bit in the documentation.

Sorry I've followed your instructions and this guide here: https://jax.readthedocs.io/en/latest/custom_vjp_update.html#what-to-update

import jax
import jax.numpy as jnp
from functools import partial

OP_TO_IDX, IDX_TO_OP = {"add": 0, "mul": 1}, {0: "add", 1: "mul"}

@jax.tree_util.register_pytree_node_class
class OpStr:
    def __init__(self, op):
        if isinstance(op, str):
            assert op in OP_TO_IDX, "can only be one of the ops"
        self.op = op

    def __eq__(self, other):
        return self.op == other

    def __hash__(self):
        return hash(self.op)

    def tree_flatten(self):
        return [OP_TO_IDX[self.op]], None

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        idx = children[0]
        if isinstance(idx, int):
            return cls(op=IDX_TO_OP[idx])
        else:
            return None


@jax.custom_vjp
def op(x, y, op):
    if op == "add": z = x + y
    if op == "mul": z = x * y
    return z

def op_fwd(x, y, op):
    cache = (x, y, op)
    if op == "add": z = x + y
    if op == "mul": z = x * y
    return z, cache

def op_bwd(cache, dz):
    x, y, op = cache
    if op == "add": dz_dx = dz_dy = 1.0
    if op == "mul": dz_dx, dz_dy = y, x
    return (dz_dx, dz_dy, None)

op.defvjp(op_fwd, op_bwd)
grad_op = jax.grad(op)
grad_op(1.0, 1.0, OpStr("add"))

But still got this rather weird error that I don't quite comprehend:

Traceback (most recent call last):
  File "/Users/wangy1/Documents/GitHub/dgl/tests/jax/test_op_str.py", line 52, in <module>
    grad_op(1.0, 1.0, OpStr("add"))
AssertionError: length mismatch: [3, 2]

would appreciate your help! thanks!

Looking more closely at this, I'm a bit confused about the goal. Fundamentally, are you hoping to create a custom jvp of a function that accepts a PyTree, or was the PyTree just the mechanism you came up with to allow passing string arguments to your function?

If the latter, I think the better mechanism to use would be nondiff_argnums, and just pass a string.

Thanks, @jakevdp ! I just wanted to pass a string argument. But if I didn't use the PyTree hack JAX wouldn't let me:

import jax
import jax.numpy as jnp
from functools import partial

@partial(jax.custom_vjp, nondiff_argnums=(2,))
def op(x, y, op):
    if op == "add": z = x + y
    if op == "mul": z = x * y
    return z

def op_fwd(x, y, op):
    cache = (x, y, op)
    if op == "add": z = x + y
    if op == "mul": z = x * y
    return z, cache

def op_bwd(cache, dz):
    x, y, op = cache
    if op == "add": dz_dx = dz_dy = 1.0
    if op == "mul": dz_dx, dz_dy = y, x
    return (dz_dx, dz_dy, None)

op.defvjp(op_fwd, op_bwd)
grad_op = jax.grad(op)
grad_op(1.0, 1.0, "add")

results in:

--------------------
  File "/Users/wangy1/Documents/GitHub/dgl/tests/jax/test_str.py", line 25, in <module>
    grad_op(1.0, 1.0, "add")
TypeError: Value 'add' with type <class 'str'> is not a valid JAX type

The issue is that you cannot include static values in the cache; they will be passed to the function via other means:

import jax
import jax.numpy as jnp
from functools import partial

@partial(jax.custom_vjp, nondiff_argnums=(2,))
def op(x, y, op):
    if op == "add": z = x + y
    if op == "mul": z = x * y
    return z

def op_fwd(x, y, op):
    cache = (x, y)
    if op == "add": z = x + y
    if op == "mul": z = x * y
    return z, cache

def op_bwd(op, cache, dz):
    x, y = cache
    if op == "add": dz_dx = dz_dy = dz
    if op == "mul": dz_dx, dz_dy = dz * y, dz * x
    return (dz_dx, dz_dy)

op.defvjp(op_fwd, op_bwd)
grad_op = jax.grad(op)
grad_op(1.0, 1.0, "add")
# DeviceArray(1., dtype=float32, weak_type=True)

Notice I also changed the backward pass to make use of the tangent it receives.

OK I see! Thanks so much!