Cannot flatten integer dtype tensors
jaywongs opened this issue · comments
Thank you guys for your work!
i was using fsdp + qlora fine tuning llama3 70B on 8* A100 80G, and i encountered this error:
Traceback (most recent call last):
File "/mnt/209180/qishi/project/alignment-handbook/scripts/run_sft.py", line 233, in <module>
main()
File "/mnt/209180/qishi/project/alignment-handbook/scripts/run_sft.py", line 188, in main
train_result = trainer.train(resume_from_checkpoint=checkpoint)
File "/root/anaconda3/envs/handbook/lib/python3.10/site-packages/trl/trainer/sft_trainer.py", line 361, in train
output = super().train(*args, **kwargs)
File "/root/anaconda3/envs/handbook/lib/python3.10/site-packages/transformers/trainer.py", line 1859, in train
return inner_training_loop(
File "/root/anaconda3/envs/handbook/lib/python3.10/site-packages/transformers/trainer.py", line 2002, in _inner_training_loop
self.model = self.accelerator.prepare(self.model)
File "/root/anaconda3/envs/handbook/lib/python3.10/site-packages/accelerate/accelerator.py", line 1292, in prepare
result = tuple(
File "/root/anaconda3/envs/handbook/lib/python3.10/site-packages/accelerate/accelerator.py", line 1293, in <genexpr>
self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
File "/root/anaconda3/envs/handbook/lib/python3.10/site-packages/accelerate/accelerator.py", line 1169, in _prepare_one
return self.prepare_model(obj, device_placement=device_placement)
File "/root/anaconda3/envs/handbook/lib/python3.10/site-packages/accelerate/accelerator.py", line 1459, in prepare_model
model = FSDP(model, **kwargs)
File "/root/anaconda3/envs/handbook/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 463, in __init__
_auto_wrap(
File "/root/anaconda3/envs/handbook/lib/python3.10/site-packages/torch/distributed/fsdp/_wrap_utils.py", line 101, in _auto_wrap
_recursive_wrap(**recursive_wrap_kwargs, **root_kwargs) # type: ignore[arg-type]
File "/root/anaconda3/envs/handbook/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 537, in _recursive_wrap
wrapped_child, num_wrapped_params = _recursive_wrap(
File "/root/anaconda3/envs/handbook/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 537, in _recursive_wrap
wrapped_child, num_wrapped_params = _recursive_wrap(
File "/root/anaconda3/envs/handbook/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 537, in _recursive_wrap
wrapped_child, num_wrapped_params = _recursive_wrap(
[Previous line repeated 2 more times]
File "/root/anaconda3/envs/handbook/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 555, in _recursive_wrap
return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel
File "/root/anaconda3/envs/handbook/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 484, in _wrap
return wrapper_cls(module, **kwargs)
File "/root/anaconda3/envs/handbook/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 487, in __init__
_init_param_handle_from_module(
File "/root/anaconda3/envs/handbook/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 519, in _init_param_handle_from_module
_init_param_handle_from_params(state, managed_params, fully_sharded_module)
File "/root/anaconda3/envs/handbook/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 531, in _init_param_handle_from_params
handle = FlatParamHandle(
File "/root/anaconda3/envs/handbook/lib/python3.10/site-packages/torch/distributed/fsdp/flat_param.py", line 537, in __init__
self._init_flat_param_and_metadata(
File "/root/anaconda3/envs/handbook/lib/python3.10/site-packages/torch/distributed/fsdp/flat_param.py", line 585, in _init_flat_param_and_metadata
) = self._validate_tensors_to_flatten(params)
File "/root/anaconda3/envs/handbook/lib/python3.10/site-packages/torch/distributed/fsdp/flat_param.py", line 720, in _validate_tensors_to_flatten
raise ValueError("Cannot flatten integer dtype tensors")
ValueError: Cannot flatten integer dtype tensors
my config :
# Model arguments
model_name_or_path: ./models/Meta-Llama-3-70B
model_revision: main
torch_dtype: bfloat16
use_flash_attention_2: true
# LoRA arguments
load_in_4bit: true
use_peft: true
lora_r: 16
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- gate_proj
- up_proj
- down_proj
# Data training arguments
chat_template: "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}"
dataset_mixer:
HuggingFaceH4/ultrafeedback_binarized: 1.0
dataset_splits:
- train_sft
- test_sft
preprocessing_num_workers: 12
# SFT trainer config
bf16: true
do_eval: true
evaluation_strategy: steps
eval_steps: 300
gradient_accumulation_steps: 1
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id:
learning_rate: 2.0e-04
log_level: info
logging_steps: 5
logging_strategy: steps
lr_scheduler_type: cosine
max_seq_length: 2048
max_steps: -1
num_train_epochs: 1
output_dir: ./output/Meta-Llama-3-70B-sft
overwrite_output_dir: true
per_device_eval_batch_size: 1
per_device_train_batch_size: 1
push_to_hub: false
report_to:
- tensorboard
save_strategy: "steps"
save_steps: 100
save_total_limit: 1
seed: 42
warmup_ratio: 0.1
My pip list:
Package Version
----------------------- -----------
absl-py 2.1.0
accelerate 0.30.1
aiohttp 3.9.5
aiosignal 1.3.1
alignment-handbook 0.4.0.dev0
annotated-types 0.6.0
async-timeout 4.0.3
attrs 23.2.0
bitsandbytes 0.43.1
Brotli 1.0.9
certifi 2024.2.2
charset-normalizer 2.0.4
datasets 2.19.1
deepspeed 0.12.2
dill 0.3.8
docstring_parser 0.16
einops 0.8.0
evaluate 0.4.0
filelock 3.13.1
flash-attn 2.5.8
frozenlist 1.4.1
fsspec 2024.3.1
gmpy2 2.1.2
grpcio 1.63.0
hf_transfer 0.1.6
hjson 3.1.0
huggingface-hub 0.23.0
idna 3.7
Jinja2 3.1.3
Markdown 3.6
markdown-it-py 3.0.0
MarkupSafe 2.1.3
mdurl 0.1.2
mkl-fft 1.3.8
mkl-random 1.2.4
mkl-service 2.4.0
mpmath 1.3.0
multidict 6.0.5
multiprocess 0.70.16
networkx 3.1
ninja 1.11.1.1
numpy 1.26.4
packaging 24.0
pandas 2.2.2
peft 0.11.1.dev0
pillow 10.3.0
pip 24.0
protobuf 3.20.2
psutil 5.9.8
py-cpuinfo 9.0.0
pyarrow 16.1.0
pyarrow-hotfix 0.6
pydantic 2.7.1
pydantic_core 2.18.2
Pygments 2.18.0
pynvml 11.5.0
PySocks 1.7.1
python-dateutil 2.9.0.post0
pytz 2024.1
PyYAML 6.0.1
regex 2024.5.15
requests 2.31.0
responses 0.18.0
rich 13.7.1
safetensors 0.4.3
scipy 1.13.0
sentencepiece 0.2.0
setuptools 69.5.1
shtab 1.7.1
six 1.16.0
sympy 1.12
tensorboard 2.16.2
tensorboard-data-server 0.7.2
tokenizers 0.19.1
torch 2.1.2
torchaudio 2.1.2
torchvision 0.16.2
tqdm 4.66.4
transformers 4.40.2
triton 2.1.0
trl 0.8.6
typing_extensions 4.11.0
tyro 0.8.4
tzdata 2024.1
urllib3 2.2.1
Werkzeug 3.0.3
wheel 0.43.0
xxhash 3.4.1
yarl 1.9.4
+1