pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration

Home Page:https://pytorch.org

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[FSDP] FSDP with CPU offload consumes `1.65X` more GPU memory when training models with most of the params frozen

pacman100 opened this issue Β· comments

πŸ› Describe the bug

Context: We have more and more situations where a large part of the model that's being trained is frozen. As these are very large LLMs, we want to leverage FSDP with CPU offloading to fit such large model training with only a tiny fraction of training params on consumer GPUs. To this end, below is an example of finetuning bigscience/mt0-small using LoRA parameter efficient fine-tuning method with/without FSDP.

To fine-tune with FSDP:

  1. Following huggingface/accelerate#807, to avoid AssertionError: expects all parameters to have same requires_grad, created a custom auto_wrap_policy such that the layers with trainable params are in separate FSDP units than those which are frozen. The result model along with FSDP options are given below. We are using Accelerate's FSDP integration . the trainable params have lora_ prefix:
FullyShardedDataParallelPlugin(sharding_strategy=<ShardingStrategy.FULL_SHARD: 1>, backward_prefetch=<BackwardPrefetch.BACKWARD_PRE: 1>, mixed_precision_policy=None, auto_wrap_policy=functools.partial(<function _or_policy at 0x7f022c1a8430>, policies=[functools.partial(<function lambda_auto_wrap_policy at 0x7f022c1a8160>, lambda_fn=<function fsdp_auto_wrap_policy.<locals>.lambda_policy_fn at 0x7f01ec31e3b0>), functools.partial(<function transformer_auto_wrap_policy at 0x7f022c1a8310>, transformer_layer_cls=(<class 'pet.tuners.prefix_tuning.PrefixEncoder'>, <class 'pet.tuners.p_tuning.PromptEncoder'>, <class 'pet.tuners.prompt_tuning.PromptEmbedding'>, <class 'transformers.models.t5.modeling_t5.T5Block'>))]), cpu_offload=CPUOffload(offload_params=True), ignored_modules=None, state_dict_type=<StateDictType.FULL_STATE_DICT: 1>, state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=True), limit_all_gathers=False)


FullyShardedDataParallel(
  (_fsdp_wrapped_module): PETModelForSeq2SeqLM(
    (base_model): LoRAModel(
      (model): MT5ForConditionalGeneration(
        (shared): Embedding(250112, 512)
        (encoder): T5Stack(
          (embed_tokens): Embedding(250112, 512)
          (block): ModuleList(
            (0): FullyShardedDataParallel(
              (_fsdp_wrapped_module): T5Block(
                (layer): ModuleList(
                  (0): T5LayerSelfAttention(
                    (SelfAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                      (relative_attention_bias): Embedding(32, 6)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (1): T5LayerFF(
                    (DenseReluDense): T5DenseGatedActDense(
                      (wi_0): Linear(in_features=512, out_features=1024, bias=False)
                      (wi_1): Linear(in_features=512, out_features=1024, bias=False)
                      (wo): Linear(in_features=1024, out_features=512, bias=False)
                      (dropout): Dropout(p=0.1, inplace=False)
                      (act): NewGELUActivation()
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
              )
            )
            (1): FullyShardedDataParallel(
              (_fsdp_wrapped_module): T5Block(
                (layer): ModuleList(
                  (0): T5LayerSelfAttention(
                    (SelfAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (1): T5LayerFF(
                    (DenseReluDense): T5DenseGatedActDense(
                      (wi_0): Linear(in_features=512, out_features=1024, bias=False)
                      (wi_1): Linear(in_features=512, out_features=1024, bias=False)
                      (wo): Linear(in_features=1024, out_features=512, bias=False)
                      (dropout): Dropout(p=0.1, inplace=False)
                      (act): NewGELUActivation()
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
              )
            )
            (2): FullyShardedDataParallel(
              (_fsdp_wrapped_module): T5Block(
                (layer): ModuleList(
                  (0): T5LayerSelfAttention(
                    (SelfAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (1): T5LayerFF(
                    (DenseReluDense): T5DenseGatedActDense(
                      (wi_0): Linear(in_features=512, out_features=1024, bias=False)
                      (wi_1): Linear(in_features=512, out_features=1024, bias=False)
                      (wo): Linear(in_features=1024, out_features=512, bias=False)
                      (dropout): Dropout(p=0.1, inplace=False)
                      (act): NewGELUActivation()
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
              )
            )
            (3): FullyShardedDataParallel(
              (_fsdp_wrapped_module): T5Block(
                (layer): ModuleList(
                  (0): T5LayerSelfAttention(
                    (SelfAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (1): T5LayerFF(
                    (DenseReluDense): T5DenseGatedActDense(
                      (wi_0): Linear(in_features=512, out_features=1024, bias=False)
                      (wi_1): Linear(in_features=512, out_features=1024, bias=False)
                      (wo): Linear(in_features=1024, out_features=512, bias=False)
                      (dropout): Dropout(p=0.1, inplace=False)
                      (act): NewGELUActivation()
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
              )
            )
            (4): FullyShardedDataParallel(
              (_fsdp_wrapped_module): T5Block(
                (layer): ModuleList(
                  (0): T5LayerSelfAttention(
                    (SelfAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (1): T5LayerFF(
                    (DenseReluDense): T5DenseGatedActDense(
                      (wi_0): Linear(in_features=512, out_features=1024, bias=False)
                      (wi_1): Linear(in_features=512, out_features=1024, bias=False)
                      (wo): Linear(in_features=1024, out_features=512, bias=False)
                      (dropout): Dropout(p=0.1, inplace=False)
                      (act): NewGELUActivation()
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
              )
            )
            (5): FullyShardedDataParallel(
              (_fsdp_wrapped_module): T5Block(
                (layer): ModuleList(
                  (0): T5LayerSelfAttention(
                    (SelfAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (1): T5LayerFF(
                    (DenseReluDense): T5DenseGatedActDense(
                      (wi_0): Linear(in_features=512, out_features=1024, bias=False)
                      (wi_1): Linear(in_features=512, out_features=1024, bias=False)
                      (wo): Linear(in_features=1024, out_features=512, bias=False)
                      (dropout): Dropout(p=0.1, inplace=False)
                      (act): NewGELUActivation()
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
              )
            )
            (6): FullyShardedDataParallel(
              (_fsdp_wrapped_module): T5Block(
                (layer): ModuleList(
                  (0): T5LayerSelfAttention(
                    (SelfAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (1): T5LayerFF(
                    (DenseReluDense): T5DenseGatedActDense(
                      (wi_0): Linear(in_features=512, out_features=1024, bias=False)
                      (wi_1): Linear(in_features=512, out_features=1024, bias=False)
                      (wo): Linear(in_features=1024, out_features=512, bias=False)
                      (dropout): Dropout(p=0.1, inplace=False)
                      (act): NewGELUActivation()
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
              )
            )
            (7): FullyShardedDataParallel(
              (_fsdp_wrapped_module): T5Block(
                (layer): ModuleList(
                  (0): T5LayerSelfAttention(
                    (SelfAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (1): T5LayerFF(
                    (DenseReluDense): T5DenseGatedActDense(
                      (wi_0): Linear(in_features=512, out_features=1024, bias=False)
                      (wi_1): Linear(in_features=512, out_features=1024, bias=False)
                      (wo): Linear(in_features=1024, out_features=512, bias=False)
                      (dropout): Dropout(p=0.1, inplace=False)
                      (act): NewGELUActivation()
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
              )
            )
          )
          (final_layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (decoder): T5Stack(
          (embed_tokens): Embedding(250112, 512)
          (block): ModuleList(
            (0): FullyShardedDataParallel(
              (_fsdp_wrapped_module): T5Block(
                (layer): ModuleList(
                  (0): T5LayerSelfAttention(
                    (SelfAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                      (relative_attention_bias): Embedding(32, 6)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (1): T5LayerCrossAttention(
                    (EncDecAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (2): T5LayerFF(
                    (DenseReluDense): T5DenseGatedActDense(
                      (wi_0): Linear(in_features=512, out_features=1024, bias=False)
                      (wi_1): Linear(in_features=512, out_features=1024, bias=False)
                      (wo): Linear(in_features=1024, out_features=512, bias=False)
                      (dropout): Dropout(p=0.1, inplace=False)
                      (act): NewGELUActivation()
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
              )
            )
            (1): FullyShardedDataParallel(
              (_fsdp_wrapped_module): T5Block(
                (layer): ModuleList(
                  (0): T5LayerSelfAttention(
                    (SelfAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (1): T5LayerCrossAttention(
                    (EncDecAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (2): T5LayerFF(
                    (DenseReluDense): T5DenseGatedActDense(
                      (wi_0): Linear(in_features=512, out_features=1024, bias=False)
                      (wi_1): Linear(in_features=512, out_features=1024, bias=False)
                      (wo): Linear(in_features=1024, out_features=512, bias=False)
                      (dropout): Dropout(p=0.1, inplace=False)
                      (act): NewGELUActivation()
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
              )
            )
            (2): FullyShardedDataParallel(
              (_fsdp_wrapped_module): T5Block(
                (layer): ModuleList(
                  (0): T5LayerSelfAttention(
                    (SelfAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (1): T5LayerCrossAttention(
                    (EncDecAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (2): T5LayerFF(
                    (DenseReluDense): T5DenseGatedActDense(
                      (wi_0): Linear(in_features=512, out_features=1024, bias=False)
                      (wi_1): Linear(in_features=512, out_features=1024, bias=False)
                      (wo): Linear(in_features=1024, out_features=512, bias=False)
                      (dropout): Dropout(p=0.1, inplace=False)
                      (act): NewGELUActivation()
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
              )
            )
            (3): FullyShardedDataParallel(
              (_fsdp_wrapped_module): T5Block(
                (layer): ModuleList(
                  (0): T5LayerSelfAttention(
                    (SelfAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (1): T5LayerCrossAttention(
                    (EncDecAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (2): T5LayerFF(
                    (DenseReluDense): T5DenseGatedActDense(
                      (wi_0): Linear(in_features=512, out_features=1024, bias=False)
                      (wi_1): Linear(in_features=512, out_features=1024, bias=False)
                      (wo): Linear(in_features=1024, out_features=512, bias=False)
                      (dropout): Dropout(p=0.1, inplace=False)
                      (act): NewGELUActivation()
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
              )
            )
            (4): FullyShardedDataParallel(
              (_fsdp_wrapped_module): T5Block(
                (layer): ModuleList(
                  (0): T5LayerSelfAttention(
                    (SelfAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (1): T5LayerCrossAttention(
                    (EncDecAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (2): T5LayerFF(
                    (DenseReluDense): T5DenseGatedActDense(
                      (wi_0): Linear(in_features=512, out_features=1024, bias=False)
                      (wi_1): Linear(in_features=512, out_features=1024, bias=False)
                      (wo): Linear(in_features=1024, out_features=512, bias=False)
                      (dropout): Dropout(p=0.1, inplace=False)
                      (act): NewGELUActivation()
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
              )
            )
            (5): FullyShardedDataParallel(
              (_fsdp_wrapped_module): T5Block(
                (layer): ModuleList(
                  (0): T5LayerSelfAttention(
                    (SelfAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (1): T5LayerCrossAttention(
                    (EncDecAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (2): T5LayerFF(
                    (DenseReluDense): T5DenseGatedActDense(
                      (wi_0): Linear(in_features=512, out_features=1024, bias=False)
                      (wi_1): Linear(in_features=512, out_features=1024, bias=False)
                      (wo): Linear(in_features=1024, out_features=512, bias=False)
                      (dropout): Dropout(p=0.1, inplace=False)
                      (act): NewGELUActivation()
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
              )
            )
            (6): FullyShardedDataParallel(
              (_fsdp_wrapped_module): T5Block(
                (layer): ModuleList(
                  (0): T5LayerSelfAttention(
                    (SelfAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (1): T5LayerCrossAttention(
                    (EncDecAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (2): T5LayerFF(
                    (DenseReluDense): T5DenseGatedActDense(
                      (wi_0): Linear(in_features=512, out_features=1024, bias=False)
                      (wi_1): Linear(in_features=512, out_features=1024, bias=False)
                      (wo): Linear(in_features=1024, out_features=512, bias=False)
                      (dropout): Dropout(p=0.1, inplace=False)
                      (act): NewGELUActivation()
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
              )
            )
            (7): FullyShardedDataParallel(
              (_fsdp_wrapped_module): T5Block(
                (layer): ModuleList(
                  (0): T5LayerSelfAttention(
                    (SelfAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (1): T5LayerCrossAttention(
                    (EncDecAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (2): T5LayerFF(
                    (DenseReluDense): T5DenseGatedActDense(
                      (wi_0): Linear(in_features=512, out_features=1024, bias=False)
                      (wi_1): Linear(in_features=512, out_features=1024, bias=False)
                      (wo): Linear(in_features=1024, out_features=512, bias=False)
                      (dropout): Dropout(p=0.1, inplace=False)
                      (act): NewGELUActivation()
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
              )
            )
          )
          (final_layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (lm_head): Linear(in_features=512, out_features=250112, bias=False)
      )
    )
  )
)
  1. The number of trainable params are given below:
trainable params: 344064 || all params: 300520832 || trainable%: 0.11448923447676333
  1. Now, the issue is that in comparison to plain Pytorch, FSDP consumes 1.65X more GPU memory instead of reducing the same by a large amount (expectation) while also greatly increasing the memory consumed on CPU. Below are the screenshots for the same. The hardware used is 1 A100 80GB GPU.

Plain PyTorch:
Screenshot 2022-12-20 at 3 45 16 PM

FSDP Full Shard with CPU offloading:
Screenshot 2022-12-20 at 4 14 04 PM

  1. When trying to use FSDP with CPU offloading using bigscience/mt0-xxl model (13B params) on a A100 80GB GPU it results in OOM GPU error whereas Plain Pytorch consumes 56GB GPU memory.

  2. Expected behaviour: Efficiently deal with frozen weights during training such that large models could be offloaded on CPUs/sharded across GPUs properly with storage of the optimizer state only for the trainable parameters, e.g., we can see that using plain PyTorch, mt0-xxl (13B params) model takes up 56GB on GPU, now, it would be really helpful if one could do CPU offloading such that training could work on a 16GB or 24GB GPU using FSDP with CPU offloading.

Versions

Collecting environment information...
PyTorch version: 1.14.0.dev20221117+cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: 10.0.0-4ubuntu1
CMake version: version 3.25.0
Libc version: glibc-2.31

Python version: 3.10.4 (main, Mar 31 2022, 08:41:55) [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-5.4.0-125-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 11.7.64
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A100-SXM4-80GB
GPU 1: NVIDIA A100-SXM4-80GB
GPU 2: NVIDIA A100-SXM4-80GB
GPU 3: NVIDIA DGX Display
GPU 4: NVIDIA A100-SXM4-80GB

Nvidia driver version: 515.65.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] msgpack-numpy==0.4.7.1
[pip3] numpy==1.23.4
[pip3] torch==1.14.0.dev20221117+cu117
[pip3] torchaudio==0.14.0.dev20221117
[pip3] torchtriton==2.0.0+0d7e753227
[pip3] torchvision==0.15.0.dev20221117
[conda] blas 1.0 mkl
[conda] cudatoolkit 11.3.1 h2bc3f7f_2
[conda] ffmpeg 4.3 hf484d3e_0 pytorch
[conda] mkl 2021.4.0 h06a4308_640
[conda] mkl-service 2.4.0 py310h7f8727e_0
[conda] mkl_fft 1.3.1 py310hd6ae3a3_0
[conda] mkl_random 1.2.2 py310h00e6091_0
[conda] msgpack-numpy 0.4.7.1 pypi_0 pypi
[conda] numpy 1.23.4 pypi_0 pypi
[conda] pytorch-cuda 11.7 h67b0de4_0 pytorch-nightly
[conda] pytorch-mutex 1.0 cuda pytorch
[conda] torch 1.14.0.dev20221117+cu117 pypi_0 pypi
[conda] torchaudio 0.14.0.dev20221117 py310_cu117 pytorch-nightly
[conda] torchtriton 2.0.0+0d7e753227 pypi_0 pypi
[conda] torchvision 0.15.0.dev20221117 py310_cu117 pytorch-nightly

cc @ezyang @gchanan @zou3519 @kadeng @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin

Thanks for providing so many details! This helps us out a lot.

We may need more time to investigate why CPU offload consumes so much more GPU memory. For now, have you guys tried setting limit_all_gathers=True? We plan to default to this soon. Without this, the memory usage may be unexpectedly high (as long as the CPU thread runs faster than the GPU execution, which is almost always). This may or may not be related to this issue.

IIUC, CPU offload was required for the 1T parameter training run (cc: @rohan-varma @zhaojuanmao), so the memory savings were proven there. Conceptually, I am not seeing how frozen parameters/layers affect the GPU memory usage aside from the optimizer state as you mentioned. In that case, if you think it is helpful, could you verify what optimizer states are being created when using FSDP? Are you seeing that optimizer states are being created even for FlatParameters that do not require gradient?

Hello @awgu, thanks for your quick reply. Apart from the optimizer states only for only the tiny portion of trainable params, offloading to CPU should further reduce GPU memory usage. In case of CPU offloading, the expected behaviour would be to only move model shards (T5 block one at a time as they belong to separate FSDP units) from CPU to GPU, do the forward pass computation and then GPU to CPU (similar to B/w pass), hence the max memory consumption on GPU should be that of a single T5 block only. Let me know if there are gaps in my understanding.

@pacman100 Do you have a recommendation for how to try to reproduce this?

commented

@awgu From the source code, it seems that another problem of FSDP is that the freezed parameters are not deallocated until the end of the backward. If most parameters are frozen in this case, then the GPU would have a nearly whole copy of the model parameters.

@awgu From the source code, it seems that another problem of FSDP is that the freezed parameters are not deallocated until the end of the backward. If most parameters are frozen in this case, then the GPU would have a nearly whole copy of the model parameters.

This is unfortunately a known problem :(

We would need to add special hooks to reshard the frozen parameters after they are used for gradient computation. We have not had any bandwidth to look into this though.

@awgu From the source code, it seems that another problem of FSDP is that the freezed parameters are not deallocated until the end of the backward. If most parameters are frozen in this case, then the GPU would have a nearly whole copy of the model parameters.

This is unfortunately a known problem :(

We would need to add special hooks to reshard the frozen parameters after they are used for gradient computation. We have not had any bandwidth to look into this though.

Is it possible to use a sized_policy to deal with those frozen parameters?

@hscspring I do not think so. The issue is more fundamental than simply choosing the wrapping policy and comes from when a parameter corresponds to a module whose output activations require gradient but itself does not require gradient -- in that case, FSDP's pre-backward hook runs but its post-backward hook does not.

@awgu thanks for the quick reply. u r right, wrapping policy only effects how to shard the model.
hope it'll be optimized in the next version.;)

commented

We would need to add special hooks to reshard the frozen parameters after they are used for gradient computation. We have not had any bandwidth to look into this though.

@awgu any idea about how to add the hook? e.g., torch.autograd.Function.

@pengyanghua I am testing #101982 out. I am seeing memory savings in my experiments, but @HamidShojanazeri is not at the moment. We are investigating.

@awgu , just to update the thread as we discussed offline, I could also observe the memory savings as well.

Since #101982 landed, I would recommend people try things out again (e.g. with a nightly). There should be memory savings now.

Thank you @awgu and @HamidShojanazeri. Will be trying this out this week and report back. This should work with CPU offloading too, right?

@pacman100 Yes, it should work with CPU offloading. (It turns out that CPU offloading is connected to "freeing the parameters", so once we fixed "freeing the parameters", the parameters should be offloaded to CPU as well if enabled.)

When trying to use FSDP with CPU offloading using bigscience/mt0-xxl model (13B params) on a A100 80GB GPU it results in OOM GPU error whereas Plain Pytorch consumes 56GB GPU memory.

Tried the latest nightly '2.1.0.dev20230620'

Using 2 A100 GPUs, able to run bigscience/mt0-xxl but doesn't result in any significant memory savings at all compared to Plain PyTorch while consuming a whole lot of CPU memory.

GPU Memory before entering the train : 0
GPU Memory consumed at the end of the train (end-begin): 315
GPU Peak Memory consumed during the train (max-begin): 53360
GPU Total Peak Memory consumed during the train (max): 53360
CPU Memory before entering the train : 26816
CPU Memory consumed at the end of the train (end-begin): 33817
CPU Peak Memory consumed during the train (max-begin): 57895
CPU Total Peak Memory consumed during the train (max): 84711

As I stated earlier, the expected behaviour is

Expected behaviour: Efficiently deal with frozen weights during training such that large models could be offloaded on CPUs/sharded across GPUs properly with storage of the optimizer state only for the trainable parameters, e.g., we can see that using plain PyTorch, mt0-xxl (13B params) model takes up 56GB on GPU, now, it would be really helpful if one could do CPU offloading such that training could work on a 16GB or 24GB GPU using FSDP with CPU offloading.

Steps to reproduce:

  1. Tried the latest nightly '2.1.0.dev20230620'
  2. install peft, transformers and accelerate from source.
  3. remove line 1316 from accelerator.py as FSDP in nightly doesn't support it.
kwargs = {
                        "sharding_strategy": fsdp_plugin.sharding_strategy,
                        "cpu_offload": fsdp_plugin.cpu_offload,
                        "auto_wrap_policy": fsdp_plugin.auto_wrap_policy,
                        "mixed_precision": fsdp_plugin.mixed_precision_policy,
                        "sync_module_states": fsdp_plugin.sync_module_states,
                        "backward_prefetch": fsdp_plugin.backward_prefetch,
                        "forward_prefetch": fsdp_plugin.forward_prefetch,
                        "use_orig_params": fsdp_plugin.use_orig_params,
                        "param_init_fn": fsdp_plugin.param_init_fn,
                        "ignored_modules": fsdp_plugin.ignored_modules,
-                        "ignored_parameters": fsdp_plugin.ignored_parameters,
                        "limit_all_gathers": fsdp_plugin.limit_all_gathers,
                        "device_id": self.device,
                    }
                    model = FSDP(model, **kwargs)
  1. Follow https://github.com/huggingface/peft#caveats related to FSDP. Use the below example to track GPU and CPU memory usage:

FSDP config:

compute_environment: LOCAL_MACHINE
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_backward_prefetch_policy: BACKWARD_PRE
  fsdp_forward_prefetch: false
  fsdp_offload_params: true
  fsdp_sharding_strategy: 1
  fsdp_state_dict_type: FULL_STATE_DICT
  fsdp_sync_module_states: false
  fsdp_transformer_layer_cls_to_wrap: MT5Block
  fsdp_use_orig_params: true
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

Code

import gc
import os
import sys
import threading

import numpy as np
import psutil
import torch
from accelerate import Accelerator
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, get_linear_schedule_with_warmup, set_seed

from peft import LoraConfig, TaskType, get_peft_model
from peft.utils.other import fsdp_auto_wrap_policy

# Converting Bytes to Megabytes
def b2mb(x):
    return int(x / 2**20)


# This context manager is used to track the peak memory usage of the process
class TorchTracemalloc:
    def __enter__(self):
        gc.collect()
        torch.cuda.empty_cache()
        torch.cuda.reset_max_memory_allocated()  # reset the peak gauge to zero
        self.begin = torch.cuda.memory_allocated()
        self.process = psutil.Process()

        self.cpu_begin = self.cpu_mem_used()
        self.peak_monitoring = True
        peak_monitor_thread = threading.Thread(target=self.peak_monitor_func)
        peak_monitor_thread.daemon = True
        peak_monitor_thread.start()
        return self

    def cpu_mem_used(self):
        """get resident set size memory for the current process"""
        return self.process.memory_info().rss

    def peak_monitor_func(self):
        self.cpu_peak = -1

        while True:
            self.cpu_peak = max(self.cpu_mem_used(), self.cpu_peak)

            # can't sleep or will not catch the peak right (this comment is here on purpose)
            # time.sleep(0.001) # 1msec

            if not self.peak_monitoring:
                break

    def __exit__(self, *exc):
        self.peak_monitoring = False

        gc.collect()
        torch.cuda.empty_cache()
        self.end = torch.cuda.memory_allocated()
        self.peak = torch.cuda.max_memory_allocated()
        self.used = b2mb(self.end - self.begin)
        self.peaked = b2mb(self.peak - self.begin)

        self.cpu_end = self.cpu_mem_used()
        self.cpu_used = b2mb(self.cpu_end - self.cpu_begin)
        self.cpu_peaked = b2mb(self.cpu_peak - self.cpu_begin)
        # print(f"delta used/peak {self.used:4d}/{self.peaked:4d}")


def main():
    accelerator = Accelerator()
    # model_name_or_path = "bigscience/T0_3B"
    model_name_or_path = "bigscience/mt0-xxl"#"facebook/bart-large"
    dataset_name = "twitter_complaints"
    peft_config = LoraConfig(
        task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1
    )
    text_column = "Tweet text"
    label_column = "text_label"
    lr = 3e-3
    num_epochs = 5
    batch_size = 8
    seed = 42
    do_test = False
    set_seed(seed)

    dataset = load_dataset("ought/raft", dataset_name)
    classes = [k.replace("_", " ") for k in dataset["train"].features["Label"].names]
    dataset = dataset.map(
        lambda x: {"text_label": [classes[label] for label in x["Label"]]},
        batched=True,
        num_proc=1,
    )

    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
    target_max_length = max([len(tokenizer(class_label)["input_ids"]) for class_label in classes])

    def preprocess_function(examples):
        inputs = examples[text_column]
        targets = examples[label_column]
        model_inputs = tokenizer(inputs, truncation=True)
        labels = tokenizer(
            targets, max_length=target_max_length, padding="max_length", truncation=True, return_tensors="pt"
        )
        labels = labels["input_ids"]
        labels[labels == tokenizer.pad_token_id] = -100
        model_inputs["labels"] = labels
        return model_inputs

    with accelerator.main_process_first():
        processed_datasets = dataset.map(
            preprocess_function,
            batched=True,
            num_proc=1,
            remove_columns=dataset["train"].column_names,
            load_from_cache_file=True,
            desc="Running tokenizer on dataset",
        )
    accelerator.wait_for_everyone()

    train_dataset = processed_datasets["train"]
    eval_dataset = processed_datasets["train"]
    test_dataset = processed_datasets["test"]

    def collate_fn(examples):
        return tokenizer.pad(examples, padding="longest", return_tensors="pt")

    train_dataloader = DataLoader(
        train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=batch_size, pin_memory=True
    )
    eval_dataloader = DataLoader(eval_dataset, collate_fn=collate_fn, batch_size=batch_size, pin_memory=True)
    test_dataloader = DataLoader(test_dataset, collate_fn=collate_fn, batch_size=batch_size, pin_memory=True)

    # creating model
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)
    model = get_peft_model(model, peft_config)
    model.print_trainable_parameters()

    # optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

    # lr scheduler
    lr_scheduler = get_linear_schedule_with_warmup(
        optimizer=optimizer,
        num_warmup_steps=0,
        num_training_steps=(len(train_dataloader) * num_epochs),
    )
    
    if getattr(accelerator.state, "fsdp_plugin", None) is not None:
        accelerator.state.fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(model)


    model, train_dataloader, eval_dataloader, test_dataloader, optimizer, lr_scheduler = accelerator.prepare(
        model, train_dataloader, eval_dataloader, test_dataloader, optimizer, lr_scheduler
    )
    accelerator.print(model)

    for epoch in range(num_epochs):
        with TorchTracemalloc() as tracemalloc:
            model.train()
            total_loss = 0
            for step, batch in enumerate(tqdm(train_dataloader)):
                outputs = model(**batch)
                loss = outputs.loss
                total_loss += loss.detach().float()
                accelerator.backward(loss)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()
        # Printing the GPU memory usage details such as allocated memory, peak memory, and total memory usage
        accelerator.print("GPU Memory before entering the train : {}".format(b2mb(tracemalloc.begin)))
        accelerator.print("GPU Memory consumed at the end of the train (end-begin): {}".format(tracemalloc.used))
        accelerator.print("GPU Peak Memory consumed during the train (max-begin): {}".format(tracemalloc.peaked))
        accelerator.print(
            "GPU Total Peak Memory consumed during the train (max): {}".format(
                tracemalloc.peaked + b2mb(tracemalloc.begin)
            )
        )

        accelerator.print("CPU Memory before entering the train : {}".format(b2mb(tracemalloc.cpu_begin)))
        accelerator.print("CPU Memory consumed at the end of the train (end-begin): {}".format(tracemalloc.cpu_used))
        accelerator.print("CPU Peak Memory consumed during the train (max-begin): {}".format(tracemalloc.cpu_peaked))
        accelerator.print(
            "CPU Total Peak Memory consumed during the train (max): {}".format(
                tracemalloc.cpu_peaked + b2mb(tracemalloc.cpu_begin)
            )
        )
        train_epoch_loss = total_loss / len(train_dataloader)
        train_ppl = torch.exp(train_epoch_loss)
        accelerator.print(f"{epoch=}: {train_ppl=} {train_epoch_loss=}")

if __name__ == "__main__":
    main()


output logs:

trainable params: 9,437,184 || all params: 12,930,494,464 || trainable%: 0.072983937515106
None
trainable params: 9,437,184 || all params: 12,930,494,464 || trainable%: 0.072983937515106
Found cached dataset json (/raid/sourab/.cache/huggingface/datasets/json/default-45aab44cd4d75e35/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)
100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 2/2 [00:00<00:00, 1410.32it/s]
Found cached dataset json (/raid/sourab/.cache/huggingface/datasets/json/default-45aab44cd4d75e35/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)
100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 2/2 [00:00<00:00, 1349.52it/s]
FSDP Warning: When using FSDP, it is efficient and recommended to call prepare for the model before creating the optimizer     
FullyShardedDataParallel(                                                                                                      
  (_fsdp_wrapped_module): PeftModelForSeq2SeqLM(
    (base_model): LoraModel(
      (model): MT5ForConditionalGeneration(
        (shared): Embedding(250112, 4096)
        (encoder): MT5Stack(
          (embed_tokens): Embedding(250112, 4096)
          (block): ModuleList(
            (0): FullyShardedDataParallel(
              (_fsdp_wrapped_module): MT5Block(
                (layer): ModuleList(
                  (0): MT5LayerSelfAttention(
                    (SelfAttention): MT5Attention(
                      (q): Linear(
                        in_features=4096, out_features=4096, bias=False
                        (lora_dropout): ModuleDict(
                          (default): Dropout(p=0.1, inplace=False)
                        )
                        (lora_A): ModuleDict(
                          (default): FullyShardedDataParallel(
                            (_fsdp_wrapped_module): Linear(in_features=4096, out_features=8, bias=False)
                          )
                        )
                        (lora_B): ModuleDict(
                          (default): FullyShardedDataParallel(
                            (_fsdp_wrapped_module): Linear(in_features=8, out_features=4096, bias=False)
                          )
                        )
                        (lora_embedding_A): ParameterDict()
                        (lora_embedding_B): ParameterDict()
                      )
                      (k): Linear(in_features=4096, out_features=4096, bias=False)
                      (v): Linear(
                        in_features=4096, out_features=4096, bias=False
                        (lora_dropout): ModuleDict(
                          (default): Dropout(p=0.1, inplace=False)
                        )
                        (lora_A): ModuleDict(
                          (default): FullyShardedDataParallel(
                            (_fsdp_wrapped_module): Linear(in_features=4096, out_features=8, bias=False)
                          )
                        )
                        (lora_B): ModuleDict(
                          (default): FullyShardedDataParallel(
                            (_fsdp_wrapped_module): Linear(in_features=8, out_features=4096, bias=False)
                          )
                        )
                        (lora_embedding_A): ParameterDict()
                        (lora_embedding_B): ParameterDict()
                      )
                      (o): Linear(in_features=4096, out_features=4096, bias=False)
                      (relative_attention_bias): Embedding(32, 64)
                    )
                    (layer_norm): MT5LayerNorm()
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (1): MT5LayerFF(
                    (DenseReluDense): MT5DenseGatedActDense(
                      (wi_0): Linear(in_features=4096, out_features=10240, bias=False)
                      (wi_1): Linear(in_features=4096, out_features=10240, bias=False)
                      (wo): Linear(in_features=10240, out_features=4096, bias=False)
                      (dropout): Dropout(p=0.1, inplace=False)
                      (act): NewGELUActivation()
                    )
                    (layer_norm): MT5LayerNorm()
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
              )
            )
            (1-23): 23 x FullyShardedDataParallel(
              (_fsdp_wrapped_module): MT5Block(
                (layer): ModuleList(
                  (0): MT5LayerSelfAttention(
                    (SelfAttention): MT5Attention(
                      (q): Linear(
                        in_features=4096, out_features=4096, bias=False
                        (lora_dropout): ModuleDict(
                          (default): Dropout(p=0.1, inplace=False)
                        )
                        (lora_A): ModuleDict(
                          (default): FullyShardedDataParallel(
                            (_fsdp_wrapped_module): Linear(in_features=4096, out_features=8, bias=False)
                          )
                        )
                        (lora_B): ModuleDict(
                          (default): FullyShardedDataParallel(
                            (_fsdp_wrapped_module): Linear(in_features=8, out_features=4096, bias=False)
                          )
                        )
                        (lora_embedding_A): ParameterDict()
                        (lora_embedding_B): ParameterDict()
                      )
                      (k): Linear(in_features=4096, out_features=4096, bias=False)
                      (v): Linear(
                        in_features=4096, out_features=4096, bias=False
                        (lora_dropout): ModuleDict(
                          (default): Dropout(p=0.1, inplace=False)
                        )
                        (lora_A): ModuleDict(
                          (default): FullyShardedDataParallel(
                            (_fsdp_wrapped_module): Linear(in_features=4096, out_features=8, bias=False)
                          )
                        )
                        (lora_B): ModuleDict(
                          (default): FullyShardedDataParallel(
                            (_fsdp_wrapped_module): Linear(in_features=8, out_features=4096, bias=False)
                          )
                        )
                        (lora_embedding_A): ParameterDict()
                        (lora_embedding_B): ParameterDict()
                      )
                      (o): Linear(in_features=4096, out_features=4096, bias=False)
                    )
                    (layer_norm): MT5LayerNorm()
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (1): MT5LayerFF(
                    (DenseReluDense): MT5DenseGatedActDense(
                      (wi_0): Linear(in_features=4096, out_features=10240, bias=False)
                      (wi_1): Linear(in_features=4096, out_features=10240, bias=False)
                      (wo): Linear(in_features=10240, out_features=4096, bias=False)
                      (dropout): Dropout(p=0.1, inplace=False)
                      (act): NewGELUActivation()
                    )
                    (layer_norm): MT5LayerNorm()
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
              )
            )
          )
          (final_layer_norm): MT5LayerNorm()
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (decoder): MT5Stack(
          (embed_tokens): Embedding(250112, 4096)
          (block): ModuleList(
            (0): FullyShardedDataParallel(
              (_fsdp_wrapped_module): MT5Block(
                (layer): ModuleList(
                  (0): MT5LayerSelfAttention(
                    (SelfAttention): MT5Attention(
                      (q): Linear(
                        in_features=4096, out_features=4096, bias=False
                        (lora_dropout): ModuleDict(
                          (default): Dropout(p=0.1, inplace=False)
                        )
                        (lora_A): ModuleDict(
                          (default): FullyShardedDataParallel(
                            (_fsdp_wrapped_module): Linear(in_features=4096, out_features=8, bias=False)
                          )
                        )
                        (lora_B): ModuleDict(
                          (default): FullyShardedDataParallel(
                            (_fsdp_wrapped_module): Linear(in_features=8, out_features=4096, bias=False)
                          )
                        )
                        (lora_embedding_A): ParameterDict()
                        (lora_embedding_B): ParameterDict()
                      )
                      (k): Linear(in_features=4096, out_features=4096, bias=False)
                      (v): Linear(
                        in_features=4096, out_features=4096, bias=False
                        (lora_dropout): ModuleDict(
                          (default): Dropout(p=0.1, inplace=False)
                        )
                        (lora_A): ModuleDict(
                          (default): FullyShardedDataParallel(
                            (_fsdp_wrapped_module): Linear(in_features=4096, out_features=8, bias=False)
                          )
                        )
                        (lora_B): ModuleDict(
                          (default): FullyShardedDataParallel(
                            (_fsdp_wrapped_module): Linear(in_features=8, out_features=4096, bias=False)
                          )
                        )
                        (lora_embedding_A): ParameterDict()
                        (lora_embedding_B): ParameterDict()
                      )
                      (o): Linear(in_features=4096, out_features=4096, bias=False)
                      (relative_attention_bias): Embedding(32, 64)
                    )
                    (layer_norm): MT5LayerNorm()
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (1): MT5LayerCrossAttention(
                    (EncDecAttention): MT5Attention(
                      (q): Linear(
                        in_features=4096, out_features=4096, bias=False
                        (lora_dropout): ModuleDict(
                          (default): Dropout(p=0.1, inplace=False)
                        )
                        (lora_A): ModuleDict(
                          (default): FullyShardedDataParallel(
                            (_fsdp_wrapped_module): Linear(in_features=4096, out_features=8, bias=False)
                          )
                        )
                        (lora_B): ModuleDict(
                          (default): FullyShardedDataParallel(
                            (_fsdp_wrapped_module): Linear(in_features=8, out_features=4096, bias=False)
                          )
                        )
                        (lora_embedding_A): ParameterDict()
                        (lora_embedding_B): ParameterDict()
                      )
                      (k): Linear(in_features=4096, out_features=4096, bias=False)
                      (v): Linear(
                        in_features=4096, out_features=4096, bias=False
                        (lora_dropout): ModuleDict(
                          (default): Dropout(p=0.1, inplace=False)
                        )
                        (lora_A): ModuleDict(
                          (default): FullyShardedDataParallel(
                            (_fsdp_wrapped_module): Linear(in_features=4096, out_features=8, bias=False)
                          )
                        )
                        (lora_B): ModuleDict(
                          (default): FullyShardedDataParallel(
                            (_fsdp_wrapped_module): Linear(in_features=8, out_features=4096, bias=False)
                          )
                        )
                        (lora_embedding_A): ParameterDict()
                        (lora_embedding_B): ParameterDict()
                      )
                      (o): Linear(in_features=4096, out_features=4096, bias=False)
                    )
                    (layer_norm): MT5LayerNorm()
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (2): MT5LayerFF(
                    (DenseReluDense): MT5DenseGatedActDense(
                      (wi_0): Linear(in_features=4096, out_features=10240, bias=False)
                      (wi_1): Linear(in_features=4096, out_features=10240, bias=False)
                      (wo): Linear(in_features=10240, out_features=4096, bias=False)
                      (dropout): Dropout(p=0.1, inplace=False)
                      (act): NewGELUActivation()
                    )
                    (layer_norm): MT5LayerNorm()
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
              )
            )
            (1-23): 23 x FullyShardedDataParallel(
              (_fsdp_wrapped_module): MT5Block(
                (layer): ModuleList(
                  (0): MT5LayerSelfAttention(
                    (SelfAttention): MT5Attention(
                      (q): Linear(
                        in_features=4096, out_features=4096, bias=False
                        (lora_dropout): ModuleDict(
                          (default): Dropout(p=0.1, inplace=False)
                        )
                        (lora_A): ModuleDict(
                          (default): FullyShardedDataParallel(
                            (_fsdp_wrapped_module): Linear(in_features=4096, out_features=8, bias=False)
                          )
                        )
                        (lora_B): ModuleDict(
                          (default): FullyShardedDataParallel(
                            (_fsdp_wrapped_module): Linear(in_features=8, out_features=4096, bias=False)
                          )
                        )
                        (lora_embedding_A): ParameterDict()
                        (lora_embedding_B): ParameterDict()
                      )
                      (k): Linear(in_features=4096, out_features=4096, bias=False)
                      (v): Linear(
                        in_features=4096, out_features=4096, bias=False
                        (lora_dropout): ModuleDict(
                          (default): Dropout(p=0.1, inplace=False)
                        )
                        (lora_A): ModuleDict(
                          (default): FullyShardedDataParallel(
                            (_fsdp_wrapped_module): Linear(in_features=4096, out_features=8, bias=False)
                          )
                        )
                        (lora_B): ModuleDict(
                          (default): FullyShardedDataParallel(
                            (_fsdp_wrapped_module): Linear(in_features=8, out_features=4096, bias=False)
                          )
                        )
                        (lora_embedding_A): ParameterDict()
                        (lora_embedding_B): ParameterDict()
                      )
                      (o): Linear(in_features=4096, out_features=4096, bias=False)
                    )
                    (layer_norm): MT5LayerNorm()
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (1): MT5LayerCrossAttention(
                    (EncDecAttention): MT5Attention(
                      (q): Linear(
                        in_features=4096, out_features=4096, bias=False
                        (lora_dropout): ModuleDict(
                          (default): Dropout(p=0.1, inplace=False)
                        )
                        (lora_A): ModuleDict(
                          (default): FullyShardedDataParallel(
                            (_fsdp_wrapped_module): Linear(in_features=4096, out_features=8, bias=False)
                          )
                        )
                        (lora_B): ModuleDict(
                          (default): FullyShardedDataParallel(
                            (_fsdp_wrapped_module): Linear(in_features=8, out_features=4096, bias=False)
                          )
                        )
                        (lora_embedding_A): ParameterDict()
                        (lora_embedding_B): ParameterDict()
                      )
                      (k): Linear(in_features=4096, out_features=4096, bias=False)
                      (v): Linear(
                        in_features=4096, out_features=4096, bias=False
                        (lora_dropout): ModuleDict(
                          (default): Dropout(p=0.1, inplace=False)
                        )
                        (lora_A): ModuleDict(
                          (default): FullyShardedDataParallel(
                            (_fsdp_wrapped_module): Linear(in_features=4096, out_features=8, bias=False)
                          )
                        )
                        (lora_B): ModuleDict(
                          (default): FullyShardedDataParallel(
                            (_fsdp_wrapped_module): Linear(in_features=8, out_features=4096, bias=False)
                          )
                        )
                        (lora_embedding_A): ParameterDict()
                        (lora_embedding_B): ParameterDict()
                      )
                      (o): Linear(in_features=4096, out_features=4096, bias=False)
                    )
                    (layer_norm): MT5LayerNorm()
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (2): MT5LayerFF(
                    (DenseReluDense): MT5DenseGatedActDense(
                      (wi_0): Linear(in_features=4096, out_features=10240, bias=False)
                      (wi_1): Linear(in_features=4096, out_features=10240, bias=False)
                      (wo): Linear(in_features=10240, out_features=4096, bias=False)
                      (dropout): Dropout(p=0.1, inplace=False)
                      (act): NewGELUActivation()
                    )
                    (layer_norm): MT5LayerNorm()
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
              )
            )
          )
          (final_layer_norm): MT5LayerNorm()
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (lm_head): Linear(in_features=4096, out_features=250112, bias=False)
      )
    )
  )
)
/home/sourab/miniconda3/envs/ml/lib/python3.11/site-packages/torch/cuda/memory.py:307: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
/home/sourab/miniconda3/envs/ml/lib/python3.11/site-packages/torch/cuda/memory.py:307: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
GPU Memory before entering the train : 0
GPU Memory consumed at the end of the train (end-begin): 315
GPU Peak Memory consumed during the train (max-begin): 53360
GPU Total Peak Memory consumed during the train (max): 53360
CPU Memory before entering the train : 26816
CPU Memory consumed at the end of the train (end-begin): 33817
CPU Peak Memory consumed during the train (max-begin): 57895
CPU Total Peak Memory consumed during the train (max): 84711
epoch=0: train_ppl=tensor(16.8360, device='cuda:0') train_epoch_loss=tensor(2.8235, device='cuda:0')

@pacman100 Can you please share the manually warp code for lora? I encounter the same problem. Many Thanks!

@pacman100 Is this issue solved in recent pytorch versions?
Should I not bother with trying FSDP for reducing memory in a fine-tuning task with most parameters frozen?

Still having OOM issue with CPU offload on a 7b model in mixed precision

Hello, come across this important discussion. In the latest PyTorch Nightlies (for example in the LLaMA-receipes implementation), runing PEFT with FSDP indeeds saves memory. For example, when using a mini-batch size of 4 to train 13B model on 4 GPU (with pure BF16), VRAM per GPU is 25G (PEFT) vs 55g (full model). Could we confirm this issue has resolved in the latest PyTorch Nightlies? Many thanks.

Also note to future readers, there are some excellent discussion on this led by @awgu:
#104690
#100945
#104690

Also looking at LLaMA-recipes implementation, seems they largely followed the rec from this post. They will use the customized lambda_policy_fn only when using peft.

@awgu There's a lot of issues open regarding FSDP + PEFT. I've been looking at a lot of these discussion threads and unfortunately even with CPU Offloading and a wrapping policy suggested here, I'm still facing persistent OOM issues. This was another issue we'd like to hear updates on.

Is it possible to consolidate everything regarding this topic (current workarounds, PT nightlies, current problems, WIP if any) either on this thread or in a new issue on this repo?

That would really help a lot of us currently struggling to evade OOMs with FSDP + PEFT.