triton-lang / triton

Development repository for the Triton language and compiler

Home Page:https://triton-lang.org/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Inconsistent Behavior Between Triton and PyTorch When Modifying Input Tensor In-Place for Gradient Storage

ByronHsu opened this issue · comments

Question

Hi triton team, I am implementing a triton kernel (2.3.0 on A100) with the trick of modifying the input tensor in-place to store gradient to save memory. However, I found some unexpected behaviors:

  1. With triton kernel backward, the version of input tensor stays at 0 no matter how many times i run the backward
  2. With pytorch native backward, the version of input tensor increments by 1 after each backward and throws the in-place error.
  3. With triton kernel backward + 2 views before the forward, the version of input tensor stays at 0 in the first backward, and increments by 1 at the second backward (??!), and throws the in-place error

In general, i would like to know how to safely modify input tensor in-place to save memory in backward, or it is not encouraged?

Reproduce

import torch
import triton
import triton.language as tl

# y = x**2
# dx = dy * 2x

def full(func):
    B, T, C = 2, 4, 8
    X = torch.randn(B*T, C, requires_grad=True, device='cuda')
    Y = func(X)
    # Intentionally backward three times
    Y.backward(torch.ones_like(Y), retain_graph=True)
    print("1st backward input version", X._version)
    Y.backward(torch.ones_like(Y), retain_graph=True)
    print("2nd backward input version", X._version)
    Y.backward(torch.ones_like(Y), retain_graph=True)
    print("3rd backward input version", X._version)

def full_with_view(func):
    B, T, C = 2, 4, 8 
    X = torch.randn(B*T, C, requires_grad=True, device='cuda')
    # Same as full but just do two redundant view ops
    Y = func(X.view(B, T, C).view(B*T, C))
    # Intentionally backward three time
    Y.backward(torch.ones_like(Y), retain_graph=True)
    print("1st backward input version", X._version)
    Y.backward(torch.ones_like(Y), retain_graph=True)
    print("2nd backward input version", X._version)
    # It fails at the 3rd backward
    Y.backward(torch.ones_like(Y), retain_graph=True)
    print("3rd backward input version", X._version)

class SquareFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return input ** 2

    @staticmethod
    def backward(ctx, grad_output):
        (input,) = ctx.saved_tensors
        grad_output *= 2
        input *= grad_output
        return input, None

@triton.jit
def element_mul(
    X_ptr,
    X_stride,
    grad_output_ptr,
    grad_output_stride,
    n_cols,
    BLOCK_SIZE: tl.constexpr
):
    program_id = tl.program_id(axis=0)
    X_ptr += program_id * X_stride
    grad_output_ptr += program_id * grad_output_stride
    offsets = tl.arange(0, BLOCK_SIZE)
    X_row = tl.load(X_ptr+ offsets, mask=offsets < n_cols)
    grad_output = tl.load(grad_output_ptr+offsets, mask=offsets < n_cols)
    tl.store(X_ptr + offsets, X_row * grad_output, mask=offsets < n_cols)

class TritonSquareFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return input ** 2

    @staticmethod
    def backward(ctx, grad_output):
        (input,) = ctx.saved_tensors
        grad_output *= 2
        # triton kernel
        BLOCK_SIZE = 1024
        n_cols = input.size(1)
        n_row = input.size(0)
        element_mul[(n_row, )](
            input,
            input.stride(-2),
            grad_output,
            grad_output.stride(-2),
            n_cols,
            BLOCK_SIZE
        )
        return input, None

full(TritonSquareFunction.apply)
print("TritonSquareFunction done")

try:
    full_with_view(TritonSquareFunction.apply)
    print("TritonSquareFunction done")
except RuntimeError as e:
    print(e)

try:
    full(SquareFunction.apply)
    print("SquareFunction done")
except RuntimeError as e:
    print(e)

Output:

1st backward input version 0
2nd backward input version 0
3rd backward input version 0
TritonSquareFunction done
1st backward input version 0
2nd backward input version 1
one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [8, 8]], which is output 0 of AsStridedBackward0, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
1st backward input version 1
one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [8, 8]] is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

Answer from claude 3 Sonnet (much better than chatgpt and gemini in this case):

This behavior is related to how PyTorch handles in-place operations and how it tracks the version of tensors for autograd. In PyTorch, when a tensor is modified in-place, its version is incremented to keep track of the changes and ensure correct gradient computation.

In your case, when you use the TritonSquareFunction without the redundant views, the input tensor's version stays at 0 because the triton kernel is directly modifying the input tensor in-place without creating any new views or copies. PyTorch allows this as long as the tensor is not used in any other computation that requires its original value.

However, when you introduce the redundant views (X.view(B, T, C).view(B*T, C)), PyTorch creates new views of the input tensor, and the version of the input tensor is incremented. During the backward pass, when you try to modify the input tensor in-place using the triton kernel, PyTorch detects that the tensor version has changed, and it throws an error because it expects the tensor version to be the same as when it was saved for the backward pass.

The PyTorch native backward (SquareFunction) increments the input tensor's version after each backward pass because it creates a new tensor for the gradient and performs an in-place addition operation (input *= grad_output).

To safely modify the input tensor in-place during the backward pass, you should avoid creating any new views or copies of the input tensor before passing it to the forward function. If you need to reshape the input tensor, you should do it inside the forward function and save the reshaped tensor for the backward pass.

Additionally, you can consider using PyTorch's in-place operations carefully or use PyTorch's functional APIs that perform out-of-place operations to avoid modifying the input tensor directly.

In general, modifying input tensors in-place during the backward pass is not recommended because it can lead to unexpected behaviors and potential errors, especially when dealing with complex computations and control flows. It's safer to create new tensors for the gradients and perform out-of-place operations to avoid modifying the input tensors.

  • full(TritonSquareFunction.apply)
1st backward input version 0
2nd backward input version 0
3rd backward input version 0
  • full_with_view(TritonSquareFunction.apply)
1st backward input version 0
2nd backward input version 1
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [8, 8]], which is output 0 of AsStridedBackward0, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True)
  • full(SquareFunction.apply)
1st backward input version 1
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [8, 8]] is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
  • full_with_view(SquareFunction.apply)
1st backward input version 1
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [8, 8]], which is output 0 of AsStridedBackward0, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).