Lightning-AI / litgpt

Pretrain, finetune, deploy 20+ LLMs on your own data. Uses state-of-the-art techniques: flash attention, FSDP, 4-bit, LoRA, and more.

Home Page:https://lightning.ai

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

bugs use FSDP

batman-do opened this issue · comments

Using this will reduce the model's VRAM when using FSDP

from functools import partial
        from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
        from torch.distributed.fsdp import MixedPrecision

        wrap_policy = partial(
            transformer_auto_wrap_policy, transformer_layer_cls={Block}
        )
        bfSixteen = MixedPrecision(
            param_dtype=torch.bfloat16,
            # Gradient communication precision.
            reduce_dtype=torch.bfloat16,
            # Buffer precision.
            buffer_dtype=torch.bfloat16,
        )
        strategy = FSDPStrategy(
            # auto_wrap_policy={Block},
            # activation_checkpointing_policy={Block},
            auto_wrap_policy=wrap_policy,
            activation_checkpointing_policy=wrap_policy,
            state_dict_type="full",
            limit_all_gathers=True,
            cpu_offload=True,
            mixed_precision=bfSixteen,
        )

Replace the repo's code

strategy = FSDPStrategy(
            auto_wrap_policy={Block},
            activation_checkpointing_policy={Block},
            state_dict_type="full",
            limit_all_gathers=True,
            cpu_offload=False,
        )

Can u somebody explain that ? @rasbt