EleutherAI / gpt-neox

An implementation of model parallel autoregressive transformers on GPUs, based on the Megatron and DeepSpeed libraries

Home Page:https://www.eleuther.ai/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Converting Pythia checkpoint from HF to NeoX fails

malteos opened this issue · comments

Describe the bug

Converting Pythia checkpoint from HF to NeoX fails with a missing key error regarding the rotary embeddings.

To Reproduce
Steps to reproduce the behavior:

I am running this command to convert the Pythia 410M checkpoint to NeoX (for continued pretraining):

OMPI_COMM_WORLD_RANK=0 CUDA_VISIBLE_DEVICES=0 python $NEOX_DIR/tools/ckpts/convert_hf_to_sequential.py \
>     --hf-model-name pythia-410m \
>     --revision 143000 \
>     --output-dir $BASE_DIR/data/pythia-410m/neox_converted_checkpoints/ \
>     --cache-dir $TRANSFORMERS_CACHE \
>     --config $BASE_DIR/neox_configs/continued-pythia-410m_pegasus.yml \
>     --test

Error trace:

Traceback (most recent call last):
  File "/netscratch/experiments/gpt-neox/tools/ckpts/convert_hf_to_sequential.py", line 581, in <module>
    load_checkpoint(
  File "/netscratch/experiments/gpt-neox/megatron/checkpointing.py", line 390, in load_checkpoint
    checkpoint_name, state_dict = model.load_checkpoint(
  File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/engine.py", line 2599, in load_checkpoint
    load_path, client_states = self._load_checkpoint(load_dir,
  File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/engine.py", line 2662, in _load_checkpoint
    self.load_module_state_dict(checkpoint=checkpoint,
  File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/pipe/engine.py", line 1274, in load_module_state_dict
    self.module.load_state_dir(load_dir=self._curr_ckpt_path,
  File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/pipe/module.py", line 598, in load_state_dir
    layer.load_state_dict(checkpoint)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1667, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for ParallelTransformerLayerPipe:
        Missing key(s) in state_dict: "attention.rotary_emb.inv_freq".

Expected behavior

Conversion to NeoX without any error.

Proposed solution

From my understanding attention.rotary_emb.inv_freq is not a trainable parameter and thus should not be loaded from the state dict.

Environment (please complete the following information):

Thanks for your amazing project!

Hi! You can get around this via adding persistent=False to register_buffer("inv_freq".... calls in the NeoX library, for now.

What's your Huggingface version? seems the culprit is this change huggingface/transformers@253f9a3 which made inv_freq non-persistent on the HF side--I was under the impression they reverted this change but it seems I was wrong about that.

Will probably update this buffer to non-persistent in GPT-NeoX, but will need to check that this does not break others' existing checkpoints.

Thanks for the quick response. Adding persistent=False to the register_buffer calls fixed the problem!

Reopening this to track it since we haven't merged a fix yet!