huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.

Home Page:https://huggingface.co/transformers

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Mistral loss instability

teknium1 opened this issue · comments

System Info

Hello, I've been working with dhokas who finetuned Mistral's official instruct model. I have been trying to finetune mistral with several datasets over dozens of ablations. There is very insane loss instability training this model with transformers that never seems to appear with his training runs which do not use hf trainer.

I am opening this so we can get to the bottom of this. Here are some of my runs using axolotl with some datasets.

With hermes 2.0 dataset (unpublished):
https://wandb.ai/teknium1/hermes2.0-mistral-7b?workspace=user-teknium1

With Teknium/GPT4-LLM-CLEANED dataset
https://wandb.ai/teknium1/gpt4llm-mistral-7b

With a 5-sequences run to ensure loss goes to 0 (that memorization is occurring):
https://wandb.ai/teknium1/5seq-mistral-7b?workspace=user-teknium1

With OpenHermes dataset teknium1/openhermes:
https://wandb.ai/teknium1/hermes-mistral-7b

as can be seen, these loss charts with all these ablations are unreliable, and generally produce bad results no matter what hyperparams are changed.

Mistral dev who worked with me, he trained mistral with gpt4llm cleaned and got this result:
image

@younesbelkada @muellerz

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Train Mistral on any of the above datasets with Mistral's own finetune hyperparams as reported in mistral's discord and see the loss fail to work out

Expected behavior

A smooth or downward trajectory for the loss.

I have tried:
2e-5, 1e-5, 8e-6, 6e-6, 4e-6, with and without flash attention/xformers/none, with and without packing, with 0.1 and 0.01 weight decay, with long, medium, and short warmups (between 0.01% and 80% warmup steps to total steps), I've tried with Hermes 2.0, Hermes 1.0 (which has been trained on llama fine in several occasions), and GPT4LLM datasets, I've tried with FSDP, With Deepspeed zero2 & zero3, with and without groupbylength, with updated adam beta and epsilons #adam_beta2: 0.95
#adam_epsilon: 0.00001

with and without max_grad_norm: 1.0. I've basically run out of hyperparams to try tuning - several on fresh venv's

I have also come across an issue involving an irregular loss curve for finetuning mistral 7b.
unusual_loss

For reference some of my loss charts:
image
image
image

I am facing the same issue and loss is going up while finetuning on Dolly-15k dataset.

Same for me with the garage-bAInd/Open-Platypus Dataset. Though mine was extremely weird
image

Continue pre-training on Chinese/mandarin corpus
IMG_7827

Optimizer adamw
lr: 2.5e-5
Warmup: 4%
Bs 2
Seq Len 1024
Used flash attention in the pr

Continue pre-training on Chinese/mandarin corpus IMG_7827

Optimizer adamw lr: 2.5e-5 Warmup: 4% Bs 2 Seq Len 1024 Used flash attention in the pr

Any specific library you using for continued pre training?

Continue pre-training on Chinese/mandarin corpus IMG_7827

Optimizer adamw lr: 2.5e-5 Warmup: 4% Bs 2 Seq Len 1024 Used flash attention in the pr

Any specific library you using for continued pre training?

I am using SFTtrainer from trl. Noted that both runs failed. Orange one cannot converge. Green one dropped to loss=0.0 but in fact the model produced garbages

I am using SFTtrainer from trl. Noted that both runs failed. Orange one cannot converge. Green one dropped to loss=0.0 but in fact the model produced garbages

image
Same with fine tuning. The output is pure garbage even with all the standard hyperparams I used for fine tuning llama.

With Teknium/GPT4-LLM-CLEANED dataset https://wandb.ai/teknium1/gpt4llm-mistral-7b

With a 5-sequences run to ensure loss goes to 0 (that memorization is occurring): https://wandb.ai/teknium1/5seq-mistral-7b?workspace=user-teknium1

@teknium1 these both 404 😞

With Teknium/GPT4-LLM-CLEANED dataset https://wandb.ai/teknium1/gpt4llm-mistral-7b
With a 5-sequences run to ensure loss goes to 0 (that memorization is occurring): https://wandb.ai/teknium1/5seq-mistral-7b?workspace=user-teknium1

@teknium1 these both 404 😞

Sorry, my projects default to private, public'ed them

How did you load your model?

How did you load your model?

with transformers? or do you mean precision?

How did you load your model?

with transformers? or do you mean precision?

I was just wondering if you used one of the HuggingFace AutoModel classes or if you loaded it using the Mistral reference implementation.

How did you load your model?

with transformers? or do you mean precision?

I was just wondering if you used one of the HuggingFace AutoModel classes or if you loaded it using the Mistral reference implementation.

MistralForCausalLM

How did you load your model?

with transformers? or do you mean precision?

I was just wondering if you used one of the HuggingFace AutoModel classes or if you loaded it using the Mistral reference implementation.

MistralForCausalLM

I see. I guess one idea to sanity check could be to load the model using the reference implementation and ensure it behaves similarly to the HuggingFace version.

How did you load your model?

with transformers? or do you mean precision?

I was just wondering if you used one of the HuggingFace AutoModel classes or if you loaded it using the Mistral reference implementation.

MistralForCausalLM

I see. I guess one idea to sanity check could be to load the model using the reference implementation and ensure it behaves similarly to the HuggingFace version.

Do you mean outside of huggingface/hf trainer? The mistral dev did do this, we have totally different training results when he trains the same dataset, same hyperparams, without hf trainer.

How did you load your model?

with transformers? or do you mean precision?

I was just wondering if you used one of the HuggingFace AutoModel classes or if you loaded it using the Mistral reference implementation.

MistralForCausalLM

I see. I guess one idea to sanity check could be to load the model using the reference implementation and ensure it behaves similarly to the HuggingFace version.

Do you mean outside of huggingface/hf trainer? The mistral dev did do this, we have totally different training results when he trains the same dataset, same hyperparams, without hf trainer.

Yeah I mean just making sure both models are behaving similarly for a single forward/backwards pass on the same data without the trainer. If they are the same, then my guess is it probably narrows it down to the Trainer

Indeed, they are not the same. They are actually completely inverse lol

Indeed, they are not the same. They are actually completely inverse lol

interesting.

image

Trying the Pippa-ShareGPT dataset from huggingface, the loss is big.
https://wandb.ai/undis95/pippa-sharegpt-13b-qlora?workspace=user-undis95
I trained others datasets, but don't have screenshot of the loss nor the wandb.ai data since I just learned all this.
Data and dataset can be seen at source, OG dataset are always linked:

https://huggingface.co/Undi95/Mistral-pippa-sharegpt-7b-qlora
https://huggingface.co/Undi95/Mistral-7B-smoll_pippa-lora
https://huggingface.co/Undi95/Mistral-7B-roleplay_alpaca-lora

Result are not the one I expected, and I can't find a way to train properly.

I made a script that compares the last hidden state embeddings of both

Sampled values from Mistral embedding: [[-1.635 0.4966 -1.647 ]
[ 0.1438 0.2181 0.0925 ]
[ 0.2527 0.8457 0.8496 ]
[ 0.1675 0.07324 1.037 ]
[ 0.881 -0.614 0.1123 ]]
Sampled values from Hugging Face embedding: [[-1.7 0.5347 -1.733 ]
[ 1.075 1.69 0.7036]
[ 1.983 6.86 6.73 ]
[ 1.353 0.615 8.5 ]
[ 9.23 -6.65 1.188 ]]
Embedding difference (L2 norm): inf

see comparison script at https://github.com/bdytx5/mistral7B_finetune/blob/main/train/dev/cmp_models.py

also, you will have to add

def get_last_hidden_state(
    self,
    input_ids: torch.Tensor,
    cache: RotatingBufferCache,
    seqlens: List[int],
) -> torch.Tensor:
    assert len(seqlens) <= self.args.max_batch_size, f"Max batch size is {self.args.max_batch_size}, got batch size of {len(seqlens)}"
    assert sum(seqlens) == input_ids.shape[0], (sum(seqlens), input_ids.shape[0])

    input_metadata = cache.get_input_metadata(seqlens)
    h = self.tok_embeddings(input_ids)
    freqs_cis = self.freqs_cis[input_metadata.positions]

    for layer_id, layer in enumerate(self.layers):
        h = layer(h, freqs_cis, cache.get_view(layer_id, input_metadata))

    cache.update_seqlens(seqlens)

    return h  # Return the embeddings before the output layer.        

into the 'transformer' class of the reference implementation

I made a script that compares the last hidden state embeddings of both

Sampled values from Mistral embedding: [[-1.635 0.4966 -1.647 ] [ 0.1438 0.2181 0.0925 ] [ 0.2527 0.8457 0.8496 ] [ 0.1675 0.07324 1.037 ] [ 0.881 -0.614 0.1123 ]] Sampled values from Hugging Face embedding: [[-1.7 0.5347 -1.733 ] [ 1.075 1.69 0.7036] [ 1.983 6.86 6.73 ] [ 1.353 0.615 8.5 ] [ 9.23 -6.65 1.188 ]] Embedding difference (L2 norm): inf

see comparison script at https://github.com/bdytx5/mistral7B_finetune/blob/main/train/dev/cmp_models.py

also, you will have to add

def get_last_hidden_state(
    self,
    input_ids: torch.Tensor,
    cache: RotatingBufferCache,
    seqlens: List[int],
) -> torch.Tensor:
    assert len(seqlens) <= self.args.max_batch_size, f"Max batch size is {self.args.max_batch_size}, got batch size of {len(seqlens)}"
    assert sum(seqlens) == input_ids.shape[0], (sum(seqlens), input_ids.shape[0])

    input_metadata = cache.get_input_metadata(seqlens)
    h = self.tok_embeddings(input_ids)
    freqs_cis = self.freqs_cis[input_metadata.positions]

    for layer_id, layer in enumerate(self.layers):
        h = layer(h, freqs_cis, cache.get_view(layer_id, input_metadata))

    cache.update_seqlens(seqlens)

    return h  # Return the embeddings before the output layer.        

into the 'transformer' class of the reference implementation

So is this the cause of the loss issues or just a cleaner more proper implementation?

I made a script that compares the last hidden state embeddings of both
Sampled values from Mistral embedding: [[-1.635 0.4966 -1.647 ] [ 0.1438 0.2181 0.0925 ] [ 0.2527 0.8457 0.8496 ] [ 0.1675 0.07324 1.037 ] [ 0.881 -0.614 0.1123 ]] Sampled values from Hugging Face embedding: [[-1.7 0.5347 -1.733 ] [ 1.075 1.69 0.7036] [ 1.983 6.86 6.73 ] [ 1.353 0.615 8.5 ] [ 9.23 -6.65 1.188 ]] Embedding difference (L2 norm): inf
see comparison script at https://github.com/bdytx5/mistral7B_finetune/blob/main/train/dev/cmp_models.py
also, you will have to add

def get_last_hidden_state(
    self,
    input_ids: torch.Tensor,
    cache: RotatingBufferCache,
    seqlens: List[int],
) -> torch.Tensor:
    assert len(seqlens) <= self.args.max_batch_size, f"Max batch size is {self.args.max_batch_size}, got batch size of {len(seqlens)}"
    assert sum(seqlens) == input_ids.shape[0], (sum(seqlens), input_ids.shape[0])

    input_metadata = cache.get_input_metadata(seqlens)
    h = self.tok_embeddings(input_ids)
    freqs_cis = self.freqs_cis[input_metadata.positions]

    for layer_id, layer in enumerate(self.layers):
        h = layer(h, freqs_cis, cache.get_view(layer_id, input_metadata))

    cache.update_seqlens(seqlens)

    return h  # Return the embeddings before the output layer.        

into the 'transformer' class of the reference implementation

So is this the cause of the loss issues or just a cleaner more proper implementation?

It's definitely possible that a difference in initial weights is causing the strange training behavior. I might try using the official weights and converting it with their script to make sure the weights on huggingface are the same as the official weights.

One thing I have noticed is the config class for the model has default "rms_norm_eps": 1e-06 where the config used on huggingface hub uses 1e-05. I'm not sure if this matters but I might try converting the weights to make sure that they were originally converted using the right config. You can find the default config here https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/configuration_mistral.py

To follow up Tek, fter looking a little closer at this final layer embeddings

Sampled values from Mistral embedding: [[-1.635 0.4966 -1.647 2.324 -0.1011 ]
[ 0.1438 0.2181 0.0925 -1.136 0.2788 ]
[ 0.2527 0.8457 0.8496 -0.4353 -0.3838 ]
[ 0.1675 0.07324 1.037 -1.225 0.158 ]
[ 0.881 -0.614 0.1123 -1.201 0.2915 ]]
Sampled values from Hugging Face embedding: [[-1.706 0.593 -2.016 2.396 -0.05334]
[ 2.277 0.762 0.0974 -8.88 3.088 ]
[ 2.75 5.703 6.695 -4.22 -2.928 ]
[ 1.782 -0.5884 8.914 -9.2 1.583 ]
[ 7.8 -5.42 1.145 -9.29 4.605 ]]
Embedding difference (L2 norm): inf

The huggingface outputs seem pretty high in comparison to the official ones which does seem suspicious...

Hi @teknium1 @bdytx5

Reading through the thread and the options you have tried I first suspected that the issue might come from the new window causal mask
On my end I have tried to FT mistral-7b using QLoRA, with 2 different approaches:

1- Using vanilla causal mask
2- Using the window attention mask

I have fine-tuned the 7B using QLoRA, this script and using a context length of 512 and sliding window size of 256 to make sure the sliding window mask will behave correctly: https://gist.github.com/younesbelkada/9f7f75c94bdc1981c8ca5cc937d4a4da with model_id being changed to mistral 7b, with packing and here is the behaviour of the losses

Screenshot 2023-10-03 at 13 52 24

Despite the model not "nicely" converging as the ideal loss curve you shared, the model manages to produce generation that are coherent with Guanaco dataset

# input: ### Human: Can you write a short introduction about the relevance of the term "monopsony" in economics? Please use examples related to potential monopsonies in the labour market and cite relevant research.### Assistant:

>>> '### Human: Can you write a short introduction about the relevance of the term "monopsony" in economics? Please use examples related to potential monopsonies in the labour market and cite relevant research.### Assistant: Monopsony is a market structure where there is only one buyer of a good or service. In the context of the labour market, a monopsony occurs when there is only one employer in a particular industry or region. This can happen for a variety of reasons, such as government regulation, natural monopolies, or the existence of a single large firm that dominates the market.\n\nThe concept of monopsony in the labour market has gained increasing attention in recent years'

Model weights here: https://huggingface.co/ybelkada/mistral-7b-guanaco

What @bdytx5 said makes sense, there might be some differences between original model's logits and ours, indeed HF version uses 1e-5: https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json#L16 whereas mistral uses 1e-6: https://github.com/mistralai/mistral-src/blob/main/mistral/model.py#L129

@teknium1 can you try to run a training with this version of the model instead: https://huggingface.co/mistralai/Mistral-7B-v0.1/discussions/35 just pass revision="refs/pr/35" when calling from_pretrained

Reading through the thread and the options you have tried I suspected that the issue might come from the new window causal mask

I haven't looked into much detail yet, but the mask seems to unconditionally attend to cached key/values. Shouldn't the sliding window apply to cached key/values as well?

mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)

(In the case of generating a batch of single tokens at a time, there is also https://github.com/huggingface/transformers/blob/ae9a344cce52ff244f721425f660b55ebc522b88/src/transformers/models/mistral/modeling_mistral.py#L795C30-L795C30, which skips applying the window to the k/v cache.)

Hi @teknium1 @bdytx5

Reading through the thread and the options you have tried I first suspected that the issue might come from the new window causal mask On my end I have tried to FT mistral-7b using QLoRA, with 2 different approaches:

1- Using vanilla causal mask 2- Using the window attention mask

I have fine-tuned the 7B using QLoRA, this script and using a context length of 512 and sliding window size of 256 to make sure the sliding window mask will behave correctly: https://gist.github.com/younesbelkada/9f7f75c94bdc1981c8ca5cc937d4a4da with model_id being changed to mistral 7b, with packing and here is the behaviour of the losses

Screenshot 2023-10-03 at 13 52 24

Despite the model not "nicely" converging as the ideal loss curve you shared, the model manages to produce generation that are coherent with Guanaco dataset

# input: ### Human: Can you write a short introduction about the relevance of the term "monopsony" in economics? Please use examples related to potential monopsonies in the labour market and cite relevant research.### Assistant:

>>> '### Human: Can you write a short introduction about the relevance of the term "monopsony" in economics? Please use examples related to potential monopsonies in the labour market and cite relevant research.### Assistant: Monopsony is a market structure where there is only one buyer of a good or service. In the context of the labour market, a monopsony occurs when there is only one employer in a particular industry or region. This can happen for a variety of reasons, such as government regulation, natural monopolies, or the existence of a single large firm that dominates the market.\n\nThe concept of monopsony in the labour market has gained increasing attention in recent years'

Model weights here: https://huggingface.co/ybelkada/mistral-7b-guanaco

What @bdytx5 said makes sense, there might be some differences between original model's logits and ours, indeed HF version uses 1e-5: https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json#L16 whereas mistral uses 1e-6: https://github.com/mistralai/mistral-src/blob/main/mistral/model.py#L129

@teknium1 can you try to run a training with this version of the model instead: https://huggingface.co/mistralai/Mistral-7B-v0.1/discussions/35 just pass revision="refs/pr/35" when calling from_pretrained

Next time I try a full finetune I will. I actually did succeed at training airoboros' dataset over mistral 7b, with a qlora. Leading me to one of two conclusions:

One (or more) of the datasets for hermes 2.0 is malformed, or, qlora is the only way to get the reliable training/good loss curves that I want atm. Will try with the revision next full finetune I try.

On a side note about Mistral, @younesbelkada,

When I inference 7b Mistral on a 4090, with just 2k max seq length, It uses >24gb of vram. It hits 23.3GB of vram used then starts offloading to CPU.

image

The code I run to make this happen:

import torch#, json, os, sys
import time
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import LlamaTokenizer, LlamaForCausalLM, MistralForCausalLM
#import bitsandbytes

tokenizer = LlamaTokenizer.from_pretrained('./collectivecognition-run6', trust_remote_code=True)
model = MistralForCausalLM.from_pretrained(
    "./collectivecognition-run6",
    torch_dtype=torch.bfloat16,
    device_map="auto",
    load_in_8bit=False
    #trust_remote_code=True
)
benchmarks = [
    "Hello, tell me about the history of the United States",
    "Roleplay as a scientist, who just discovered artificial general intelligence. What do you think about this discovery? What possibilities are there now?"]

index = 0
for obj in benchmarks:
    

    index += 1
    if index < 1:
        continue
    else:
        start_time = time.time()  # Start timing
        prompt = f"USER:\n{obj}\n\nASSISTANT:\n"
        input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda")
        generated_ids = model.generate(input_ids, max_new_tokens=2048, temperature=None)#, do_sample=True, eos_token_id=tokenizer.eos_token_id)
        response = tokenizer.decode(generated_ids[0][input_ids.shape[-1]:], skip_special_tokens=True, clean_up_tokenization_space=True)
        print(f"Response  {index}: {response}")

        end_time = time.time()  # End timing
        elapsed_time = end_time - start_time  # Calculate time taken for the iteration
        print(f"Time taken for Response {index}: {elapsed_time:.4f} seconds")
        print(f"tokens total: {len(tokenizer.encode(response))}")

@teknium1
I believe because the vanilla implementation we have currently in transformers does not allow cache slicing as per the original repository.
To benefit from fixed-size cache and memory efficient generation, you can use the Flash Attention 2 version of the model

import torch#, json, os, sys
import time
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import LlamaTokenizer, LlamaForCausalLM, MistralForCausalLM
#import bitsandbytes

tokenizer = LlamaTokenizer.from_pretrained('./collectivecognition-run6', trust_remote_code=True)
model = MistralForCausalLM.from_pretrained(
    "./collectivecognition-run6",
    torch_dtype=torch.bfloat16,
    device_map="auto",
    use_flash_attention_2=True
)
benchmarks = [
    "Hello, tell me about the history of the United States",
    "Roleplay as a scientist, who just discovered artificial general intelligence. What do you think about this discovery? What possibilities are there now?"]

index = 0
for obj in benchmarks:
    

    index += 1
    if index < 1:
        continue
    else:
        start_time = time.time()  # Start timing
        prompt = f"USER:\n{obj}\n\nASSISTANT:\n"
        input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda")
        generated_ids = model.generate(input_ids, max_new_tokens=2048, temperature=None)#, do_sample=True, eos_token_id=tokenizer.eos_token_id)
        response = tokenizer.decode(generated_ids[0][input_ids.shape[-1]:], skip_special_tokens=True, clean_up_tokenization_space=True)
        print(f"Response  {index}: {response}")

        end_time = time.time()  # End timing
        elapsed_time = end_time - start_time  # Calculate time taken for the iteration
        print(f"Time taken for Response {index}: {elapsed_time:.4f} seconds")
        print(f"tokens total: {len(tokenizer.encode(response))}")

Check the results of my benchmark here: #26464 (comment)

@teknium1 for full fine-tuning with DS how do you create the packed dataset ? Do you use the SFTTrainer with packing=True ?
See this PR from @lewtun : https://huggingface.co/mistralai/Mistral-7B-v0.1/discussions/26 and https://twitter.com/jon_durbin/status/1709147204915523929?s=20 for reference. If you use the SFTTrainer the eos token is correctly added at the end of each chunk: https://github.com/huggingface/trl/blob/main/trl/trainer/utils.py#L577 but if you pre-tokenize your dataset manually you will never get any EOS token properly encoded I think

@teknium1 for full fine-tuning with DS how do you create the packed dataset ? Do you use the SFTTrainer with packing=True ? See this PR from @lewtun : https://huggingface.co/mistralai/Mistral-7B-v0.1/discussions/26 and https://twitter.com/jon_durbin/status/1709147204915523929?s=20 for reference. If you use the SFTTrainer the eos token is correctly added at the end of each chunk: https://github.com/huggingface/trl/blob/main/trl/trainer/utils.py#L577 but if you pre-tokenize your dataset manually you will never get any EOS token properly encoded I think

I use Axolotl, but @winglian would have to explain the implementation. However, when I print the tokenized dataset with axolotl, it appears fine. Also I have done several QLora's now with axolotl on Mistral, with new datasets so far, and they are turning out perfect, like, astounding, so its either dataset specific (all 3 of the datasets I tried full ft's on), or full-finetune only impacts

I took a look at tuning Mistral 7B with TRL's SFTTrainer and DeepSpeed ZeRO-3 on a subset of the UltraChat dataset and the loss seems to converge as expected:

Screenshot 2023-10-05 at 13 09 39

Here's a gist of the tweaks I made to the TRL example in case it's useful to others: https://gist.github.com/lewtun/b9d46e00292d9ecdd6fd9628d53c2814

Overall, I think the divergences some people are reporting could be due to dataset issues (e.g. how you format the chat template) and/or choice of hyperparameters. As far as I can tell, there is no issue in the SFTTrainer or Trainer from transformers.

@younesbelkada

I believe because the vanilla implementation we have currently in transformers does not allow cache slicing as per the original repository.
To benefit from fixed-size cache and memory efficient generation, you can use the Flash Attention 2 version of the model

Indeed there are two mechanism in Mistral Original repo 1) sliding window 2) rolling buffer cache

I have the impression that in HF you implemented only the sliding windows attention by playing only on the attention mask and ONLY at training time, which means that at inference, the full length is taken into account, am I correct ?

I trained several models successfully with qlora. All of them on datasets that make up hermes 2.0 - all but one turned out excellent in terms of loss graph. I removed that bad dataset, I cleaned up hermes 2.0 dramatically since then. and I return to full finetune today:

This is with 2e-5 300 warmup steps
image

Same dataset as a qlora, working perfectly fine:
image

An additional information point:
@winglian said today:
caseus — Today at 12:39 PM
I had posted this the other day. there is something slightly amiss with mistral finetunes, my hunch is it's a transformers issue somewhere. Vastly different LR's (6.5x) but set warmup steps so they followed the same LR trajectory. One would expect that with the same LR at the same step, the loss and gradient should be identical?

image

4e-6lr fullfinetune run, still a nope:
image

@teknium1 were you able to try the lower rms_norm?

@teknium1 can you try to run a training with this version of the model instead: https://huggingface.co/mistralai/Mistral-7B-v0.1/discussions/35 just pass revision="refs/pr/35" when calling from_pretrained

Will attempt to do so in about 30 mins

I am having what appears to be a potentially successful run.. at 1e-6 LR.. so I wont try the revision as to let this one play out for now.. but Ive never seen that low of an LR for finetuning a model with before.. will try the rev if this does end up spiking
image

welp..
image

will try the new revision .. lol

Okay, so the revision didn't help either, all the runs starting with higher loss are with the revision, one with lr 4e-6, several with 1e-6, one with gradient clipping 1.0, another with 0.3, one with weight decay 20% (vs 0.3% in all others)
image

the next image is it full finetuning on llama-2 13b flawlessly:
image

And here is the qlora on mistral with the same dataset as well:
image

@teknium1 I don't use flash attention and set tokenizer.padding_size = right, then my loss is ok. But if using flash attention and set tokenizer.padding_size = left, it causes the loss instability. I think you should check this. (maybe the loss instability is due to flash attention code of mistral model and tokenizer.padding_size=left)

Here is some experiment that I do with tokenizer.padding_size = right and no flash attention:

  1. Qlora (4 bit):
    image

  2. LoRA (no 4 bit, no 8 bit):
    image

Well friends, it seems it is not mistral specific... :(
image

My dataset truly looks pristine to me, I cannot find any systemic or widespread issues in it. I dont know why it fails. You would think if it is the dataset, qlora would likely be unstable as well, but, maybe not. Maybe it's a change made in axolotl, but many people are training models right now with current main branch, and Ive tried with their configs/hyperparams. I'm at a loss 🤷‍♂️

That's quite curious. Did you try shuffling the dataset? It looks like there may be some overfitting occurring at the beginning.

That's quite curious. Did you try shuffling the dataset? It looks like there may be some overfitting occurring at the beginning.

The dataset is shuffled automatically with axolotl. I have used the same shuffle seed in all runs though.

I'm willing to grant access to the dataset itself if anyone thinks they may find something me and several others who've looked at it have not, if interested

@teknium1 here is my loss with full finetuning, I think the loss is decreasing suitably (although at the beginning, the loss is increased, but after a few steps, it is decreasing)
image

Well, @winglian ended up running a finetune over mistral with hermes2 dataset, (well, its running atm), and this is the loss chart now:
image

It looks... good. Why? I dont know. The only difference in what he is doing and what I have done is he is using deepspeed, and he has set it up for chatml format instead of traditional sharegpt/vicuna/fastchat format. As far as I can tell, that is the only difference. Will focus future ablations on those 2 factors..

What @bdytx5 said makes sense, there might be some differences between original model's logits and ours, indeed HF version uses 1e-5: https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json#L16 whereas mistral uses 1e-6: https://github.com/mistralai/mistral-src/blob/main/mistral/model.py#L129

I think this is wrong, Mistral uses 1e-5 because it reads params.json which has 1e-5

Thanks for the heads up I have realised that this was wrong few days ago and I have closed the PR accordingly as you can see from: https://huggingface.co/mistralai/Mistral-7B-v0.1/discussions/35

Did you see my other comment above wrt inference and rotating cache?

I have the impression that in HF you implemented only the sliding windows attention by playing only on the attention mask and ONLY at training time, which means that at inference, the full length is taken into account, am I correct ?

If you use the vanilla HF attention yes, that is the case we did not implemented the rotating buffer cache mechanism as it requires an important refactor

However we tried to mimic the rotating buffer caching mechanism by constraining it only in the case where padding_side=left for FA-2 models by shifting the cache and slicing out the previous tokens when generating the next token. See my benchmarks here for more details: #26464 (comment)

yes exactly

Anyway it requires some hardware to support seqlen > 4096 ....

No you can scale to very large sequence length as the cache will be always having 4096 tokens, similarly as the rotating buffer cache from original mistral repository.

Per my understanding (cc @timlacroix please correct me if I am wrong) since we always use absolute positional embedding the model is able to keep the whole context even if we go beyond 4096 tokens.
In case one feeds to the model a super large context (>4096) directly on the first iteration, you will indeed need to enough compute but since the FA module will use sliding window attention, it should be quite memory efficient. Slicing the cache afterwards is not a problem since the model has already computed attention scores based on the entire context on the first iteration so the information is not lost.
In case of batched generation it is slightly more complex since we don't follow the exact same procedure as mistral's rotating buffer cache, we slice out the first tokens of the cache after the first iteration. But in case of BS=1 you should get pretty decent performance, if you have a hardware that supports FlashAttention 2 you can try to generate up to very large number of tokens without any major issue I believe

hmm the cache size is not the only limiting factor. You still need to forward the full sequence to the model, and the flash2 still happens with the full length even if the mechanism makes it linear to length (and not quadratic)

but that's the case in any case right? for the first forward is you pass a large context you'll need to compute the attention scores on all tokens.

@younesbelkada @bdytx5 @vince62s @arthurmensch

Okay update on the issue.

image

The above image is testing with deepspeed zero 2 vs FSDP. Zero 2 is the more stable trajectory run. Same hyperparams on all else. I feel like I tested with zero3 in the past, and found same as FSDP run, a U shaped pattern, but I am not sure atm.

At the moment I dont know if it is being caused by axolotl's interactions with FSDP, or if it is something in transformers/accelerate/who knows what. But this seems like an important development in figuring out whats going on, not sure how much you guys can look into it, but figured I'd place the info here in case it isn't axolotl's code.

edit: nevermind...
image

however, it still looks far better than my loss curves on runs with much lower LR's than this one above (it has 2.5e-5)
image

Ok I did a new longer run with deepspeed zero 2 vs fsdp all else same:
image

Something about fsdp is making it converge slower (and technically, loss is not moving downward at all, very very very slightly upward) - with LR 4e-6

Zero 3 and Zero 2 seem fine, just not FSDP. I will reference the issue in axolotl and pytorch repos

image

image

for me, using transformer, trainer and custom dataset, batch size of 2, accumulation of 6, training loss drop to 0.0 after certain points. Eval loss become NaN
I am using torch_dtype of torch.float16

I ve seen someone saying change float16 to bfloat16 ?

hi @nps798
Yes I think using bfloat16 is preferrable to be on the safe zone. Also something strange that I have noticed is that in case you use padding, make sure to set padding_side="right" in case you train with padding tokens: https://gist.github.com/younesbelkada/9f7f75c94bdc1981c8ca5cc937d4a4da?permalink_comment_id=4636728#gistcomment-4636728

hi @nps798 Yes I think using bfloat16 is preferrable to be on the safe zone. Also something strange that I have noticed is that in case you use padding, make sure to set padding_side="right" in case you train with padding tokens: https://gist.github.com/younesbelkada/9f7f75c94bdc1981c8ca5cc937d4a4da?permalink_comment_id=4636728#gistcomment-4636728

Thanks for your reply. I'll give it a try soon.

BTW, I have just encountered another issue with my previous float16 and padding left setting, qlora
I ve checked my input batch data near around those batches (yeah I print out all batch on each step), nothing weird or special.
I check all the model's parameters with the following code

for name, param in model.named_parameters():
        if torch.isnan(param).any():
            print(f'NaN value detected in model weights: {name}')
        if torch.isinf(param).any():
            print(f'Infinity value detected in model weights: {name}')

Nothing was printed.

So...
Correct me if I am wrong, the folloinwg NaN is not coming from problematic dataset.
It is related to some weights of the model being too small or too big, and the NaN will be produced by any dataset. And are unable to detect beforehand ?

input[0] has nans
output has nans

Detected inf/nan during batch_number=54681
Last 21 forward frames:
abs min abs max metadata
base_model.model.model.layers.30.mlp.gate_proj Linear4bit
0.00e+00 2.55e+02 weight
0.00e+00 1.85e+02 input[0]
5.96e-08 3.02e+01 output
base_model.model.model.layers.30.mlp.act_fn SiLUActivation
5.96e-08 3.02e+01 input[0]
0.00e+00 2.39e+01 output
base_model.model.model.layers.30.mlp.up_proj Linear4bit
0.00e+00 2.55e+02 weight
0.00e+00 1.85e+02 input[0]
5.96e-08 2.36e+01 output
base_model.model.model.layers.30.mlp.down_proj Linear4bit
0.00e+00 2.55e+02 weight
0.00e+00 3.70e+02 input[0]
0.00e+00 1.38e+02 output
base_model.model.model.layers.30.mlp MistralMLP
0.00e+00 1.85e+02 input[0]
0.00e+00 1.38e+02 output
base_model.model.model.layers.30 MistralDecoderLayer
0.00e+00 3.05e+02 input[0]
0.00e+00 1.67e+02 output[0]
0.00e+00 1.68e+01 output[1][0]
0.00e+00 8.28e+00 output[1][1]
base_model.model.model.layers.31.input_layernorm MistralRMSNorm
8.36e-01 8.75e+00 weight
0.00e+00 1.67e+02 input[0]
0.00e+00 9.58e+01 output
base_model.model.model.layers.31.self_attn.q_proj.lora_dropout.default Dropout
0.00e+00 9.58e+01 input[0]
0.00e+00 1.01e+02 output
base_model.model.model.layers.31.self_attn.q_proj.lora_A.default Linear
9.78e-08 1.07e-01 weight
0.00e+00 1.01e+02 input[0]
2.04e-03 8.38e+01 output
base_model.model.model.layers.31.self_attn.q_proj.lora_B.default Linear
1.98e-07 8.64e-02 weight
2.04e-03 8.38e+01 input[0]
2.06e-07 2.49e+01 output
base_model.model.model.layers.31.self_attn.q_proj Linear4bit
0.00e+00 2.55e+02 weight
0.00e+00 9.58e+01 input[0]
0.00e+00 2.62e+01 output
base_model.model.model.layers.31.self_attn.k_proj.lora_dropout.default Dropout
0.00e+00 9.58e+01 input[0]
0.00e+00 1.01e+02 output
base_model.model.model.layers.31.self_attn.k_proj.lora_A.default Linear
2.39e-07 7.29e-02 weight
0.00e+00 1.01e+02 input[0]
6.44e-05 5.60e+01 output
base_model.model.model.layers.31.self_attn.k_proj.lora_B.default Linear
3.00e-07 6.73e-02 weight
6.44e-05 5.60e+01 input[0]
4.96e-07 1.24e+01 output
base_model.model.model.layers.31.self_attn.k_proj Linear4bit
0.00e+00 2.55e+02 weight
0.00e+00 9.58e+01 input[0]
0.00e+00 1.85e+01 output
base_model.model.model.layers.31.self_attn.v_proj.lora_dropout.default Dropout
0.00e+00 9.58e+01 input[0]
0.00e+00 1.01e+02 output
base_model.model.model.layers.31.self_attn.v_proj.lora_A.default Linear
1.05e-07 1.07e-01 weight
0.00e+00 1.01e+02 input[0]
1.04e-03 5.54e+01 output
base_model.model.model.layers.31.self_attn.v_proj.lora_B.default Linear
7.20e-07 3.79e-02 weight
1.04e-03 5.54e+01 input[0]
7.59e-07 6.53e+00 output
base_model.model.model.layers.31.self_attn.v_proj Linear4bit
0.00e+00 2.55e+02 weight
0.00e+00 9.58e+01 input[0]
0.00e+00 8.99e+00 output
base_model.model.model.layers.31.self_attn.rotary_emb MistralRotaryEmbedding
0.00e+00 8.99e+00 input[0]
5.15e-05 1.00e+00 output[0]
0.00e+00 1.00e+00 output[1]
base_model.model.model.layers.31.self_attn.o_proj Linear4bit
0.00e+00 2.55e+02 weight
nan nan input[0]
nan nan output

image
My training loss is behaving strangely as it suddenly explodes at different positions during each training. I attempted to resolve this issue by following the instructions in mistral-7b-instruct and setting padding_side to "right", with pad_token being set as eos_token, but it didn't solve the problem. I use deepspeed stage3 and bfloat16.

@younesbelkada thank you
I set the torch dtype to bf16 (while remaining the padding as left

successfully qlora fine tuning with 5 epoch without exploding loss or zero loss.

will keep experiment some other combinations of parameters

I can confirm at least 2 other people have this issue with FSDP now. I still see loss go up after per-epoch drops in my training runs with deepspeed as well however, leaving me concerned but in a better state than previously.. which was always U shaped loss curves
image
image

Hi everyone
Thanks a lot for the deep investigation, recently @pacman100 managed to successfully fine-tune llama (from what I have understood the issue is quite agnostic to the architecture) using FSDP and shared some insights here: huggingface/accelerate#2127 (comment)
It seems the solution is to not load the model in bf16 and instead enable mixed precision training through TrainingArguments by passing bf=16 cc @pacman100 in case I missed something

Thanks a lot for the deep investigation, recently @pacman100 managed to successfully fine-tune llama (from what I have understood the issue is quite agnostic to the architecture) using FSDP and shared some insights here

I think this was a misunderstanding, and actually it's not successfully training. However @tmabraham did show a workaround in that thread.

Hello,

I ran the below experiment to see the fine-tuning using FSDP and Mistral was as expected. Below are the results:

  1. Codebase: https://github.com/pacman100/DHS-LLM-Workshop/tree/main/chat_assistant/training
  2. Dataset: smangrul/chat-instruct-mixer
  3. Model: mistralai/Mistral-7B-v0.1
  4. Accelerate config after running accelerate config --config_file fsdp_config.yaml and answering the questionnaire:
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_backward_prefetch_policy: BACKWARD_PRE
  fsdp_cpu_ram_efficient_loading: true
  fsdp_forward_prefetch: false
  fsdp_offload_params: false
  fsdp_sharding_strategy: 1
  fsdp_state_dict_type: SHARDED_STATE_DICT
  fsdp_sync_module_states: true
  fsdp_use_orig_params: true
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
  1. Command:
accelerate launch \
    --config_file configs/fsdp_config.yaml \
    train.py \
    --model_name "mistralai/Mistral-7B-v0.1" \
    --dataset_name "smangrul/chat-instruct-mixer" \
    --max_seq_len 4096 \
    --max_steps 5000 \
    --logging_steps 25 \
    --eval_steps 1000 \
    --save_steps 1000 \
    --bf16 True \
    --packing True \
    --output_dir "/fsx/sourab/experiments/full-finetune-mistral-7b-fsdp-chat-asst" \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 1 \
    --dataset_text_field "content" \
    --use_gradient_checkpointing False \
    --learning_rate 5e-6  \
    --lr_scheduler_type "cosine" \
    --weight_decay 0.01 \
    --warmup_ratio 0.03 \
    --max_grad_norm 1.0 \
    --use_flash_attn True
  1. Training plots at the end of 1000 steps:
Screenshot 2023-11-16 at 1 59 09 PM
  1. Observations:
    a. Loss is going down as expected and it is successfully training.
    b. Sensitivity to learning rate: When I used learning rates of 5e-5 or 2e-5, the training was not converging properly. 5e-6 worked best for my dataset. So, when fully fine-tuning, hyperparameter tuning is important.
    c. seq-length 4096 with batch size 8 (per GPU 1 and gradient accumulation steps 1) has lower loss when compared to seq-length 2048 with batch size 16 (per GPU 1 and gradient accumulation steps 2).
  2. Library versions:
  • Output of transformers-cli env:
- `transformers` version: 4.35.2
- Platform: Linux-5.15.0-1023-aws-x86_64-with-glibc2.31
- Python version: 3.11.4
- Huggingface_hub version: 0.16.4
- Safetensors version: 0.3.2
- Accelerate version: 0.24.1
- Accelerate config: 	not found
- PyTorch version (GPU?): 2.1.0.dev20230809 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: <fill in>
- Using distributed or parallel set-up in script?: <fill in>
  • Output of accelerate env:
- `Accelerate` version: 0.24.1
- Platform: Linux-5.15.0-1023-aws-x86_64-with-glibc2.31
- Python version: 3.11.4
- Numpy version: 1.24.3
- PyTorch version (GPU?): 2.1.0.dev20230809 (True)
- PyTorch XPU available: False
- PyTorch NPU available: False
- System RAM: 1121.82 GB
- GPU type: NVIDIA A100-SXM4-80GB
- `Accelerate` default config:
	Not found
  • flash-attn: 2.3.3

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Is this solved due to the previous mention?