huggingface / alignment-handbook

Robust recipes to align language models with human and AI preferences

Home Page:https://huggingface.co/HuggingFaceH4

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Cannot flatten integer dtype tensors

jaywongs opened this issue · comments

Thank you guys for your work!

i was using fsdp + qlora fine tuning llama3 70B on 8* A100 80G, and i encountered this error:

Traceback (most recent call last):
  File "/mnt/209180/qishi/project/alignment-handbook/scripts/run_sft.py", line 233, in <module>
    main()
  File "/mnt/209180/qishi/project/alignment-handbook/scripts/run_sft.py", line 188, in main
    train_result = trainer.train(resume_from_checkpoint=checkpoint)
  File "/root/anaconda3/envs/handbook/lib/python3.10/site-packages/trl/trainer/sft_trainer.py", line 361, in train
    output = super().train(*args, **kwargs)
  File "/root/anaconda3/envs/handbook/lib/python3.10/site-packages/transformers/trainer.py", line 1859, in train
    return inner_training_loop(
  File "/root/anaconda3/envs/handbook/lib/python3.10/site-packages/transformers/trainer.py", line 2002, in _inner_training_loop
    self.model = self.accelerator.prepare(self.model)
  File "/root/anaconda3/envs/handbook/lib/python3.10/site-packages/accelerate/accelerator.py", line 1292, in prepare
    result = tuple(
  File "/root/anaconda3/envs/handbook/lib/python3.10/site-packages/accelerate/accelerator.py", line 1293, in <genexpr>
    self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
  File "/root/anaconda3/envs/handbook/lib/python3.10/site-packages/accelerate/accelerator.py", line 1169, in _prepare_one
    return self.prepare_model(obj, device_placement=device_placement)
  File "/root/anaconda3/envs/handbook/lib/python3.10/site-packages/accelerate/accelerator.py", line 1459, in prepare_model
    model = FSDP(model, **kwargs)
  File "/root/anaconda3/envs/handbook/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 463, in __init__
    _auto_wrap(
  File "/root/anaconda3/envs/handbook/lib/python3.10/site-packages/torch/distributed/fsdp/_wrap_utils.py", line 101, in _auto_wrap
    _recursive_wrap(**recursive_wrap_kwargs, **root_kwargs)  # type: ignore[arg-type]
  File "/root/anaconda3/envs/handbook/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 537, in _recursive_wrap
    wrapped_child, num_wrapped_params = _recursive_wrap(
  File "/root/anaconda3/envs/handbook/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 537, in _recursive_wrap
    wrapped_child, num_wrapped_params = _recursive_wrap(
  File "/root/anaconda3/envs/handbook/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 537, in _recursive_wrap
    wrapped_child, num_wrapped_params = _recursive_wrap(
  [Previous line repeated 2 more times]
  File "/root/anaconda3/envs/handbook/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 555, in _recursive_wrap
    return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel
  File "/root/anaconda3/envs/handbook/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 484, in _wrap
    return wrapper_cls(module, **kwargs)
  File "/root/anaconda3/envs/handbook/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 487, in __init__
    _init_param_handle_from_module(
  File "/root/anaconda3/envs/handbook/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 519, in _init_param_handle_from_module
    _init_param_handle_from_params(state, managed_params, fully_sharded_module)
  File "/root/anaconda3/envs/handbook/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 531, in _init_param_handle_from_params
    handle = FlatParamHandle(
  File "/root/anaconda3/envs/handbook/lib/python3.10/site-packages/torch/distributed/fsdp/flat_param.py", line 537, in __init__
    self._init_flat_param_and_metadata(
  File "/root/anaconda3/envs/handbook/lib/python3.10/site-packages/torch/distributed/fsdp/flat_param.py", line 585, in _init_flat_param_and_metadata
    ) = self._validate_tensors_to_flatten(params)
  File "/root/anaconda3/envs/handbook/lib/python3.10/site-packages/torch/distributed/fsdp/flat_param.py", line 720, in _validate_tensors_to_flatten
    raise ValueError("Cannot flatten integer dtype tensors")
ValueError: Cannot flatten integer dtype tensors

my config :

# Model arguments
model_name_or_path: ./models/Meta-Llama-3-70B
model_revision: main
torch_dtype: bfloat16
use_flash_attention_2: true

# LoRA arguments
load_in_4bit: true
use_peft: true
lora_r: 16
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- gate_proj
- up_proj
- down_proj

# Data training arguments
chat_template: "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}"
dataset_mixer:
  HuggingFaceH4/ultrafeedback_binarized: 1.0
dataset_splits:
- train_sft
- test_sft
preprocessing_num_workers: 12

# SFT trainer config
bf16: true
do_eval: true
evaluation_strategy: steps
eval_steps: 300
gradient_accumulation_steps: 1
gradient_checkpointing: true
gradient_checkpointing_kwargs:
  use_reentrant: false
hub_model_id: 
learning_rate: 2.0e-04
log_level: info
logging_steps: 5  
logging_strategy: steps
lr_scheduler_type: cosine
max_seq_length: 2048
max_steps: -1
num_train_epochs: 1
output_dir: ./output/Meta-Llama-3-70B-sft
overwrite_output_dir: true
per_device_eval_batch_size: 1
per_device_train_batch_size: 1
push_to_hub: false
report_to:
- tensorboard
save_strategy: "steps"
save_steps: 100
save_total_limit: 1
seed: 42
warmup_ratio: 0.1

My pip list:

Package                 Version
----------------------- -----------
absl-py                 2.1.0
accelerate              0.30.1
aiohttp                 3.9.5
aiosignal               1.3.1
alignment-handbook      0.4.0.dev0
annotated-types         0.6.0
async-timeout           4.0.3
attrs                   23.2.0
bitsandbytes            0.43.1
Brotli                  1.0.9
certifi                 2024.2.2
charset-normalizer      2.0.4
datasets                2.19.1
deepspeed               0.12.2
dill                    0.3.8
docstring_parser        0.16
einops                  0.8.0
evaluate                0.4.0
filelock                3.13.1
flash-attn              2.5.8
frozenlist              1.4.1
fsspec                  2024.3.1
gmpy2                   2.1.2
grpcio                  1.63.0
hf_transfer             0.1.6
hjson                   3.1.0
huggingface-hub         0.23.0
idna                    3.7
Jinja2                  3.1.3
Markdown                3.6
markdown-it-py          3.0.0
MarkupSafe              2.1.3
mdurl                   0.1.2
mkl-fft                 1.3.8
mkl-random              1.2.4
mkl-service             2.4.0
mpmath                  1.3.0
multidict               6.0.5
multiprocess            0.70.16
networkx                3.1
ninja                   1.11.1.1
numpy                   1.26.4
packaging               24.0
pandas                  2.2.2
peft                    0.11.1.dev0
pillow                  10.3.0
pip                     24.0
protobuf                3.20.2
psutil                  5.9.8
py-cpuinfo              9.0.0
pyarrow                 16.1.0
pyarrow-hotfix          0.6
pydantic                2.7.1
pydantic_core           2.18.2
Pygments                2.18.0
pynvml                  11.5.0
PySocks                 1.7.1
python-dateutil         2.9.0.post0
pytz                    2024.1
PyYAML                  6.0.1
regex                   2024.5.15
requests                2.31.0
responses               0.18.0
rich                    13.7.1
safetensors             0.4.3
scipy                   1.13.0
sentencepiece           0.2.0
setuptools              69.5.1
shtab                   1.7.1
six                     1.16.0
sympy                   1.12
tensorboard             2.16.2
tensorboard-data-server 0.7.2
tokenizers              0.19.1
torch                   2.1.2
torchaudio              2.1.2
torchvision             0.16.2
tqdm                    4.66.4
transformers            4.40.2
triton                  2.1.0
trl                     0.8.6
typing_extensions       4.11.0
tyro                    0.8.4
tzdata                  2024.1
urllib3                 2.2.1
Werkzeug                3.0.3
wheel                   0.43.0
xxhash                  3.4.1
yarl                    1.9.4

+1