google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Home Page:http://jax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Nested custom vjps do not correctly propagate cotangents

lengstrom opened this issue · comments

Description

The use case is nesting two custom_vjp operators to perform reverse mode AD over reverse mode AD (to calculate an HVP over a function that uses custom_vjp to call into external, untraceable code). I have not used custom_vjp before and the documentation didn't seem to cover anything related to this problem, so please forgive me if I'm missing something obvious here!

When nesting two custom_vjp operations, the cotangents computed by the second custom_vjp do not seem to be propagated properly through the compute graph. A minimal example is below, using the function y=x**3; the computed gradient is 0, which implies that the ddy computed in the second custom_vjp (in op_bck_bck) is never further propagated.

import jax

@jax.custom_vjp
def op(x):
    return x**3

def op_fwd(x):
    # return the output and the saved values
    return op(x), (x,)

@jax.custom_vjp
def op_bck(saved, dy):
    x, = saved
    return 3 * x**2 * dy, # gradient of y=x^3

def op_bck_fwd(saved, dy):
    return op_bck(saved, dy), saved

def op_bck_bck(saved, ddx):
    x, = saved
    ddx, = ddx
    ddy = 3 * x**2 * ddx
    # return:
    # - dSaved (None as we don't want to diff through saved values)
    # - ddy (use chain rule)
    return (None,), ddy

op_bck.defvjp(op_bck_fwd, op_bck_bck)
op.defvjp(op_fwd, op_bck)

# test it
def jop(x):
    return x**3

xx = jax.random.normal(jax.random.PRNGKey(0), (5,))
ww = jax.random.normal(jax.random.PRNGKey(1), (5,))
ww2 = jax.random.normal(jax.random.PRNGKey(2), (5,))

def fn(op, x):
    def l(x):
        # first vjp
        return jax.grad(lambda z: op(z) @ ww)(x) @ ww2

    # should call second vjp
    return jax.grad(l)(x)

gt_jac = fn(jop, xx)
my_jac = fn(op, xx)

print(gt_jac) # [-0.16732852 11.739656   -0.06793058  0.4734605  -0.15350063]
print(my_jac) # [0. 0. 0. 0. 0.]

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.26
jaxlib: 0.4.26
numpy:  1.26.4
python: 3.10.14 (main, Mar 21 2024, 16:24:04) [GCC 11.2.0]

Here is an edit to your nested custom_vjp snippet that causes the second-order derivative values to match the reference values:

import jax

@jax.custom_vjp
def op(x):
  return x ** 3

def op_fwd(x):
  return op(x), (x,)

@jax.custom_vjp
def op_bck(xs, dy):
  x, = xs
  return 3 * x ** 2 * dy,

def op_bck_fwd(xs, dy):
  return op_bck(xs, dy), (xs, dy)

def op_bck_bck(xs_dy, dzs):
  xs, dy = xs_dy
  x, = xs
  dz, = dzs
  dx = 6 * x * dy * dz
  ddy = 3 * x ** 2 * dz
  return (dx,), ddy

op_bck.defvjp(op_bck_fwd, op_bck_bck)
op.defvjp(op_fwd, op_bck)

A couple of observations behind this edit:

  • Returning None from the second backward function you had doesn't seem quite right. Even though this was invoked with a saved residual, we still want to differentiate. You can see the VJP with respect to this argument as the expression dx = 6 * ... in my op_bck_bck.
  • In op_bck's custom derivative, it's useful to save the input value dy as a residual, since we need to use it in computing the VJP expression from the previous bullet (dx = 6 * ...).

What do you think? Take a look and see if this looks correct to you. If something remains confusing then we can think on how we'd improve docs somehow, or consider adding an example somewhere (maybe in a future version of the AD cookbook).

This makes much more sense now, thank you for the help! I really appreciate the worked edit.