intel-analytics / ipex-llm

Accelerate local LLM inference and finetuning (LLaMA, Mistral, ChatGLM, Qwen, Baichuan, Mixtral, Gemma, Phi, etc.) on Intel CPU and GPU (e.g., local PC with iGPU, discrete GPU such as Arc, Flex and Max); seamlessly integrate with llama.cpp, Ollama, HuggingFace, LangChain, LlamaIndex, DeepSpeed, vLLM, FastChat, Axolotl, etc.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Shape Mismatch with Checkpoint for Deepspeed Zero3

Uxito-Ada opened this issue · comments

Tried to load and distribute model to devices in a layerwise way, by using deepspeed zero3 context manager as below:

with ds.zero.Init(config_dict_or_path=deepspeed):
            model = AutoModelForCausalLM.from_pretrained(
                 base_model,
                 config=model_config,
                 torch_dtype=torch.bfloat16,
                 ignore_mismatched_sizes=True,
                 cpu_embedding=False
            )

Then, the error occured when loading:

Some weights of LlamaForCausalLM were not initialized from the model checkpoint at /mnt/disk1/models/Llama-2-7b-hare newly initialized because the shapes did not match:
- model.embed_tokens.weight: found shape torch.Size([32000, 4096]) in the checkpoint and torch.Size([0]) in the mnstantiated
- model.layers.0.input_layernorm.weight: found shape torch.Size([4096]) in the checkpoint and torch.Size([0]) in del instantiated
- model.layers.0.mlp.down_proj.weight: found shape torch.Size([4096, 11008]) in the checkpoint and torch.Size([0]he model instantiated
- model.layers.0.mlp.gate_proj.weight: found shape torch.Size([11008, 4096]) in the checkpoint and torch.Size([0]he model instantiated
- model.layers.0.mlp.up_proj.weight: found shape torch.Size([11008, 4096]) in the checkpoint and torch.Size([0])  model instantiated
- model.layers.0.post_attention_layernorm.weight: found shape torch.Size([4096]) in the checkpoint and torch.Sizein the model instantiated
......
       You may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method.

After setting ignore_mismatched_sizes=True as suggested, native huggingface transformers can work well for QLoRA training, while by contrast, BigDL model will still report a error when forwarding, which is profiled and found that, base_layer is still empty with shape [0] and thus computes an empty result as well:

File "/home/arda/heyang/ipex-llm/python/llm/example/GPU/LLM-Finetuning/QLoRA/alpaca-qlora/./alpaca_qlora_zero3_finetuning.py", line 253, in train
    trainer.train(resume_from_checkpoint=resume_from_checkpoint)
  File "/opt/anaconda3/envs/heyang-zero/lib/python3.11/site-packages/transformers/trainer.py", line 1537, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/heyang-zero/lib/python3.11/site-packages/transformers/trainer.py", line 1854, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/heyang-zero/lib/python3.11/site-packages/transformers/trainer.py", line 2723, in training_step
    loss = self.compute_loss(model, inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/heyang-zero/lib/python3.11/site-packages/transformers/trainer.py", line 2746, in compute_loss
    outputs = model(**inputs)
              ^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/heyang-zero/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/heyang-zero/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/heyang-zero/lib/python3.11/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/heyang-zero/lib/python3.11/site-packages/deepspeed/runtime/engine.py", line 1822, in forward
    loss = self.module(*inputs, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/heyang-zero/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/heyang-zero/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
    result = forward_call(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/heyang-zero/lib/python3.11/site-packages/peft/peft_model.py", line 1129, in forward
    return self.base_model(
           ^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/heyang-zero/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/heyang-zero/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
    result = forward_call(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/heyang-zero/lib/python3.11/site-packages/peft/tuners/tuners_utils.py", line 161, in forward
    return self.model.forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/heyang-zero/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 1174, in forward
    outputs = self.model(
              ^^^^^^^^^^^
  File "/opt/anaconda3/envs/heyang-zero/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/heyang-zero/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
    result = forward_call(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/heyang-zero/lib/python3.11/site-packages/ipex_llm/transformers/models/llama.py", line 123, in llama_model_forward_4_36
    return llama_model_forward_4_36_internal(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/heyang-zero/lib/python3.11/site-packages/ipex_llm/transformers/models/llama.py", line 1851, in llama_model_forward_4_36_internal
    layer_outputs = self._gradient_checkpointing_func(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/heyang-zero/lib/python3.11/site-packages/torch/_compile.py", line 24, in inner
    return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/heyang-zero/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/heyang-zero/lib/python3.11/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/heyang-zero/lib/python3.11/site-packages/torch/utils/checkpoint.py", line 451, in checkpoint
    return CheckpointFunction.apply(function, preserve, *args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/heyang-zero/lib/python3.11/site-packages/torch/autograd/function.py", line 539, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/heyang-zero/lib/python3.11/site-packages/torch/utils/checkpoint.py", line 230, in forward
    outputs = run_function(*args)
              ^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/heyang-zero/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/heyang-zero/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
    result = forward_call(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/heyang-zero/lib/python3.11/site-packages/ipex_llm/transformers/models/llama.py", line 254, in llama_decoder_forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
                                                          ^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/heyang-zero/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/heyang-zero/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
    result = forward_call(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/heyang-zero/lib/python3.11/site-packages/ipex_llm/transformers/models/llama.py", line 938, in llama_attention_forward_4_36
    return forward_function(
           ^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/heyang-zero/lib/python3.11/site-packages/ipex_llm/transformers/models/llama.py", line 1269, in llama_attention_forward_4_36_original
    query_states = self.q_proj(hidden_states)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/heyang-zero/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/heyang-zero/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
    result = forward_call(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/heyang-zero/lib/python3.11/site-packages/ipex_llm/transformers/qlora.py", line 121, in forward
    result = self.base_layer.forward(x)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/heyang-zero/lib/python3.11/site-packages/ipex_llm/transformers/low_bit_linear.py", line 711, in forward
    result = result.view(new_shape)
             ^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: shape '[8, 256, 4096]' is invalid for input of size 0
  0%|          | 0/1164 [00:02<?, ?it/s]

To reproduce the issue, you can use code below in LLM-finetuning/QLoRA/alpaca:

code
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/tloen/alpaca-lora/blob/main/finetune.py
#
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from typing import List
os.environ["ACCELERATE_USE_XPU"] = "true"
import fire
import torch
from datasets import load_dataset
import accelerate
import transformers

from transformers import AutoTokenizer, BitsAndBytesConfig, AutoConfig
from ipex_llm.transformers import AutoModelForCausalLM
from peft import (
    get_peft_model_state_dict,
    set_peft_model_state_dict,
)

current_dir = os.path.dirname(os.path.realpath(__file__))
common_util_path = os.path.join(current_dir, '..', '..')
import sys
sys.path.append(common_util_path)
from common.utils import Prompter, get_int_from_env, wandb_check, get_train_val_data

from ipex_llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training,\
    LoraConfig
from ipex_llm.utils.common import invalidInputError
import deepspeed as ds
 
local_rank = get_int_from_env(["LOCAL_RANK","MPI_LOCALRANKID"], "0")
world_size = get_int_from_env(["WORLD_SIZE","PMI_SIZE"], "1")
port = get_int_from_env(["MASTER_PORT"], 29500)
os.environ["LOCAL_RANK"] = str(local_rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["RANK"] = str(local_rank)
os.environ["MASTER_PORT"] = str(port)

def train(
    # model/data params
    base_model: str = "meta-llama/Llama-2-7b-hf",  # the only required argument, default to be "meta-llama/Llama-2-7b-hf"
    data_path: str = "yahma/alpaca-cleaned",
    output_dir: str = "./ipex-deepspeed-zero3-qlora-alpaca",
    # training hyperparams
    bf16: bool = True,  # default to bf16
    batch_size: int = 128,
    micro_batch_size: int = 2,  # default to be 2, limited by GPU memory
    num_epochs: int = 3,
    learning_rate: float = 3e-5,  # default to be 3e-5 to avoid divergence
    cutoff_len: int = 256,
    val_set_size: int = 2000,
    # lora hyperparams
    lora_r: int = 8,
    lora_alpha: int = 16,
    lora_dropout: float = 0.05,
    lora_target_modules: List[str] = [
        "q_proj",
        "v_proj",
        "k_proj",
        "o_proj",
        "up_proj",
        "down_proj",
        "gate_proj"
    ],  # according to the QLoRA paper (https://arxiv.org/pdf/2305.14314.pdf), it's suggested to fine tune all linear layers
    # llm hyperparams
    train_on_inputs: bool = True,  # if False, masks out inputs in loss
    add_eos_token: bool = False,
    group_by_length: bool = False,  # faster, but produces an odd training loss curve
    # wandb params
    wandb_project: str = "",
    wandb_run_name: str = "",
    wandb_watch: str = "",  # options: false | gradients | all
    wandb_log_model: str = "",  # options: false | true
    resume_from_checkpoint: str = None,  # either training checkpoint or final adapter
    prompt_template_name: str = "alpaca",  # The prompt template to use, will default to alpaca.
    gradient_checkpointing: bool = False,
    deepspeed: str = None,
    training_mode: str = "qlora",
):
    invalidInputError(training_mode == "qlora",
                      f"This example is for qlora training mode, but got training_mode={training_mode}.")
    if int(os.environ.get("LOCAL_RANK", 0)) == 0:
        print(
            f"Training Alpaca-LoRA model with params:\n"
            f"base_model: {base_model}\n"
            f"data_path: {data_path}\n"
            f"output_dir: {output_dir}\n"
            f"batch_size: {batch_size}\n"
            f"micro_batch_size: {micro_batch_size}\n"
            f"num_epochs: {num_epochs}\n"
            f"learning_rate: {learning_rate}\n"
            f"cutoff_len: {cutoff_len}\n"
            f"val_set_size: {val_set_size}\n"
            f"lora_r: {lora_r}\n"
            f"lora_alpha: {lora_alpha}\n"
            f"lora_dropout: {lora_dropout}\n"
            f"lora_target_modules: {lora_target_modules}\n"
            f"train_on_inputs: {train_on_inputs}\n"
            f"add_eos_token: {add_eos_token}\n"
            f"group_by_length: {group_by_length}\n"
            f"wandb_project: {wandb_project}\n"
            f"wandb_run_name: {wandb_run_name}\n"
            f"wandb_watch: {wandb_watch}\n"
            f"wandb_log_model: {wandb_log_model}\n"
            f"resume_from_checkpoint: {resume_from_checkpoint or False}\n"
            f"prompt template: {prompt_template_name}\n"
            f"training_mode: {training_mode}\n"
        )
    assert (
        base_model
    ), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"
    gradient_accumulation_steps = batch_size // micro_batch_size

    prompter = Prompter(prompt_template_name)

    device_map = "auto"
    world_size = int(os.environ.get("WORLD_SIZE", 1))
    ddp = world_size != 1
    if ddp:
        device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
        gradient_accumulation_steps = gradient_accumulation_steps // world_size

    # Check if parameter passed or if set within environ
    use_wandb = wandb_check(wandb_project, wandb_watch, wandb_log_model)

    
    model_config = model_config = AutoConfig.from_pretrained(base_model)
    with ds.zero.Init(config_dict_or_path=deepspeed):
        model = AutoModelForCausalLM.from_pretrained(
            base_model,
            config=model_config,
            torch_dtype=torch.bfloat16,
            load_in_low_bit="nf4",
            ignore_mismatched_sizes=True,
        )
        
    def move_to_xpu(module):
      if hasattr(module, "weight") and module.weight.data.device.type == "cpu":
          module.weight.data = module.weight.data.to(f"xpu:{local_rank}")

    from transformers import LlamaTokenizer
    tokenizer = LlamaTokenizer.from_pretrained(base_model, trust_remote_code=True)
    print(f"Tokenizer loaded on rank {os.environ.get('LOCAL_RANK')}")

    tokenizer.pad_token_id = (
        0  # unk. we want this to be different from the eos token
    )
    tokenizer.padding_side = "left"  # Allow batched inference

    print(model)

    # Prepare a IPEX-LLM compatible Peft model
    model = prepare_model_for_kbit_training(model,
                                            use_gradient_checkpointing=gradient_checkpointing)
                                            #enable_deepspeed_zero3=True)

    config = LoraConfig(
        r=lora_r,
        lora_alpha=lora_alpha,
        target_modules=lora_target_modules,
        lora_dropout=lora_dropout,
        bias="none",
        task_type="CAUSAL_LM",
        training_mode=training_mode,
    )
    print(f"Lora Config: {config}")
    model = get_peft_model(model, config) #enable_deepspeed_zero3=True)
    
    model.apply(move_to_xpu)

    if data_path.endswith(".json") or data_path.endswith(".jsonl"):
        data = load_dataset("json", data_files=data_path)
    else:
        data = load_dataset(data_path)

    model.print_trainable_parameters()  # Be more transparent about the % of trainable params.

    train_data, val_data = get_train_val_data(data, tokenizer, prompter, train_on_inputs,
                                              add_eos_token, cutoff_len, val_set_size, seed=42)

    def print_device(module):
        if hasattr(module, "weight"):
            print(module.weight.data.device)

    model.apply(print_device)
    trainer = transformers.Trainer(
        model=model,
        train_dataset=train_data,
        eval_dataset=val_data,
        args=transformers.TrainingArguments(
            per_device_train_batch_size=micro_batch_size,
            gradient_accumulation_steps=gradient_accumulation_steps,
            # warmup_ratio=0.03,
            # warmup_steps=100,
            max_grad_norm=0.3,
            num_train_epochs=num_epochs,
            learning_rate=learning_rate,
            lr_scheduler_type="cosine",
            bf16=True,  # ensure training more stable
            logging_steps=1,
            optim="adamw_torch",
            evaluation_strategy="steps" if val_set_size > 0 else "no",
            save_strategy="steps",
            eval_steps=100 if val_set_size > 0 else None,
            save_steps=100,
            output_dir=output_dir,
            save_total_limit=100,
            load_best_model_at_end=True if val_set_size > 0 else False,
            ddp_find_unused_parameters=False if ddp else None,
            group_by_length=group_by_length,
            report_to="wandb" if use_wandb else None,
            run_name=wandb_run_name if use_wandb else None,
            gradient_checkpointing=gradient_checkpointing,
            ddp_backend="ccl",
            deepspeed=deepspeed,
            save_safetensors=False,
        ),
        data_collator=transformers.DataCollatorForSeq2Seq(
            tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
        ),
    )
    model.config.use_cache = False

    trainer.train(resume_from_checkpoint=resume_from_checkpoint)

    model.save_pretrained(output_dir)

    print(
        "\n If there's a warning about missing keys above, please disregard :)"
    )


if __name__ == "__main__":
    fire.Fire(train)
deepspeed_config
{
    "zero_optimization": {
      "stage": 3,
      "contiguous_gradients": true,
      "overlap_comm": true,
      "offload_optimizer": {"device": "cpu"},
      "offload_param": {"device": "cpu"}
    },
    "bf16": {
      "enabled": true
    },
    "world_size":2,
    "train_batch_size": 128,
    "train_micro_batch_size_per_gpu": 8,
    "gradient_accumulation_steps": 8
}

@leonardozcm @glorysdj

This can be fixed by disabling deepspeed's post_init hook and hardcoded qtype, WIP in #11048 .