Loss spikes and explode with mpt-1b model pretrain
sagorbrur opened this issue · comments
Hi,
We are training an MPT-1B model with 230 GB of Bi-lingual (Bangla + English) data.
We keep the default config with changes in batch size and start training in the H100 (8GPU) machine.
Loss was decreasing but after 42k iteration it got spiked and started exploding. (Attachment)
We tried different tuning like
- Stop and resume training
- Reduce LR and increase warmup from 100 to 1000
- Change weight decay from 0.0 to 0.1
- Load only weight and start with a small LR
None of the above methods is mitigating our exploding problem.
If anyone has any suggestions to mitigate this problem please let me know.