Nested custom vjps do not correctly propagate cotangents
lengstrom opened this issue · comments
Description
The use case is nesting two custom_vjp
operators to perform reverse mode AD over reverse mode AD (to calculate an HVP over a function that uses custom_vjp
to call into external, untraceable code). I have not used custom_vjp
before and the documentation didn't seem to cover anything related to this problem, so please forgive me if I'm missing something obvious here!
When nesting two custom_vjp operations, the cotangents computed by the second custom_vjp
do not seem to be propagated properly through the compute graph. A minimal example is below, using the function y=x**3
; the computed gradient is 0, which implies that the ddy computed in the second custom_vjp
(in op_bck_bck
) is never further propagated.
import jax
@jax.custom_vjp
def op(x):
return x**3
def op_fwd(x):
# return the output and the saved values
return op(x), (x,)
@jax.custom_vjp
def op_bck(saved, dy):
x, = saved
return 3 * x**2 * dy, # gradient of y=x^3
def op_bck_fwd(saved, dy):
return op_bck(saved, dy), saved
def op_bck_bck(saved, ddx):
x, = saved
ddx, = ddx
ddy = 3 * x**2 * ddx
# return:
# - dSaved (None as we don't want to diff through saved values)
# - ddy (use chain rule)
return (None,), ddy
op_bck.defvjp(op_bck_fwd, op_bck_bck)
op.defvjp(op_fwd, op_bck)
# test it
def jop(x):
return x**3
xx = jax.random.normal(jax.random.PRNGKey(0), (5,))
ww = jax.random.normal(jax.random.PRNGKey(1), (5,))
ww2 = jax.random.normal(jax.random.PRNGKey(2), (5,))
def fn(op, x):
def l(x):
# first vjp
return jax.grad(lambda z: op(z) @ ww)(x) @ ww2
# should call second vjp
return jax.grad(l)(x)
gt_jac = fn(jop, xx)
my_jac = fn(op, xx)
print(gt_jac) # [-0.16732852 11.739656 -0.06793058 0.4734605 -0.15350063]
print(my_jac) # [0. 0. 0. 0. 0.]
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.26
jaxlib: 0.4.26
numpy: 1.26.4
python: 3.10.14 (main, Mar 21 2024, 16:24:04) [GCC 11.2.0]
Here is an edit to your nested custom_vjp
snippet that causes the second-order derivative values to match the reference values:
import jax
@jax.custom_vjp
def op(x):
return x ** 3
def op_fwd(x):
return op(x), (x,)
@jax.custom_vjp
def op_bck(xs, dy):
x, = xs
return 3 * x ** 2 * dy,
def op_bck_fwd(xs, dy):
return op_bck(xs, dy), (xs, dy)
def op_bck_bck(xs_dy, dzs):
xs, dy = xs_dy
x, = xs
dz, = dzs
dx = 6 * x * dy * dz
ddy = 3 * x ** 2 * dz
return (dx,), ddy
op_bck.defvjp(op_bck_fwd, op_bck_bck)
op.defvjp(op_fwd, op_bck)
A couple of observations behind this edit:
- Returning
None
from the second backward function you had doesn't seem quite right. Even though this was invoked with a saved residual, we still want to differentiate. You can see the VJP with respect to this argument as the expressiondx = 6 * ...
in myop_bck_bck
. - In
op_bck
's custom derivative, it's useful to save the input valuedy
as a residual, since we need to use it in computing the VJP expression from the previous bullet (dx = 6 * ...
).
What do you think? Take a look and see if this looks correct to you. If something remains confusing then we can think on how we'd improve docs somehow, or consider adding an example somewhere (maybe in a future version of the AD cookbook).
This makes much more sense now, thank you for the help! I really appreciate the worked edit.