microsoft / Megatron-DeepSpeed

Ongoing research training transformer language models at scale, including: BERT & GPT-2

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Loss curve changes after checkpoint load

okoge-kaz opened this issue · comments

Overview

I am training 175B model using Megatron-DeepSpeed.
The loss changes when the checkpoint is loaded and training is resumed.

Training settings

Training script

175B: https://github.com/llm-jp/Megatron-DeepSpeed/blob/feature/fujii/abci/scripts/abci/175B/49node/dp2_tp4_pp49_zero1_flash_attn_rottary_bf16_100K.sh
175B ds_config: https://github.com/llm-jp/Megatron-DeepSpeed/blob/feature/fujii/abci/examples_deepspeed/rebase/ds_config_gpt_TEMPLATE_bf16.json
deepspeed version: deepspeed 0.11.1

1.3B: https://github.com/llm-jp/Megatron-DeepSpeed/blob/feature/fujii/abci/scripts/abci/1.3B/dp16_tp1_pp1_zero1.sh
1.3B ds_config: https://github.com/llm-jp/Megatron-DeepSpeed/blob/feature/fujii/abci/examples_deepspeed/rebase/ds_config_gpt_TEMPLATE_bf16.json
deepspeed version: deepspeed 0.11.1

Cluster Info

175B: A100(40GB) 8GPU x 49 Node
1.3B: A100(40GB) 8GPU x 2 node

Detail

1.3B

wandb project: https://wandb.ai/okoge/megatron-deepspeed-loss-spike/workspace

experiments:

At first glance, it looks fine.
image

but when enlarged, the loss is misaligned.
image
image

175B

this experiment's wandb log is private, sorry.
But the situation is as shown in the screenshots below

image

image

Supplement Info

In the 1.3B training, I confirmed that using Rotary Positional Embeddings and Flash-Attention caused a change in the Loss Curve.

No RoPE, Flash Attention

wandb runs:

image
image

Using RoPE, Flash Attention

wandb runs

image

Related Issue: #248

image

deepspeed version: 0.11.1

@okoge-kaz Thank you for the detailed report. We reproduced the mismatch of loss values after loading a checkpoint. After running small-scale experiments, we found that some kernels do not reproduce the exact same results. Can you try the followings to make sure if you are seeing the same issue? I observed the loss values exactly matched after these changes.

  • Add --no-bias-dropout-fusion and --no-bias-gelu-fusion
  • Remove --use-flash-attn

@okoge-kaz Just for clarification, can you let us know the version of Megatron-DeepSpeed you used?
I didn't find --pp-partition-method used in the training script in the current version.

@okoge-kaz Just for clarification, can you let us know the version of Megatron-DeepSpeed you used? I didn't find --pp-partition-method used in the training script in the current version.

In our project, the vocab size is large(about 100K), so we have added a process that considers the embedding layer and lm-head as a transformer block, based on the implementation of the BLOOM project.

This is the implementation in BLOOM
https://github.com/bigscience-workshop/Megatron-DeepSpeed/blob/main/megatron/model/gpt_model.py#L310-L319

@okoge-kaz Thank you for the detailed report. We reproduced the mismatch of loss values after loading a checkpoint. After running small-scale experiments, we found that some kernels do not reproduce the exact same results. Can you try the followings to make sure if you are seeing the same issue? I observed the loss values exactly matched after these changes.

  • Add --no-bias-dropout-fusion and --no-bias-gelu-fusion
  • Remove --use-flash-attn

@tohtana
Thank you very much for your reply!
I will try it as soon as I have some computing resources available.

By the way, Is it possible to avoid the checkpoint issue when using flash-attention?

As additional information, I also faced the phenomenon of loss curve change when checkpoints are loaded when I am using Megatron-LM for continuous learning of Llama-2. I hope this information will be useful for debugging.

image

@tohtana
Sorry for the late reply.
Adding --no-bias-dropout-fusion and --no-bias-gelu-fusionand removing --use-flash-attn, the loss matches exactly!

image
Screenshot 2023-11-10 at 0 18 12

Thank you very much!
By the way, is there any way to use flash-attention and RoPE at the same time and still the Loss match exactly?

This is a very important issue for us as we would like to use Megatron-DeepSpeed when we create our next model.

Thanks.

@okoge-kaz Let me share some related information.

  • The backward pass of FlashAttention is non-deterministic, meaning we cannot reproduce exact same results as long as we use FlashAttention.
  • According to README, FlashAttention produces larger numerical errors than the reference implementation of PyTorch. This could potentially impact training stability, especially in very large models.

Unfortunately, there's no definitive solution to these issues yet. Then, I do have a couple of suggestions:

  • If you've only experimented with FlashAttention V1, consider trying out V2 or the Triton version. This document provides guidance on using the Triton version with Megatron-DeepSpeed.
  • The primary cause of training instability is still unclear. Given that the training instability was observed after loading a checkpoint, I recommend adding --no-bias-dropout-fusion with FlashAttention enabled. This will show a non-deterministic behavior, but the stability might change if the issue is related to fused dropout.

@tohtana Thank you for your detailed investigation of the loss instability issue.

We will have the opportunity to train the 175B model in early December (but only for a few hours), so I will use the --no-bias-dropout-fusion argument you showed to check the instability of Loss.

Also, I was using --pp-partition-method, which was not in the original argument options of Megatron-DeepSpeed, so I will investigate the behavior without it.

I will keep the Issue Open to share findings here.

Again, thanks for your detailed investigation.

Let me share the results of experiments.

1.3B DP=2, PP=2, TP=2 ZeRO Stage1 (A100 40GB x 8)
Flash Attention version 2.3.0, DeepSpeed version 0.11.1

  1. without --pp-partition-method, with --no-bias-dropout-fusion
    script: llm-jp/Megatron-DeepSpeed@d6648e9#diff-511554af37bf0b364e3a9cb65bac8e8ea4499385c3827510a29b531c74111fba

image
image

Observing about 100 iterations after checkpoint load, the loss after checkpoint load is always high. Therefore, it seems that non-deteministic reasons are not the only reasons.

  1. without --pp-partition-method, without --no-bias-dropout-fusion
    script: llm-jp/Megatron-DeepSpeed@d6648e9#diff-e61698a0b9307e2d7c1db383d4720d4d9b1bdc2d7cefb1bc7ed7ec4617edb53b

image
image

same above

I will also do some experiments on what happens when the DeepSpeed version is brought up to date.
And I am using also flash-attention version 1.x to observe what happens.

(all logging data: https://wandb.ai/okoge/megatron-deepspeed-loss-spike/workspace?workspace=user-okoge)

1.3B DP=2, PP=2, TP=2 ZeRO Stage1 (A100 40GB x 8)
Flash Attention version 2.3.3, DeepSpeed version 0.12.3
PyTorch 2.1.0+cu118

  1. without --pp-partition-method, with --no-bias-dropout-fusion
    image
    image
    image

wandb run1: https://wandb.ai/okoge/megatron-deepspeed-loss-spike/runs/mn07xosl
wandb run2: https://wandb.ai/okoge/megatron-deepspeed-loss-spike/runs/wgmtemzr
wandb run3: https://wandb.ai/okoge/megatron-deepspeed-loss-spike/runs/ngs5uu7j

  1. without --pp-partition-method, without --no-bias-dropout-fusion
    image
    image
    image

wandb run1: https://wandb.ai/okoge/megatron-deepspeed-loss-spike/runs/k6z91rag
wandb run2: https://wandb.ai/okoge/megatron-deepspeed-loss-spike/runs/jo52aqd7
wandb run3: https://wandb.ai/okoge/megatron-deepspeed-loss-spike/runs/4zf8454c

I have changed the version of DeepSpeed to 0.12.3, but the loss is still high after checkpoint load.

1.3B DP=2, PP=2, TP=2 ZeRO Stage1 (A100 40GB x 8)
Flash Attention version 2.3.3, DeepSpeed version 0.12.3
PyTorch 2.1.0+cu118

  1. with --no-bias-dropout-fusion and with --no-bias-gelu-fusion
    image
    image

wandb run1: https://wandb.ai/okoge/megatron-deepspeed-loss-spike/runs/g9whmwxb
wandb run2: https://wandb.ai/okoge/megatron-deepspeed-loss-spike/runs/1rzvz4s4

1.3B DP=2, PP=2, TP=2 ZeRO Stage1 (A100 40GB x 8)
Flash Attention version 1.0.4, DeepSpeed version 0.12.3
PyTorch 2.1.0+cu118

with --no-bias-dropout-fusion and with --no-bias-gelu-fusion

image
image

I tried changing the version of FlashAttention, but that did not solve the issue.

In Megatron-LM, this phenomenon does not occur (only smaller loss curve gaps) when flash-attention is enabled, but why?

wandb: run1 https://wandb.ai/okoge/megatron-deepspeed-loss-spike/runs/7tat5r6d
wandb: run2 https://wandb.ai/okoge/megatron-deepspeed-loss-spike/runs/pdtwiloa

There was a bug related to setting the dropout probability for FlashAttention.
It has been addressed in this PR.

Loss no longer significantly shifted after checkpoint load even with flash-attn enabled.
Thank you very much! Very helpful!

image
image
image

Thanks, @tohtana for investigating the checkpoint load bug very precisely.