Issues in using FP8 for MPT baselines on H100
prigoyal opened this issue · comments
Hello,
I am trying to train MPT models using fp8 and currently hitting issues similar to what has been reported in #271 .
The changes I made are: Installing the flash-attn and TransformerEngine:
pip install flash-attn==1.0.7 --no-build-isolation
pip install git+https://github.com/NVIDIA/TransformerEngine.git@v0.10
and then making following changes to our config files (following the tutorials):
precision: amp_fp8
model:
fc_type: te
ffn_config_defaults:
ffn_type: te_ln_mlp
The llm-foundry version I am using is 0.3.0
I would appreciate if you can share any insights into what could be missing from our setup to successfully use fp8. cc @growlix
Seems that your TransformerEngine version is outdated:
Try:
pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable
If that doesn't work try:
git clone https://github.com/NVIDIA/TransformerEngine.git
cd TransformerEngine
pip install -e .
cd ..
thanks @j316chuck , will give that a shot. It might be very helpful to update the README.md as well in case others run into the same issue.