ml-explore / mlx-swift-examples

Examples using MLX Swift

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Fine-tuned Qwen2 model inference error

madroidmaq opened this issue · comments

I fine-tuned the model based on the Qwen/Qwen1.5-0.5B-Chat model and then fused the models. The final output when reasoning with the mlx-lm model is as expected (a specific URL link) and is roughly as follows:

python -m mlx_lm.generate --model qwen1.5-0.5B-4bit --max-tokens 50 --temp 0 --colorize --prompt "Bring down the brightness."
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
==========
Prompt: <|im_start|>system
You are a helpful assistant<|im_end|>
<|im_start|>user
Bring down the brightness.<|im_end|>
<|im_start|>assistant

flex://system/set?brightness=down
==========
Prompt: 579.851 tokens-per-sec
Generation: 149.535 tokens-per-sec

When running mlx-community/Qwen1.5-0.5B-Chat-4bit from the llm-tool command line, it worked fine. When I loaded my own fine-tuned fusion model, its inference was wrong and could not correctly predict the subsequent text generation, the effect is roughly as follows:

image

I'm not quite sure what's wrong, is it possible to give directions for further troubleshooting and I'll make an attempt, many thanks.

Yeah, here are some ideas that I wrote for somebody else working on LLMs:

you already have this part so you know the model works

example of running it removed

Notice the augmentation of the prompt -- this is done using python code in the tokenizer configuration. We can't run that so you may need some configuration to help with this. For example in the example repo:

you already have this part and can see the augmented prompt

Given the working python version you can do a few things:

  • the tokenizer produces an array of integers

    • print out the tokens the python code generates, see utils.py: prompt_tokens = mx.array(tokenizer.encode(prompt))
    • hard code the swift code to take this same array
    • if this array works then you can suspect something in the tokenizer
  • the tokenizer can decode the tokens it prepares

    • make sure it can decode both the tokens the swift tokenizer makes
    • and the tokens the python code makes
  • set the random seed

    • --seed in the command line tool and MLXRandom.seed() in python
    • maybe set the temperature to 0
    • generate a small number of tokens
    • are they the same? the code to produce tokens from the logits might be slightly different between the two but I found the first token is usually the same with the same seed
  • assuming the tokens are different compare the execution of the models

    • I found something like print("\(name) \(array.shape) \(array.sum())") in both swift and python (similar code in python) can help spot differences without looking at the whole tensor
    • I had typos in the Attention layer a couple times -- incorrectly place parenthesis, etc.
  • make sure your weights are loaded correctly

    • try model.update(parameters: parameters, verify: [.none])

this is already done in the example code but it could lead to issues like this if one of the parameters was still just random data.

Good luck and ask if you have questions!

@davidkoski Thank you for your suggestion. I used my own integrated model to input the same Prompt and print out the encoded content, and indeed found differences.

The original Prompt content is as follows:

<|im_start|>system
You are a helpful assistant<|im_end|>
<|im_start|>user
hello<|im_end|>
<|im_start|>assistant

The comparison results are as follows, the first line is the printout from swift, and the second line is the printout from mlx-lm:

[27, 91, 318, 4906, 91, 29, 8948, 198, 2610, 525, 264, 10950, 17847, 27, 91, 318, 6213, 91, 397, 27, 91, 318, 4906, 91, 29, 872, 198, 14990, 27, 91, 318, 6213, 91, 397, 27, 91, 318, 4906, 91, 29, 77091]
[151644, 8948, 198, 2610, 525, 264, 10950, 17847, 151645, 198, 151644, 872, 198, 14990, 151645, 198, 151644, 77091, 198]

I use the mlx-community/Qwen1.5-0.5B-Chat-4bit model for the same test:

The comparison results are as follows, the first line is the printout from swift, and the second line is the printout from mlx-lm:

[27, 91, 318, 4906, 91, 29, 8948, 198, 2610, 525, 264, 10950, 17847, 27, 91, 318, 6213, 91, 397, 27, 91, 318, 4906, 91, 29, 872, 198, 14990, 27, 91, 318, 6213, 91, 397, 27, 91, 318, 4906, 91, 29, 77091]
[151644, 8948, 198, 2610, 525, 264, 10950, 17847, 151645, 198, 151644, 872, 198, 14990, 151645, 198, 151644, 77091, 198]

I am the mlx-lm module, performing decode operations on the above 2 sets of arrays separately, and found that their output texts are completely identical, both are the input content.

From the data above, it can be basically determined that there was a problem with the Tokenizer during encoding. Analyzing the code reveals that the current logic replaces Qwen2Tokenizer with PreTrainedTokenizer. I suspect that there are some special treatments in Qwen2Tokenizer, and using PreTrainedTokenizer as a substitute would lead to some abnormal situations.

"Qwen2Tokenizer": "PreTrainedTokenizer",

I am not familiar with this part and hope someone can support the complete logic in Qwen2Tokenizer.

If you feed the python tokens in the swift model does it produce the expected output?

Yeah, it may be that there is more to the Qwen tokenizer. There must be more than a hundred of them: https://github.com/huggingface/tokenizers

The PreTrainedTokenizer is pretty generic and it seems to handle quite a bit but maybe not everything.

If you look at the tokenizer.json you can see what some of those tokens are:

  "added_tokens": [
    {
      "id": 151643,
      "content": "<|endoftext|>",
      "single_word": false,
      "lstrip": false,
      "rstrip": false,
      "normalized": false,
      "special": true
    },
    {
      "id": 151644,
      "content": "<|im_start|>",
      "single_word": false,
      "lstrip": false,
      "rstrip": false,
      "normalized": false,
      "special": true
    },
    {
      "id": 151645,
      "content": "<|im_end|>",
      "single_word": false,
      "lstrip": false,
      "rstrip": false,
      "normalized": false,
      "special": true
    }
  ],

The tokens inside the tokenizer are split up like this:

(lldb) p preTokenize(normalize(text))
([String]) 18 values {
  [0] = "<|"
  [1] = "im"
  [2] = "_start"
  [3] = "|>"
  [4] = "user"
  [5] = "Ġcompare"
  [6] = "Ġpython"
  [7] = "Ġand"
  [8] = "Ġswift"
  [9] = "<|"
  [10] = "im"
  [11] = "_end"
  [12] = "|><|"
  [13] = "im"
  [14] = "_start"
  [15] = "|>"
  [16] = "assistant"
  [17] = ""
}

I am not familiar with how the tokenizers work internally, but it looks like it isn't treating the markers properly though I can see the addedTokens being passed in.

It looks like this sort of covers it: huggingface/swift-transformers#4

Not everything is built -- just enough to cover the cases they tried. Potentially one could contribute back to the project. I looked at the javascript implementation and can see how the added tokens are managed:

In the PreTrainedTokenizer you can see that some of these are just not wired up:

        // TODO: specialTokens are stored but never used
        self.specialTokens = specialTokens
        self.addedTokens = Set(addedTokens.keys)

If you feed the python tokens in the swift model does it produce the expected output?

Yeah, it may be that there is more to the Qwen tokenizer. There must be more than a hundred of them: https://github.com/huggingface/tokenizers

The PreTrainedTokenizer is pretty generic and it seems to handle quite a bit but maybe not everything.

I pass an array into swift using mlx-lm encode and the output is as expected.