Resuming training fails
hidoba opened this issue · comments
So I removed the flag --overwrite_output_dir
to be able to resume the training, and I'm getting the following error:
04/01/2024 00:30:01 - INFO - __main__ - max_steps is given, it will override any value given in num_train_epochs
04/01/2024 00:30:04 - INFO - __main__ - ***** Running training *****
04/01/2024 00:30:04 - INFO - __main__ - Num examples = 4800000
04/01/2024 00:30:04 - INFO - __main__ - Instantaneous batch size per device = 8
04/01/2024 00:30:04 - INFO - __main__ - Gradient accumulation steps = 1
04/01/2024 00:30:04 - INFO - __main__ - Total train batch size (w. parallel & distributed) = 8
04/01/2024 00:30:04 - INFO - __main__ - Total optimization steps = 600000
Train steps ... : 0%| | 0/600000 [00:00<?, ?it/s]04/01/2024 00:30:04 - INFO - accelerate.accelerator - Loading states from ./checkpoint-5000-epoch-0
Traceback (most recent call last):
File "/home/vlad/distil-whisper/training/run_distillation.py", line 1682, in <module>
main()
File "/home/vlad/distil-whisper/training/run_distillation.py", line 1484, in main
accelerator.load_state(checkpoint)
File "/home/vlad/distil-whisper/.venv/lib/python3.10/site-packages/accelerate/accelerator.py", line 2966, in load_state
load_accelerator_state(
File "/home/vlad/distil-whisper/.venv/lib/python3.10/site-packages/accelerate/checkpointing.py", line 205, in load_accelerator_state
models[i].load_state_dict(state_dict, **load_model_func_kwargs)
File "/home/vlad/distil-whisper/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2153, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for WhisperForConditionalGeneration:
Missing key(s) in state_dict: "proj_out.weight".
At the same time, evaluation script works just fine with the same checkpoint.
I'm using Ubuntu 22, rtx 3090 ti.
I've also observed this in the log:
04/01/2024 00:35:47 - WARNING - accelerate.utils.other - Removed shared tensor {'proj_out.weight'} while saving. This should be OK, but check by verifying that you don't receive any warning while reloading
any updates on this ? I'm facing the same problem
Here's a temporary fix according to https://huggingface.co/docs/safetensors/torch_shared_tensors
Modify load_accelerator_state()
: https://github.com/huggingface/accelerate/blob/main/src/accelerate/checkpointing.py#L153
-from safetensors.torch import load_file
+from safetensors.torch import load_model
...
if input_model_file.exists():
- state_dict = load_file(input_model_file, device=str(map_location))
+ load_model(models[i], input_model_file, device=str(map_location), **load_model_func_kwargs)
else:
# Load with torch
input_model_file = input_dir.joinpath(f"{MODEL_NAME}{ending}.bin")
state_dict = torch.load(input_model_file, map_location=map_location)
- models[i].load_state_dict(state_dict, **load_model_func_kwargs)
+ models[i].load_state_dict(state_dict, **load_model_func_kwargs)