Getting AssertionError when saving a FSDP strat trained model with 16-mixed precision
JimenezBarreroDavid opened this issue · comments
JimenezBarreroDavid commented
Hello, I am trying to train a model using 16-mixed precision, using the full.py
script. When the model is saved it throws the following error:
File "lib/python3.10/site-packages/torch/distributed/fsdp/_unshard_param_utils.py", line 186, in - _unshard_fsdp_state_params
assert
AssertionError: Expects the handle training to be IDLE but got HandleTrainingState.FORWARD
This error happens in this part of the lit_llama/lit_llama/utils.py
file
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
state_dict = model._forward_module.state_dict()
This did not happen with 32-true precision, do you maybe have an idea of what might be going wrong?
JimenezBarreroDavid commented
After making this change, I was able to save the models and avoid the error:
mixed_precision = MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True)
strategy = FSDPStrategy(auto_wrap_policy=auto_wrap_policy, activation_checkpointing=Block, limit_all_gathers=True, mixed_precision=mixed_precision)