This issue actually still persists. My python environment:
DrRyanHuang opened this issue · comments
Ryan commented
This issue actually still persists. My python environment:
python 3.10
accelerate 0.25.0
bitsandbytes 0.41.3.post2
black 23.7.0
datasets 2.14.7
flash-attn 2.4.2
huggingface-hub 0.17.3
lion-pytorch 0.1.2
networkx 3.1
numpy 1.26.3
pandas 2.1.4
pip 23.3.1
safetensors 0.4.1
tokenizers 0.14.1
torch 2.1.2
transformers 4.34.1
triton 2.1.0
My training setup is single-node multi-GPU with FSDP. FSDP config is:
fsdp_config:
fsdp_backward_prefetch_policy: BACKWARD_PRE
fsdp_forward_prefetch: true
fsdp_offload_params: false
fsdp_sharding_strategy: FULL_SHARD
fsdp_sync_module_states: true
fsdp_use_orig_params: true
mixed_precision: bf16
i.e., using FULL_SHARD FSDP + bf16 AMP. In such case, use_triton=True
results in problems when resuming training from checkpoint - the gradient scale explodes along with the loss.
If I train with use_triton=False
, save, then resume, there's no problem.
Originally posted by @syncdoth in #20 (comment)