NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.

Home Page:https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

te.Checkpoint does not work for nested autocast

tohinz opened this issue · comments

According to #438 we should be able to use both BF16 and FP8 autocasts.

In our specific setting our module consists of some linear layers that are torch.nn.Linear and some layers that are te.Linear (due to some input sizes not being compatible with FP8 and padding not being an option in this case). When we wrap this module with te.Checkpoint (following the fix in #776) we get errors in the backwards pass since the BF16 autocast is not used when the function is recomputed.

Concretely, for something like the following:

    criterion = torch.nn.MSELoss()
    model = torch.nn.Sequential()
    model.append(torch.nn.Linear(hidden_dim, hidden_dim))
    for _ in range(num_layers):
        model.append(te.Linear(hidden_dim, hidden_dim))

    with torch.cuda.amp.autocast(dtype=torch.bfloat16):
        with te.fp8_autocast(enabled=True):
            output = te.checkpoint(model, model_input, use_reentrant=False)

    loss = criterion(output, target)

    loss.backward()

we get the error
RuntimeError: expected mat1 and mat2 to have the same dtype, but got: c10::BFloat16 != float

since the torch.nn.Linear layer is not autocast to BF16 when the function is recomputed with te.Checkpoint.

Looking at the PyTorch implementation they have functionality to make sure the autocast is also applied during the recomputation of the function in the backward pass.

The salient code pieces being something like here, here, here, and here.

I've added that to my local TE branch and it seems to fix the issue, i.e., the code with two autocasts now runs through and the gradient check returns True.

Example code to reproduce the error (adapted from #438):

import torch
from torch.autograd import grad

import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling


if __name__ == "__main__":
    num_layers = 5
    seq_length = 1024
    hidden_dim = 2048

    # Generate random model input and target for MSE loss
    model_input = torch.rand(8, seq_length, hidden_dim).cuda().to(dtype=torch.float)
    target = torch.rand(8, seq_length, hidden_dim).cuda().to(dtype=torch.bfloat16)
    criterion = torch.nn.MSELoss()

    # Define the model
    model = torch.nn.Sequential()
    model.append(torch.nn.Linear(hidden_dim, hidden_dim))
    for _ in range(num_layers):
        model.append(te.Linear(hidden_dim, hidden_dim))
    model.to(dtype=torch.float32).cuda()

    # Define FP8
    fp8_format = Format.HYBRID
    fp8_recipe = DelayedScaling(
        fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max"
    )
    autocast_args = {"enabled": True, "fp8_recipe": fp8_recipe}
    autocast = te.fp8_autocast

    def inner(compare_grads):
        with torch.cuda.amp.autocast(dtype=torch.bfloat16):
            with autocast(**autocast_args):
                output = te.checkpoint(model, model_input, use_reentrant=False)
                output2 = model(model_input)

        loss = criterion(output, target)
        loss2 = criterion(output2, target)

        if compare_grads:
            # compute gradients
            grads = grad(loss, model.parameters())
            grads2 = grad(loss2, model.parameters())

            # compare gradients
            print("Gradients are equal: ")
            print(torch.all(torch.eq(grads[0], grads2[0])))
            print(torch.all(torch.eq(grads[1], grads2[1])))

            # print gradients to check they are nonzero and not nan
            print("")
            print(grads[0])
            print(grads2[0])
            print("")
            print(grads[1])
            print(grads2[1])
        else:
            loss.backward()
            loss2.backward()

    # run model
    fp8_scaling_iters = 50

    # warmup iterations to get FP8 scaling parameters
    for _ in range(fp8_scaling_iters):
        inner(compare_grads=False)

    inner(compare_grads=True)

@tohinz I will take a look at how we can automatically handle this in the TE checkpoint tomorrow. In the meantime, you should be able to make this work via user context functions.

def torch_autocast_ctx():
    fwd_ctx = torch.amp.autocast(...)
    recomp_ctx = torch.amp.autocast(...)
    return ctx, ctx

te.distributed.checkpoint(..., context_fn=torch_autocast_ctx, ...)

The autocast contexts here would need to be configured consistently for how you want the forward and recompute to be done.

Hi @tohinz -- could you confirm if PR #791 resolves your issue? Thanks!

Hi @denera , I can verify that PR #791 fixed it for me. We can close this issue once it's merged to main.

@tohinz -- we've merged PR #791 today so I'm closing the issue. Thank you for reporting it!