horseee / LLM-Pruner

[NeurIPS 2023] LLM-Pruner: On the Structural Pruning of Large Language Models. Support LLaMA, Llama-2, BLOOM, Vicuna, Baichuan, etc.

Home Page:https://arxiv.org/abs/2305.11627

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Use LLM-Pruner for Baichuan model

Daisy5296 opened this issue · comments

Hi, I am trying to use LLM-Pruner on Baichuan-13B model (https://github.com/baichuan-inc/Baichuan-13B). It is also llama structured so I thought it should work instantly, but I got some errors... I am still trying to debug, but slowly... Any help or advice would be very appreciated!

Specifically, I ran "CUDA_VISIBLE_DEVICES=0,1 python hf_prune_baichuan.py --base_model models/baichuan-13b-chat --pruning_ratio 0.25 --device cpu --eval_device cuda --block_wise --block_mlp_layer_start 4 --block_mlp_layer_end 30 --block_attention_layer_start 4 --block_attention_layer_end 30 --save_ckpt_log_name baichuan_13b_chat_0.2 --pruner_type taylor --test_after_train --taylor param_first --save_model",

and I got the following output:
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [02:14<00:00, 44.74s/it]
2023-07-17 02:29:09 - INFO : Use taylor pruner...
2023-07-17 02:29:09 - INFO : Pruning Attention Layer = [4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29]
2023-07-17 02:29:09 - INFO : Pruning MLP Layer = [4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29]
/dfs/data/LLM-Pruner/LLMPruner/torch_pruning/dependency.py:362: UserWarning: Unwrapped parameters detected: ['model.layers.10.input_layernorm.weight', 'model.layers.17.post_attention_layernorm.weight', 'model.layers.34.input_layernorm.weight', 'model.layers.22.input_layernorm.weight', 'model.layers.29.post_attention_layernorm.weight', 'model.layers.0.post_attention_layernorm.weight', 'model.layers.5.input_layernorm.weight', 'model.layers.12.post_attention_layernorm.weight', 'model.layers.31.input_layernorm.weight', 'model.layers.38.post_attention_layernorm.weight', 'model.layers.2.input_layernorm.weight', 'model.layers.9.post_attention_layernorm.weight', 'model.layers.26.input_layernorm.weight', 'model.layers.33.post_attention_layernorm.weight', 'model.layers.14.input_layernorm.weight', 'model.layers.21.post_attention_layernorm.weight', 'model.layers.35.input_layernorm.weight', 'model.layers.4.post_attention_layernorm.weight', 'model.layers.23.input_layernorm.weight', 'model.layers.30.post_attention_layernorm.weight', 'model.layers.1.post_attention_layernorm.weight', 'model.layers.18.input_layernorm.weight', 'model.layers.25.post_attention_layernorm.weight', 'model.layers.6.input_layernorm.weight', 'model.layers.13.post_attention_layernorm.weight', 'model.layers.32.input_layernorm.weight', 'model.layers.39.post_attention_layernorm.weight', 'model.layers.34.post_attention_layernorm.weight', 'model.layers.3.input_layernorm.weight', 'model.layers.10.post_attention_layernorm.weight', 'model.layers.27.input_layernorm.weight', 'model.layers.15.input_layernorm.weight', 'model.layers.22.post_attention_layernorm.weight', 'model.layers.36.input_layernorm.weight', 'model.layers.5.post_attention_layernorm.weight', 'model.layers.24.input_layernorm.weight', 'model.layers.31.post_attention_layernorm.weight', 'model.layers.2.post_attention_layernorm.weight', 'model.layers.19.input_layernorm.weight', 'model.layers.26.post_attention_layernorm.weight', 'model.layers.7.input_layernorm.weight', 'model.layers.14.post_attention_layernorm.weight', 'model.layers.28.input_layernorm.weight', 'model.layers.35.post_attention_layernorm.weight', 'model.layers.16.input_layernorm.weight', 'model.layers.23.post_attention_layernorm.weight', 'model.layers.11.input_layernorm.weight', 'model.layers.18.post_attention_layernorm.weight', 'model.layers.37.input_layernorm.weight', 'model.layers.6.post_attention_layernorm.weight', 'model.layers.25.input_layernorm.weight', 'model.layers.32.post_attention_layernorm.weight', 'model.norm.weight', 'model.layers.3.post_attention_layernorm.weight', 'model.layers.20.input_layernorm.weight', 'model.layers.27.post_attention_layernorm.weight', 'model.layers.8.input_layernorm.weight', 'model.layers.15.post_attention_layernorm.weight', 'model.layers.29.input_layernorm.weight', 'model.layers.36.post_attention_layernorm.weight', 'model.layers.17.input_layernorm.weight', 'model.layers.24.post_attention_layernorm.weight', 'model.layers.38.input_layernorm.weight', 'model.layers.12.input_layernorm.weight', 'model.layers.19.post_attention_layernorm.weight', 'model.layers.0.input_layernorm.weight', 'model.layers.7.post_attention_layernorm.weight', 'model.layers.21.input_layernorm.weight', 'model.layers.28.post_attention_layernorm.weight', 'model.layers.9.input_layernorm.weight', 'model.layers.16.post_attention_layernorm.weight', 'model.layers.33.input_layernorm.weight', 'model.layers.4.input_layernorm.weight', 'model.layers.11.post_attention_layernorm.weight', 'model.layers.30.input_layernorm.weight', 'model.layers.37.post_attention_layernorm.weight', 'model.layers.13.input_layernorm.weight', 'model.layers.20.post_attention_layernorm.weight', 'model.layers.39.input_layernorm.weight', 'model.layers.1.input_layernorm.weight', 'model.layers.8.post_attention_layernorm.weight'].
Torch-Pruning will prune the last non-singleton dimension of a parameter. If you wish to customize this behavior, please provide an unwrapped_parameters argument.
warnings.warn("Unwrapped parameters detected: {}.\n Torch-Pruning will prune the last non-singleton dimension of a parameter. If you wish to customize this behavior, please provide an unwrapped_parameters argument.".format([_param_to_name[p] for p in unwrapped_detected]))
2023-07-17 02:30:02 - INFO : Start Pruning
2023-07-17 02:30:02 - WARNING : Found cached dataset bookcorpus (/dfs/data/data/bookcorpus/bookcorpus/plain_text/1.0.0/eddee3cae1cc263a431aa98207d4d27fd8a73b0a9742f692af0e6c65afa4d75f)
2023-07-17 02:30:45 - INFO : Start Backwarding in iterative steps = 0...
2023-07-17 02:33:56 - INFO : Loss = 3.644896984100342
Traceback (most recent call last):
File "hf_prune_baichuan.py", line 299, in
main(args)
File "hf_prune_baichuan.py", line 136, in main
pruner.step()
File "/dfs/data/LLM-Pruner/LLMPruner/torch_pruning/pruner/algorithms/metapruner.py", line 179, in step
for group in self.prune_local():
File "/dfs/data/LLM-Pruner/LLMPruner/torch_pruning/pruner/algorithms/metapruner.py", line 238, in prune_local
imp = self.estimate_importance(group, ch_groups=ch_groups, consecutive_groups=consecutive_groups)
File "/dfs/data/LLM-Pruner/LLMPruner/torch_pruning/pruner/algorithms/metapruner.py", line 183, in estimate_importance
return self.importance(group, ch_groups=ch_groups, consecutive_groups=consecutive_groups)
File "/opt/conda/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
return func(*args, **kwargs)
File "/dfs/data/LLM-Pruner/LLMPruner/pruner/hf_baichuan_pruner.py", line 306, in call
local_norm = local_norm[idxs]
IndexError: index 10240 is out of bounds for dimension 0 with size 5120

I modified "hf_prune_llama.py" and "LLMPruner/pruner/hf_llama_pruner.py":
1、replacing the model loading part as:
tokenizer = AutoTokenizer.from_pretrained(args.base_model, use_fast=False, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True)
model.generation_config = GenerationConfig.from_pretrained(args.base_model)
2、replacing all "q_proj, k_proj, v_proj" with "W_pack"

Do you have any advice on quick fixing? Thank you very much!

To resolve the bug, a simple solution is to include the hf_rmsnorm_pruner in the customized_pruners kwargs for both input_layernorm and post_attention_layernorm.

I reviewed the structure of the baichuan model and found certain potential issues that could make the dependency graph inoperable. I am currently investigating these concerns and will provide you with feedback ASAP. If you come across any parts of the code that are not functioning as expected, please feel free to let me know.

Thanks a lot! I'll update my progress here as well.

I modified the customized_pruners kwargs as following:
kwargs = {
"importance": imp,
"global_pruning": args.global_pruning,
"iterative_steps": args.iterative_steps,
"ch_sparsity": args.pruning_ratio,
"ignored_layers":[],
"channel_groups": {
},
"consecutive_groups": {
layer.self_attn.W_pack: layer.self_attn.head_dim for layer in model.model.layers
},
"customized_pruners": {
RMSNorm: baichuan_pruner.hf_rmsnorm_pruner,
},
"root_module_types": None,
"root_instances": [model.model.layers[i].self_attn.W_pack for i in range(args.block_attention_layer_start, args.block_attention_layer_end)] +
[model.model.layers[i].mlp.gate_proj for i in range(args.block_mlp_layer_start, args.block_mlp_layer_end)]
}

in which the RMSNorm is imported from the official baichuan model:
from LLMPruner.models.hf_baichuan.modeling_baichuan import RMSNorm

But the error was still the same as before.

I found the idxs for the o_proj was calculated incorrectly. As can be seen here:
image

But I am not sure which part of the code is wrong currently...

Sorry for interrupting...,but I am kind of stucked here... and I am really keen to get this work...Thanks again!
This "idx" is calculated somewhere in the torch_pruning/dependency.py, right? Or am I even on the right direction of debugging?

Sorry for interrupting...,but I am kind of stucked here... and I am really keen to get this work...Thanks again!
This "idx" is calculated somewhere in the torch_pruning/dependency.py, right? Or am I even on the right direction of debugging?

Hi! I've been tackling the bug of LLM-Pruner on Baichuan, and it's been quite a complex issue related to the dependency graph. Currently I've managed to solve the problem, and now I'm just double-checking the output.

I'll be releasing the code no later than this afternoon, and maybe even sooner. However, I have no idea about the dataset/benchmark to validate the model's performance, so this part is missing in the code.

That's really an exciting news! Thanks for rapid response!

Hi! The code is released.

If you come across any issues or have any experimental results to share, we would greatly appreciate your contribution to our repository🥳

Note: If you're using Taylor's pruner, you can manually select the calibration data in Line 132. The current code still uses English samples, which may not yield optimal results.

Thanks! I've tried out the new version code. But unfortunately, the similar error still exits:

Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [02:03<00:00, 41.21s/it]
Some weights of BaiChuanForCausalLM were not initialized from the model checkpoint at models/baichuan-13b-chat and are newly initialized: ['model.layers.23.self_attn.rotary_emb.inv_freq', 'model.layers.12.self_attn.rotary_emb.inv_freq', 'model.layers.36.self_attn.rotary_emb.inv_freq', 'model.layers.27.self_attn.rotary_emb.inv_freq', 'model.layers.20.self_attn.rotary_emb.inv_freq', 'model.layers.5.self_attn.rotary_emb.inv_freq', 'model.layers.10.self_attn.rotary_emb.inv_freq', 'model.layers.25.self_attn.rotary_emb.inv_freq', 'model.layers.39.self_attn.rotary_emb.inv_freq', 'model.layers.24.self_attn.rotary_emb.inv_freq', 'model.layers.30.self_attn.rotary_emb.inv_freq', 'model.layers.6.self_attn.rotary_emb.inv_freq', 'model.layers.18.self_attn.rotary_emb.inv_freq', 'model.layers.7.self_attn.rotary_emb.inv_freq', 'model.layers.26.self_attn.rotary_emb.inv_freq', 'model.layers.16.self_attn.rotary_emb.inv_freq', 'model.layers.22.self_attn.rotary_emb.inv_freq', 'model.layers.34.self_attn.rotary_emb.inv_freq', 'model.layers.33.self_attn.rotary_emb.inv_freq', 'model.layers.28.self_attn.rotary_emb.inv_freq', 'model.layers.4.self_attn.rotary_emb.inv_freq', 'model.layers.29.self_attn.rotary_emb.inv_freq', 'model.layers.0.self_attn.rotary_emb.inv_freq', 'model.layers.35.self_attn.rotary_emb.inv_freq', 'model.layers.31.self_attn.rotary_emb.inv_freq', 'model.layers.1.self_attn.rotary_emb.inv_freq', 'model.layers.38.self_attn.rotary_emb.inv_freq', 'model.layers.17.self_attn.rotary_emb.inv_freq', 'model.layers.2.self_attn.rotary_emb.inv_freq', 'model.layers.3.self_attn.rotary_emb.inv_freq', 'model.layers.15.self_attn.rotary_emb.inv_freq', 'model.layers.32.self_attn.rotary_emb.inv_freq', 'model.layers.19.self_attn.rotary_emb.inv_freq', 'model.layers.11.self_attn.rotary_emb.inv_freq', 'model.layers.9.self_attn.rotary_emb.inv_freq', 'model.layers.37.self_attn.rotary_emb.inv_freq', 'model.layers.13.self_attn.rotary_emb.inv_freq', 'model.layers.21.self_attn.rotary_emb.inv_freq', 'model.layers.8.self_attn.rotary_emb.inv_freq', 'model.layers.14.self_attn.rotary_emb.inv_freq']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
2023-07-18 06:14:01 - INFO : Use taylor pruner...
2023-07-18 06:14:01 - INFO : Pruning Attention Layer = [4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29]
2023-07-18 06:14:01 - INFO : Pruning MLP Layer = [4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29]
2023-07-18 06:14:39 - INFO : Start Pruning
2023-07-18 06:14:39 - WARNING : Found cached dataset bookcorpus (/dfs/data/data/bookcorpus/bookcorpus/plain_text/1.0.0/eddee3cae1cc263a431aa98207d4d27fd8a73b0a9742f692af0e6c65afa4d75f)
2023-07-18 06:14:45 - INFO : Start Backwarding in iterative steps = 0...
2023-07-18 06:17:19 - INFO : Loss = 4.370472431182861
Traceback (most recent call last):
File "baichuan.py", line 254, in
main(args)
File "baichuan.py", line 155, in main
pruner.step()
File "/dfs/data/LLM-Pruner/LLMPruner/torch_pruning/pruner/algorithms/metapruner.py", line 186, in step
for group in self.prune_local():
File "/dfs/data/LLM-Pruner/LLMPruner/torch_pruning/pruner/algorithms/metapruner.py", line 245, in prune_local
imp = self.estimate_importance(group, ch_groups=ch_groups, consecutive_groups=consecutive_groups)
File "/dfs/data/LLM-Pruner/LLMPruner/torch_pruning/pruner/algorithms/metapruner.py", line 190, in estimate_importance
return self.importance(group, ch_groups=ch_groups, consecutive_groups=consecutive_groups)
File "/opt/conda/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
return func(*args, **kwargs)
File "/dfs/data/LLM-Pruner/LLMPruner/pruner/hf_baichuan_pruner.py", line 181, in call
local_norm = local_norm[idxs]
IndexError: index 10240 is out of bounds for dimension 0 with size 5120

The only difference is that I use baichuan-13b-chat model. Would that induce any difference?

Can you check if your code uses the model file in LLMPruner/models/hf_baichuan/modeling_baichuan.py or downloads a model file from huggingface hub?

It uses the file in LLMPruner/models/hf_baichuan/modeling_baichuan.py

Can you provide the version of pytorch and huggingface/transformers?

Sure.
torch 1.11.0a0+bfe5ad2
huggingface-hub 0.14.1
transformers 4.30.2

Hi. Can you try the Baichuan-7B?
I tested the code on Baichuan-7B and it works well. I'm not sure if the bug is caused by the different versions of Pytorch or the different implementations between Baichuan-7B and Baichuan-13B.

Sure, I'll try it out and let you know soon

Would you like to join our WeChat or Telegram group? Our communication can be more efficient in that way.

Hi. The pruning & post-training code for Baichuan is updated😆
The pruning code now supports the latest version of Baichuan-13B-chat, and the bug RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn in post_training.py is solved. Please refer to the instruction here