NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.

Home Page:https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

te.checkpoint does not work on nn.Module that consists of te blocks

tohinz opened this issue · comments

We have a model that consists of several layers of te.Linear defined as, e.g. torch.nn.Sequential().

E.g.,

num_layers = 5
seq_length = 1024
hidden_dim = 2048

model = torch.nn.Sequential()
for _ in range(num_layers):
    model.append(te.Linear(hidden_dim, hidden_dim))

When running te.checkpoint on a single te.Linear layer things work as expected but when we try to checkpoint the whole model via

output = te.checkpoint(model, model_input, use_reentrant=False)
loss = MSE(output, target)
loss.backward()

we get the following error:

torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: A different number of tensors was saved during the original forward and recomputation.
Number of tensors saved during forward: 20
Number of tensors saved during recomputation: 10

I believe this is due to the fact that the function _is_te_module (run here) returns true if the module is not directly a TE module (and the te.checkpoint() will run the default PyTorch checkpoint function). But for many more complex model definitions this willl always be the case.

If I remove that _is_te_module check from te.checkpoint() to make sure it's not running the torch checkpoint function but actually the TE checkpoint function the above code works and I also verified that the gradients are the same regardless of whether I run with or without checkpointing.

I don't know if this is intended behavior but in many cases we might want to apply checkpointing to larger parts of the model and not individual layers.

If this is not the intended behavior it might make sense to update the _is_te_module to recursively check all children of the module to see if any of them are TE modules.

Attached code to reproduce the error (commenting out the following lines as a quick check will make it run and produce correct gradients):

import torch
from torch.autograd import grad

import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling


if __name__ == "__main__":
    num_layers = 5
    seq_length = 1024
    hidden_dim = 2048

    # Generate random model input and target for MSE loss
    model_input = torch.rand(8, seq_length, hidden_dim).cuda().to(dtype=torch.float)
    target = torch.rand(8, seq_length, hidden_dim).cuda().to(dtype=torch.float)
    criterion = torch.nn.MSELoss()

    # Define the model
    model = torch.nn.Sequential()
    for _ in range(num_layers):
        model.append(te.Linear(hidden_dim, hidden_dim))
    model.to(dtype=torch.float32).cuda()

    # Define FP8
    fp8_format = Format.HYBRID
    fp8_recipe = DelayedScaling(
        fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max"
    )
    autocast_args = {"enabled": True, "fp8_recipe": fp8_recipe}
    autocast = te.fp8_autocast

    def inner(compare_grads):
        with autocast(**autocast_args):
            output = te.checkpoint(model, model_input, use_reentrant=False)
            output2 = model(model_input)

        loss = criterion(output, target)
        loss2 = criterion(output2, target)

        if compare_grads:
            # compute gradients
            grads = grad(loss, model.parameters())
            grads2 = grad(loss2, model.parameters())

            # compare gradients
            print("Gradients are equal: ")
            print(torch.all(torch.eq(grads[0], grads2[0])))
            print(torch.all(torch.eq(grads[1], grads2[1])))

            # print gradients to check they are nonzero and not nan
            print("")
            print(grads[0])
            print(grads2[0])
            print("")
            print(grads[1])
            print(grads2[1])
        else:
            loss.backward()
            loss2.backward()

    # run model
    fp8_scaling_iters = 50

    # warmup iterations to get FP8 scaling parameters
    for _ in range(fp8_scaling_iters):
        inner(compare_grads=False)

    inner(compare_grads=True)

This makes sense. @denera could you take a look at this issue?

Hi @tohinz -- I'll be filing a PR to change the pass through behavior shortly, but just to unblock you for the time being, could you try passing context_fn=te.distributed.get_activation_recompute_contexts into the checkpoint function?

@tohinz -- we just merged a fix to this to TE main. Could you confirm that it resolves the issue for you? Thanks!

Can confirm that this fixes it for me. Thanks.