support for Jax-like custom forward pass definition?
tylerflex opened this issue · comments
Is there a way to define a custom forward pass, like in jax, where one can output a residual that may be used by the backward pass?
For example, is the following example (from the Jax docs) implementable in autograd?
from jax import custom_vjp
@custom_vjp
def f(x, y):
return jnp.sin(x) * y
def f_fwd(x, y):
# Returns primal output and residuals to be used in backward pass by f_bwd.
return f(x, y), (jnp.cos(x), jnp.sin(x), y)
def f_bwd(res, g):
cos_x, sin_x, y = res # Gets residuals computed in f_fwd
return (cos_x * g * y, sin_x * g)
f.defvjp(f_fwd, f_bwd)
In PyTorch, you can define a custom forward pass by subclassing torch.autograd.Function. This allows you to specify the forward pass, backward pass, and gradient computation of your custom function.
For example, you could implement the Jax f function as follows in PyTorch:
import torch
class f(torch.autograd.Function):
@staticmethod
def forward(ctx, x, y):
ctx.save_for_backward(x, y)
return torch.sin(x) * y
@staticmethod
def backward(ctx, grad_output):
x, y = ctx.saved_tensors
cos_x = torch.cos(x)
sin_x = torch.sin(x)
grad_x = grad_output * cos_x * y
grad_y = grad_output * sin_x
return grad_x, grad_y
x = torch.tensor(1.0, requires_grad=True)
y = torch.tensor(2.0, requires_grad=True)
output = f.apply(x, y)
output.backward()
print(x.grad) # tensor(-1.0806)
print(y.grad) # tensor(0.8415)
Here, ctx.save_for_backward is used to save the values of x and y for use in the backward pass. The backward method then computes the gradients with respect to x and y using the saved values and the chain rule. Finally, the apply method is used to apply the custom function to the inputs x and y.