h2oai / h2o-llmstudio

H2O LLM Studio - a framework and no-code GUI for fine-tuning LLMs. Documentation: https://h2oai.github.io/h2o-llmstudio/

Home Page:https://gpt-gm.h2o.ai

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[BUG] Mixed precision not working with bfloat16

maxjeblick opened this issue Β· comments

πŸ› Bug

GradScaler used with mixed precision is not compatible with bfloat16 (see logs below).

To Reproduce

  • Start default experiment, but set dtype=blfoat16
  • Experiment runs fine when mixed precision is disabled
2024-03-07 11:04:34,276 - INFO: Evaluation step: 1105
2024-03-07 11:04:34,816 - INFO: Stop token ids: [tensor([ 529, 29989, 5205, 29989, 29958]), tensor([ 529, 29989, 14032, 415, 29989, 29958]), tensor([ 529, 29989, 12011, 29989, 29958])]
2024-03-07 11:04:36,333 - ERROR: Exception occurred during H2O LLM Studio run:
Traceback (most recent call last):
File "/data/maxjeblick/PyCharmProjects/h2o-llmstudio/train_wave.py", line 106, in
run(cfg=cfg)
File "/data/maxjeblick/PyCharmProjects/h2o-llmstudio/train.py", line 634, in run
val_loss, val_metric = run_train(
File "/data/maxjeblick/PyCharmProjects/h2o-llmstudio/train.py", line 296, in run_train
scaler.step(optimizer) # type: ignore
File "/home/maxjeblick/.local/share/virtualenvs/h2o-llmstudio-07dpqO7E/lib/python3.10/site-packages/torch/cuda/amp/grad_scaler.py", line 446, in step
self.unscale_(optimizer)
File "/home/maxjeblick/.local/share/virtualenvs/h2o-llmstudio-07dpqO7E/lib/python3.10/site-packages/torch/cuda/amp/grad_scaler.py", line 336, in unscale_
optimizer_state["found_inf_per_device"] = self._unscale_grads_(
File "/home/maxjeblick/.local/share/virtualenvs/h2o-llmstudio-07dpqO7E/lib/python3.10/site-packages/torch/cuda/amp/grad_scaler.py", line 277, in _unscale_grads_
torch._amp_foreach_non_finite_check_and_unscale_(
RuntimeError: "_amp_foreach_non_finite_check_and_unscale_cuda" not implemented for 'BFloat16'


LLM Studio version

5a7d2d4

Tried to research this a bit but not sure how to solve it.

Scaler is not needed for bfloat16, we can just disable it there