FSDP state dict OOM during model saving
wanchaol opened this issue Β· comments
π Describe the bug
see related user reporting issues in tatsu-lab/stanford_alpaca#81 and lm-sys/FastChat#256
A workaround that the community is applying is:
Assume you are using torch=1.13.0, change python/lib/python3.9/site packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:2224 from state_dict[fqn] = state_dict[fqn].clone().detach() to state_dict[fqn] = state_dict[fqn].cpu().clone().detach()`
This is pretty manual monkey patching and we should really fix this in pytorch directly.
@fegin @awgu @rohan-varma @zhaojuanmao
Versions
This happens since pytorch 1.13 and I don't think we have fixed it so far.
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu
The recommended solution is to turn on cpu_offload
for state_dict
. The example can be found https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.set_state_dict_type
with FSDP.FullyShardedDataParallel.state_dict_type(
trainer.model,
StateDictType.LOCAL_STATE_DICT, # or any other StateDictType
LocalStateDictConfig(offload_to_cpu=True), # or without this line
LocalOptimStateDictConfig(offload_to_cpu=True), # or without this line
):
state_dict = trainer.model.state_dict()
The program will be stuck in this state for a very long time:
I finally managed to save the model with python3.10 and torch==2.0 by change /python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py
on line 309 from state_dict[fqn] = state_dict[fqn].clone().detach()
to state_dict[fqn] = state_dict[fqn].cpu().clone().detach()
. It works really well.
@alanxmay many thanks to you ! NB plus!!!!
still facing this problem, my process hangs forever
with FSDP.FullyShardedDataParallel.state_dict_type( trainer.model, StateDictType.LOCAL_STATE_DICT, # or any other StateDictType LocalStateDictConfig(offload_to_cpu=True), # or without this line LocalOptimStateDictConfig(offload_to_cpu=True), # or without this line ): state_dict = trainer.model.state_dict()The program will be stuck in this state for a very long time:
I finally managed to save the model with python3.10 and torch==2.0 by change
/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py
on line 309 fromstate_dict[fqn] = state_dict[fqn].clone().detach()
tostate_dict[fqn] = state_dict[fqn].cpu().clone().detach()
. It works really well.
Can I please ask how you call the model.state_dict? Did you use with FSDP.FullyShardedDataParallel.state_dict_type
or just change the line you mentioned?
@alanxmay I would like to understand the more detail issue here. The code you attached is LOCAL_STATE_DICT
but the fix
you mentioned is actually not used by LOCAL_STATE_DICT
but only for FULL_STATE_DICT
. Does LOCAL_STATE_DICT
also cause hang/OOM? Does the fix
also solve the hang issue for LOCAL_STATE_DICT
? How do you call FULL_STATE_DICT
?
@JulioZhao97 Just modify the line
@fegin With the fix
, the save
function is the same as in fastchat.
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
"""Collects the state dict and dump to disk."""
state_dict = trainer.model.state_dict()
if trainer.args.should_save:
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
del state_dict
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
Without the fix
, according to the FSDP tutorial, I tried use FULL_STATE_DICT
and LOCAL_STATE_DICT
, both cause hang.
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
"""Collects the state dict and dump to disk."""
with FSDP.FullyShardedDataParallel.state_dict_type(
trainer.model,
StateDictType.LOCAL_STATE_DICT, # or any other StateDictType
LocalStateDictConfig(offload_to_cpu=True), # or without this line
LocalOptimStateDictConfig(offload_to_cpu=True), # or without this line
):
state_dict = trainer.model.state_dict()
if trainer.args.should_save:
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
del state_dict
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
@fegin With the
fix
, thesave
function is the same as in fastchat.def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): """Collects the state dict and dump to disk.""" state_dict = trainer.model.state_dict() if trainer.args.should_save: cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} del state_dict trainer._save(output_dir, state_dict=cpu_state_dict) # noqaWithout the
fix
, according to the FSDP tutorial, I tried useFULL_STATE_DICT
andLOCAL_STATE_DICT
, both cause hang.def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): """Collects the state dict and dump to disk.""" with FSDP.FullyShardedDataParallel.state_dict_type( trainer.model, StateDictType.LOCAL_STATE_DICT, # or any other StateDictType LocalStateDictConfig(offload_to_cpu=True), # or without this line LocalOptimStateDictConfig(offload_to_cpu=True), # or without this line ): state_dict = trainer.model.state_dict() if trainer.args.should_save: cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} del state_dict trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
Thanks for sharing, I will try right away.
Finally, my issue was fixed, my simplified code is as follows:
def save_model(self, model):
# save model
model = unwrap_model(model)
state_dict = model_no_ddp.state_dict()
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
del state_dict
torch.save(state_dict , 'model.pth')
def sync_logger(self):
...
torch.distributed.barrier()
torch.distributed.all_reduce()
...
def train(self):
for epoch in range(total_epochs):
for iter in range(num_iters):
# model training code here
....
self.save_model(model)
....
self.sync_logger() # a barrier code set for synchronizing loggers
....
It turns out that the cause of hanging is torch.distributed.barrier()
, when I comment out this line, the checkpoint saving works fine. For anyone else who is facing this hanging issue in FSDP, this can be a possible reason.
I further test the behavior of fsdp_model.state_dict() as follows:
def save_model(self, model):
print(f'rank {dist.get_rank()} enter _save_checkpoint')
# save model
model = unwrap_model(model)
print(f'rank {dist.get_rank()} before call state_dict()')
state_dict = model_no_ddp.state_dict()
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
del state_dict
torch.save(state_dict , 'model.pth')
print(f'rank {dist.get_rank()} save model')
def train(self):
for epoch in range(total_epochs):
for iter in range(num_iters):
# model training code here
....
print(f'now is rank {dist.get_rank()}')
self.save_model(model)
print(f'rank {dist.get_rank()} save complete')
....
And the output is this:
now is rank 5
now is rank 6
now is rank 3
now is rank 2
now is rank 4
now is rank 0
now is rank 1
now is rank 7
rank 0 enter _save_checkpoint
rank 0 before call state_dict()
rank 0 save model
rank 4 save complete
rank 2 save complete
rank 3 save complete
rank 6 save complete
rank 5 save complete
rank 0 save complete
rank 1 save complete
rank 7 save complete
It seems that all other processes are blocked when state_dict() is called (or jump the state_dict() code). This is the weirdest thing, can anyone enlighten me on this?
@JulioZhao97 That is very interesting. Given the code you showed, I cannot understand how that can happen. FSDP's .state_dict()
call requires all ranks to participate, which means that the non-zero ranks must have also entered save_model()
.
I am not sure why a barrier in your previous code snippet would cause hangs either. I wonder if the two issues are related somehow.
@alanxmay Can you confirm that the entire model is wrapped by FSDP? Or only parts of the model are wrapped by FSDP?
@JulioZhao97 That is very interesting. Given the code you showed, I cannot understand how that can happen. FSDP's
.state_dict()
call requires all ranks to participate, which means that the non-zero ranks must have also enteredsave_model()
.I am not sure why a barrier in your previous code snippet would cause hangs either. I wonder if the two issues are related somehow.
My bad, I just notice that there is a @main_process
decorator ahead of my save_checkpoint
function, I will test again whether this decorator or torch.distributed.barrier()
cause the hang.
@main_process
def _save_checkpoint(self, cur_epoch, is_best=False):
"""
Save the checkpoint at the current epoch.
"""
print(f'rank {dist.get_rank()} save checkpoint')
model_no_ddp = self.unwrap_dist_model(self.model)
param_grad_dic = {
k: v.requires_grad for (k, v) in model_no_ddp.named_parameters()
}
state_dict = model_no_ddp.state_dict()
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
del state_dict
for k in list(cpu_state_dict.keys()):
if k in param_grad_dic.keys() and not param_grad_dic[k]:
# delete parameters that do not require gradient
del cpu_state_dict[k]
save_obj = {
"model": cpu_state_dict,
"optimizer": self.optimizer.state_dict(),
"config": self.config.to_dict(),
"scaler": self.scaler.state_dict() if self.scaler else None,
"epoch": cur_epoch,
}
save_to = os.path.join(
self.output_dir,
"checkpoint_{}.pth".format("best" if is_best else cur_epoch),
)
print(f'rank {dist.get_rank()} save checkpoint complete')
logging.info("Saving checkpoint at epoch {} to {}.".format(cur_epoch, save_to))
torch.save(save_obj, save_to)
@JulioZhao97 That is very interesting. Given the code you showed, I cannot understand how that can happen. FSDP's
.state_dict()
call requires all ranks to participate, which means that the non-zero ranks must have also enteredsave_model()
.I am not sure why a barrier in your previous code snippet would cause hangs either. I wonder if the two issues are related somehow.
case closed, the hanging is because of the decorator @main_process
, not the torch.distributed.barrier()
@JulioZhao97 That is very interesting. Given the code you showed, I cannot understand how that can happen. FSDP's
.state_dict()
call requires all ranks to participate, which means that the non-zero ranks must have also enteredsave_model()
.
I am not sure why a barrier in your previous code snippet would cause hangs either. I wonder if the two issues are related somehow.case closed, the hanging is because of the decorator
@main_process
, not thetorch.distributed.barrier()
Did you apply the fix
, or just fix the save
function as you metioned? I'd like to dig further into the problem I'm having. Thanks!
@JulioZhao97 That is very interesting. Given the code you showed, I cannot understand how that can happen. FSDP's
.state_dict()
call requires all ranks to participate, which means that the non-zero ranks must have also enteredsave_model()
.
I am not sure why a barrier in your previous code snippet would cause hangs either. I wonder if the two issues are related somehow.case closed, the hanging is because of the decorator
@main_process
, not thetorch.distributed.barrier()
Did you apply the
fix
, or just fix thesave
function as you metioned? I'd like to dig further into the problem I'm having. Thanks!
I also apply the fix you mentioned, but not dive in deeper.
origin:
state_dict = trainer.model.state_dict()
if trainer.args.should_save:
cpu_state_dict = {
key: value.cpu()
for key, value in state_dict.items()
}
del state_dict
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
fixed:
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType, FullStateDictConfig
model = trainer.model
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
cpu_state_dict = model.state_dict()
if trainer.args.should_save:
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
@JACKHAHA363 did you solve it via KooSung's method? I tried @alanxmay 's approach but it also raised an OOM error.
`
File "/miniconda3/envs/LLAMA2/lib/python3.9/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1753, in optim_state_dict
return FullyShardedDataParallel._optim_state_dict_impl(
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.18 GiB (GPU 3; 79.32 GiB total capacity; 75.41 GiB already allocated; 891.56 MiB free; 76.51 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
File "/miniconda3/envs/LLAMA2/lib/python3.9/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1154, in _optim_state_dict_impl
return _optim_state_dict(
File "/miniconda3/envs/LLAMA2/lib/python3.9/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1463, in _optim_state_dict
unflat_state = _unflatten_optim_state(
File "/miniconda3/envs/LLAMA2/lib/python3.9/site-packages/torch/distributed/fsdp/_optim_utils.py", line 136, in _unflatten_optim_state
consolidated_state = _communicate_optim_state(
File "/miniconda3/envs/LLAMA2/lib/python3.9/site-packages/torch/distributed/fsdp/_optim_utils.py", line 208, in _communicate_optim_state
tensor_buffer = value.new_zeros(*buffer_size)
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.18 GiB (GPU 2; 79.32 GiB total capacity; 75.41 GiB already allocated; 931.56 MiB free; 76.47 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
`
origin:
state_dict = trainer.model.state_dict() if trainer.args.should_save: cpu_state_dict = { key: value.cpu() for key, value in state_dict.items() } del state_dict trainer._save(output_dir, state_dict=cpu_state_dict) # noqafixed:
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import StateDictType, FullStateDictConfig model = trainer.model save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy): cpu_state_dict = model.state_dict() if trainer.args.should_save: trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
Mark Thanks @ntlm1686