pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration

Home Page:https://pytorch.org

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

FSDP state dict OOM during model saving

wanchaol opened this issue Β· comments

πŸ› Describe the bug

see related user reporting issues in tatsu-lab/stanford_alpaca#81 and lm-sys/FastChat#256

A workaround that the community is applying is:

Assume you are using torch=1.13.0, change python/lib/python3.9/site packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:2224 from state_dict[fqn] = state_dict[fqn].clone().detach() to state_dict[fqn] = state_dict[fqn].cpu().clone().detach()`

This is pretty manual monkey patching and we should really fix this in pytorch directly.

@fegin @awgu @rohan-varma @zhaojuanmao

Versions

This happens since pytorch 1.13 and I don't think we have fixed it so far.

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu

The recommended solution is to turn on cpu_offload for state_dict. The example can be found https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.set_state_dict_type

I am closing this given @fegin's comment.

@fegin @awgu I tried using:

with FSDP.FullyShardedDataParallel.state_dict_type(
        trainer.model,
        StateDictType.LOCAL_STATE_DICT, # or any other StateDictType
        LocalStateDictConfig(offload_to_cpu=True), # or without this line
        LocalOptimStateDictConfig(offload_to_cpu=True), # or without this line
        ):
    state_dict = trainer.model.state_dict()

The program will be stuck in this state for a very long time:
image

It eventually times out

I finally managed to save the model with python3.10 and torch==2.0 by change /python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py on line 309 from state_dict[fqn] = state_dict[fqn].clone().detach() to state_dict[fqn] = state_dict[fqn].cpu().clone().detach(). It works really well.

@alanxmay many thanks to you ! NB plus!!!!

still facing this problem, my process hangs forever

@fegin @awgu I tried using:

with FSDP.FullyShardedDataParallel.state_dict_type(
        trainer.model,
        StateDictType.LOCAL_STATE_DICT, # or any other StateDictType
        LocalStateDictConfig(offload_to_cpu=True), # or without this line
        LocalOptimStateDictConfig(offload_to_cpu=True), # or without this line
        ):
    state_dict = trainer.model.state_dict()

The program will be stuck in this state for a very long time: image

It eventually times out

I finally managed to save the model with python3.10 and torch==2.0 by change /python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py on line 309 from state_dict[fqn] = state_dict[fqn].clone().detach() to state_dict[fqn] = state_dict[fqn].cpu().clone().detach(). It works really well.

Can I please ask how you call the model.state_dict? Did you use with FSDP.FullyShardedDataParallel.state_dict_type or just change the line you mentioned?

@alanxmay I would like to understand the more detail issue here. The code you attached is LOCAL_STATE_DICT but the fix you mentioned is actually not used by LOCAL_STATE_DICT but only for FULL_STATE_DICT. Does LOCAL_STATE_DICT also cause hang/OOM? Does the fix also solve the hang issue for LOCAL_STATE_DICT? How do you call FULL_STATE_DICT?

@JulioZhao97 Just modify the line

@fegin With the fix, the save function is the same as in fastchat.

def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
    """Collects the state dict and dump to disk."""
    state_dict = trainer.model.state_dict()
    if trainer.args.should_save:
        cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
        del state_dict
        trainer._save(output_dir, state_dict=cpu_state_dict)  # noqa

Without the fix, according to the FSDP tutorial, I tried use FULL_STATE_DICT and LOCAL_STATE_DICT, both cause hang.

def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
    """Collects the state dict and dump to disk."""
    with FSDP.FullyShardedDataParallel.state_dict_type(
            trainer.model,
            StateDictType.LOCAL_STATE_DICT, # or any other StateDictType
            LocalStateDictConfig(offload_to_cpu=True), # or without this line
            LocalOptimStateDictConfig(offload_to_cpu=True), # or without this line
            ):
        state_dict = trainer.model.state_dict()
    if trainer.args.should_save:
        cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
        del state_dict
        trainer._save(output_dir, state_dict=cpu_state_dict)  # noqa

@fegin With the fix, the save function is the same as in fastchat.

def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
    """Collects the state dict and dump to disk."""
    state_dict = trainer.model.state_dict()
    if trainer.args.should_save:
        cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
        del state_dict
        trainer._save(output_dir, state_dict=cpu_state_dict)  # noqa

Without the fix, according to the FSDP tutorial, I tried use FULL_STATE_DICT and LOCAL_STATE_DICT, both cause hang.

def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
    """Collects the state dict and dump to disk."""
    with FSDP.FullyShardedDataParallel.state_dict_type(
            trainer.model,
            StateDictType.LOCAL_STATE_DICT, # or any other StateDictType
            LocalStateDictConfig(offload_to_cpu=True), # or without this line
            LocalOptimStateDictConfig(offload_to_cpu=True), # or without this line
            ):
        state_dict = trainer.model.state_dict()
    if trainer.args.should_save:
        cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
        del state_dict
        trainer._save(output_dir, state_dict=cpu_state_dict)  # noqa

Thanks for sharing, I will try right away.

Finally, my issue was fixed, my simplified code is as follows:

def save_model(self, model):
    # save model
    model = unwrap_model(model)
    state_dict = model_no_ddp.state_dict()
    cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
    del state_dict
    torch.save(state_dict , 'model.pth')

def sync_logger(self):
    ...
    torch.distributed.barrier()
    torch.distributed.all_reduce()  
    ...

def train(self):
    for epoch in range(total_epochs):
        for iter in range(num_iters):
            # model training code here
            ....
        self.save_model(model)
        ....
        self.sync_logger()    # a barrier code set for synchronizing loggers
        ....

It turns out that the cause of hanging is torch.distributed.barrier(), when I comment out this line, the checkpoint saving works fine. For anyone else who is facing this hanging issue in FSDP, this can be a possible reason.

I further test the behavior of fsdp_model.state_dict() as follows:

def save_model(self, model):
    print(f'rank {dist.get_rank()} enter _save_checkpoint')
    # save model
    model = unwrap_model(model)
    print(f'rank {dist.get_rank()} before call state_dict()')
    state_dict = model_no_ddp.state_dict()
    cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
    del state_dict
    torch.save(state_dict , 'model.pth')
    print(f'rank {dist.get_rank()} save model')

def train(self):
    for epoch in range(total_epochs):
        for iter in range(num_iters):
            # model training code here
            ....
        print(f'now is rank {dist.get_rank()}')
        self.save_model(model)
        print(f'rank {dist.get_rank()} save complete')
        ....

And the output is this:

now is rank 5
now is rank 6
now is rank 3
now is rank 2
now is rank 4
now is rank 0
now is rank 1
now is rank 7
rank 0 enter _save_checkpoint
rank 0 before call state_dict()
rank 0 save model
rank 4 save complete
rank 2 save complete
rank 3 save complete
rank 6 save complete
rank 5 save complete
rank 0 save complete
rank 1 save complete
rank 7 save complete

It seems that all other processes are blocked when state_dict() is called (or jump the state_dict() code). This is the weirdest thing, can anyone enlighten me on this?

@JulioZhao97 That is very interesting. Given the code you showed, I cannot understand how that can happen. FSDP's .state_dict() call requires all ranks to participate, which means that the non-zero ranks must have also entered save_model().

I am not sure why a barrier in your previous code snippet would cause hangs either. I wonder if the two issues are related somehow.

@alanxmay Can you confirm that the entire model is wrapped by FSDP? Or only parts of the model are wrapped by FSDP?

@JulioZhao97 That is very interesting. Given the code you showed, I cannot understand how that can happen. FSDP's .state_dict() call requires all ranks to participate, which means that the non-zero ranks must have also entered save_model().

I am not sure why a barrier in your previous code snippet would cause hangs either. I wonder if the two issues are related somehow.

My bad, I just notice that there is a @main_process decorator ahead of my save_checkpoint function, I will test again whether this decorator or torch.distributed.barrier() cause the hang.

    @main_process
    def _save_checkpoint(self, cur_epoch, is_best=False):
        """
        Save the checkpoint at the current epoch.
        """
        print(f'rank {dist.get_rank()} save checkpoint')
        model_no_ddp = self.unwrap_dist_model(self.model)
        param_grad_dic = {
            k: v.requires_grad for (k, v) in model_no_ddp.named_parameters()
        }
        state_dict = model_no_ddp.state_dict()
        cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
        del state_dict
        for k in list(cpu_state_dict.keys()):
            if k in param_grad_dic.keys() and not param_grad_dic[k]:
                # delete parameters that do not require gradient
                del cpu_state_dict[k]
        save_obj = {
            "model": cpu_state_dict,
            "optimizer": self.optimizer.state_dict(),
            "config": self.config.to_dict(),
            "scaler": self.scaler.state_dict() if self.scaler else None,
            "epoch": cur_epoch,
        }
        save_to = os.path.join(
            self.output_dir,
            "checkpoint_{}.pth".format("best" if is_best else cur_epoch),
        )
        print(f'rank {dist.get_rank()} save checkpoint complete')
        logging.info("Saving checkpoint at epoch {} to {}.".format(cur_epoch, save_to))
        torch.save(save_obj, save_to)

@JulioZhao97 That is very interesting. Given the code you showed, I cannot understand how that can happen. FSDP's .state_dict() call requires all ranks to participate, which means that the non-zero ranks must have also entered save_model().

I am not sure why a barrier in your previous code snippet would cause hangs either. I wonder if the two issues are related somehow.

case closed, the hanging is because of the decorator @main_process, not the torch.distributed.barrier()

@fegin Thanks for your valuable suggestion, I will check later

@JulioZhao97 That is very interesting. Given the code you showed, I cannot understand how that can happen. FSDP's .state_dict() call requires all ranks to participate, which means that the non-zero ranks must have also entered save_model().
I am not sure why a barrier in your previous code snippet would cause hangs either. I wonder if the two issues are related somehow.

case closed, the hanging is because of the decorator @main_process, not the torch.distributed.barrier()

Did you apply the fix, or just fix the save function as you metioned? I'd like to dig further into the problem I'm having. Thanks!

@JulioZhao97 That is very interesting. Given the code you showed, I cannot understand how that can happen. FSDP's .state_dict() call requires all ranks to participate, which means that the non-zero ranks must have also entered save_model().
I am not sure why a barrier in your previous code snippet would cause hangs either. I wonder if the two issues are related somehow.

case closed, the hanging is because of the decorator @main_process, not the torch.distributed.barrier()

Did you apply the fix, or just fix the save function as you metioned? I'd like to dig further into the problem I'm having. Thanks!

I also apply the fix you mentioned, but not dive in deeper.

origin:

    state_dict = trainer.model.state_dict()
    if trainer.args.should_save:
        cpu_state_dict = {
            key: value.cpu()
            for key, value in state_dict.items()
        }
        del state_dict
        trainer._save(output_dir, state_dict=cpu_state_dict)  # noqa

fixed:

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType, FullStateDictConfig

    model = trainer.model
    save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
    with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
        cpu_state_dict = model.state_dict()
        if trainer.args.should_save:
            trainer._save(output_dir, state_dict=cpu_state_dict)  # noqa

@KooSung Did you only change the save function or you also change the _state_dict_utils.py as @alanxmay did?

@JACKHAHA363 did you solve it via KooSung's method? I tried @alanxmay 's approach but it also raised an OOM error.
`
File "/miniconda3/envs/LLAMA2/lib/python3.9/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1753, in optim_state_dict
return FullyShardedDataParallel._optim_state_dict_impl(
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.18 GiB (GPU 3; 79.32 GiB total capacity; 75.41 GiB already allocated; 891.56 MiB free; 76.51 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

File "/miniconda3/envs/LLAMA2/lib/python3.9/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1154, in _optim_state_dict_impl
return _optim_state_dict(
File "/miniconda3/envs/LLAMA2/lib/python3.9/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1463, in _optim_state_dict
unflat_state = _unflatten_optim_state(
File "/miniconda3/envs/LLAMA2/lib/python3.9/site-packages/torch/distributed/fsdp/_optim_utils.py", line 136, in _unflatten_optim_state
consolidated_state = _communicate_optim_state(
File "/miniconda3/envs/LLAMA2/lib/python3.9/site-packages/torch/distributed/fsdp/_optim_utils.py", line 208, in _communicate_optim_state
tensor_buffer = value.new_zeros(*buffer_size)
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.18 GiB (GPU 2; 79.32 GiB total capacity; 75.41 GiB already allocated; 931.56 MiB free; 76.47 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
`

origin:

    state_dict = trainer.model.state_dict()
    if trainer.args.should_save:
        cpu_state_dict = {
            key: value.cpu()
            for key, value in state_dict.items()
        }
        del state_dict
        trainer._save(output_dir, state_dict=cpu_state_dict)  # noqa

fixed:

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType, FullStateDictConfig

    model = trainer.model
    save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
    with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
        cpu_state_dict = model.state_dict()
        if trainer.args.should_save:
            trainer._save(output_dir, state_dict=cpu_state_dict)  # noqa

Mark Thanks @ntlm1686