HIPS / autograd

Efficiently computes derivatives of numpy code.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

support for Jax-like custom forward pass definition?

tylerflex opened this issue · comments

Is there a way to define a custom forward pass, like in jax, where one can output a residual that may be used by the backward pass?

For example, is the following example (from the Jax docs) implementable in autograd?

from jax import custom_vjp

@custom_vjp
def f(x, y):
  return jnp.sin(x) * y

def f_fwd(x, y):
# Returns primal output and residuals to be used in backward pass by f_bwd.
  return f(x, y), (jnp.cos(x), jnp.sin(x), y)

def f_bwd(res, g):
  cos_x, sin_x, y = res # Gets residuals computed in f_fwd
  return (cos_x * g * y, sin_x * g)

f.defvjp(f_fwd, f_bwd)

In PyTorch, you can define a custom forward pass by subclassing torch.autograd.Function. This allows you to specify the forward pass, backward pass, and gradient computation of your custom function.

For example, you could implement the Jax f function as follows in PyTorch:

import torch

class f(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, y):
        ctx.save_for_backward(x, y)
        return torch.sin(x) * y

    @staticmethod
    def backward(ctx, grad_output):
        x, y = ctx.saved_tensors
        cos_x = torch.cos(x)
        sin_x = torch.sin(x)
        grad_x = grad_output * cos_x * y
        grad_y = grad_output * sin_x
        return grad_x, grad_y

x = torch.tensor(1.0, requires_grad=True)
y = torch.tensor(2.0, requires_grad=True)

output = f.apply(x, y)
output.backward()

print(x.grad) # tensor(-1.0806)
print(y.grad) # tensor(0.8415)

Here, ctx.save_for_backward is used to save the values of x and y for use in the backward pass. The backward method then computes the gradients with respect to x and y using the saved values and the chain rule. Finally, the apply method is used to apply the custom function to the inputs x and y.