Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.

Home Page:https://lightning.ai

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

XLA FSDP strategy has undocumented requirement for using activation checkpointing

ebreck opened this issue · comments

Bug description

To use activation checkpointing with the XLA FSDP Fabric strategy, the set of modules to checkpoint must also be wrapped by the auto_wrap_policy.

That is, if auto_wrap_policy is provided to the strategy as a set W, and activation_checkpointing_policy is passed as a set C, any members of C that are not also members of W will not actually be checkpointed. This is because the implementation of the activation checkpointing policy adds "also wrap these classes in checkpoint_module" to the functionality of the wrapper callable, but the wrapper is only applied when the auto_wrap_policy tells it to.

This was surprising, since at least at first blush the auto_wrap_policy and activation_checkpointing_policy appear to be independent parameters. And the behavior of silent failure added confusion as to what was going on.

I suggest a few possible remedies, and these are not mutually exclusive.

  1. Document this in the class docstring.
  2. If auto_wrap_policy is provided as a set, activation_checkpointing_policy is not empty, and the former is not a superset of the latter, throw an error. If it's provided as a function, evaluate it for each member of activation_checkpointing_policy and if it's false for any, throw an error.
  3. Union activation_checkpointing_policy with the auto_wrap_policy. Unlike 1 or 2, this would be a behavior change, though at least in my code it's what I ended up doing manually anyway.

What version are you seeing the problem on?

v2.2

How to reproduce the bug

No response

Error messages and logs

# Error messages and logs here please

Environment

Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0):
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 2.0):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):

More info

No response