[BUG] FP8+PP+Recompute+GA>1, loss = nan
jingjie01ai opened this issue · comments
jingjie01ai commented
Describe the bug
FP8+PP+Recompute+GA>1, loss = nan
FP8+PP+GA>1, loss is normal
FP8+PP+Recompute+GA=1, loss is normal
FP8+TP+Recompute+GA>1, loss is normal
jingjie01ai commented
refer: NVIDIA/TransformerEngine#539