te.checkpoint does not work on nn.Module that consists of te blocks
tohinz opened this issue · comments
We have a model that consists of several layers of te.Linear defined as, e.g. torch.nn.Sequential().
E.g.,
num_layers = 5
seq_length = 1024
hidden_dim = 2048
model = torch.nn.Sequential()
for _ in range(num_layers):
model.append(te.Linear(hidden_dim, hidden_dim))
When running te.checkpoint on a single te.Linear layer things work as expected but when we try to checkpoint the whole model via
output = te.checkpoint(model, model_input, use_reentrant=False)
loss = MSE(output, target)
loss.backward()
we get the following error:
torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: A different number of tensors was saved during the original forward and recomputation.
Number of tensors saved during forward: 20
Number of tensors saved during recomputation: 10
I believe this is due to the fact that the function _is_te_module (run here) returns true if the module is not directly a TE module (and the te.checkpoint()
will run the default PyTorch checkpoint function). But for many more complex model definitions this willl always be the case.
If I remove that _is_te_module
check from te.checkpoint()
to make sure it's not running the torch checkpoint function but actually the TE checkpoint function the above code works and I also verified that the gradients are the same regardless of whether I run with or without checkpointing.
I don't know if this is intended behavior but in many cases we might want to apply checkpointing to larger parts of the model and not individual layers.
If this is not the intended behavior it might make sense to update the _is_te_module
to recursively check all children of the module to see if any of them are TE modules.
Attached code to reproduce the error (commenting out the following lines as a quick check will make it run and produce correct gradients):
import torch
from torch.autograd import grad
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling
if __name__ == "__main__":
num_layers = 5
seq_length = 1024
hidden_dim = 2048
# Generate random model input and target for MSE loss
model_input = torch.rand(8, seq_length, hidden_dim).cuda().to(dtype=torch.float)
target = torch.rand(8, seq_length, hidden_dim).cuda().to(dtype=torch.float)
criterion = torch.nn.MSELoss()
# Define the model
model = torch.nn.Sequential()
for _ in range(num_layers):
model.append(te.Linear(hidden_dim, hidden_dim))
model.to(dtype=torch.float32).cuda()
# Define FP8
fp8_format = Format.HYBRID
fp8_recipe = DelayedScaling(
fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max"
)
autocast_args = {"enabled": True, "fp8_recipe": fp8_recipe}
autocast = te.fp8_autocast
def inner(compare_grads):
with autocast(**autocast_args):
output = te.checkpoint(model, model_input, use_reentrant=False)
output2 = model(model_input)
loss = criterion(output, target)
loss2 = criterion(output2, target)
if compare_grads:
# compute gradients
grads = grad(loss, model.parameters())
grads2 = grad(loss2, model.parameters())
# compare gradients
print("Gradients are equal: ")
print(torch.all(torch.eq(grads[0], grads2[0])))
print(torch.all(torch.eq(grads[1], grads2[1])))
# print gradients to check they are nonzero and not nan
print("")
print(grads[0])
print(grads2[0])
print("")
print(grads[1])
print(grads2[1])
else:
loss.backward()
loss2.backward()
# run model
fp8_scaling_iters = 50
# warmup iterations to get FP8 scaling parameters
for _ in range(fp8_scaling_iters):
inner(compare_grads=False)
inner(compare_grads=True)
This makes sense. @denera could you take a look at this issue?
Hi @tohinz -- I'll be filing a PR to change the pass through behavior shortly, but just to unblock you for the time being, could you try passing context_fn=te.distributed.get_activation_recompute_contexts
into the checkpoint function?
@tohinz -- we just merged a fix to this to TE main. Could you confirm that it resolves the issue for you? Thanks!
Can confirm that this fixes it for me. Thanks.