ml-explore / mlx-examples

Examples in the MLX framework

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Seems like when generating, some memory usage cannot be correctly released.

alexC-nonsense4k opened this issue · comments

I am using this code to generate answer with LLM CodeLlama2-13b

for i in range(len(json_list)):
    prompt_raw=json_list[i]['prompt']
    messages = [{"role": "user", "content": prompt_raw}]
    prompt = tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
    response = generate(model, tokenizer, max_tokens=4096,prompt=prompt, verbose=True,temp=0.2)

But during the generation, I find something interesting.
I using a model with 13b parameters. So when the loading is done. The model takes up about 24GB memory which is correct.
But When generating, the memory usage keep increasing.
After first generating the total memory usage is around 50GB
After second generating the total memory usage is around 90GB
And the memory usage will keep at 90GB
I know that there is cache used to speed up the calculation of decoder but I think when a generation process finishes, the memory occupied by the cache should be released. However, based on my current usage, the memory release does not seem to occur. I am not sure whether this issue is due to the code itself or Python settings not correctly freeing up the memory.

This occurs with M1 ultra, 128GB unified memory, python3.10 mlx 0.11.1 and mlx_lm 0.11.0

It should be reusing the memory in the cache. There are two things strange here:

  1. Using an extra 26GB for generation doesn't makes sense unless your prompts are very very long. Is it a quantized model?
  2. Once it finishes the first generation it should reuse memory already in the cache.

Also, how are you measuring memory use?

I am not using a quantized model and my prompts are sure very very long (average 1800 tokens per prompt).
I also think it is strange that the program continue to ask for more memory after first generation instead of reusing cache.
And I use activity monitor to measure the memory usage

@awni and @alexC-nonsense4k ,

I've actually noticed this issue as well. I made small little program to demonstrate the problem and in the process tracked it down to the MLX caching of memory. I'm not sure if there is some memory leak going on where MLX is caching memory and never reusing it or what, but the high memory utilization in my case is definitely because of MLX memory caching.

For the details on how I'm tracking memory, please see the code I published at this repo: https://github.com/kerekovskik/mlx-memory-usage-check

I've provided the textfile that I used for the prompt in the above linked repo so you should be able to reproduce this on your machine by following the readme and are able to look at how I got the numbers that I got.

The program I wrote to check uses Meta-Llama-3 8B (full precision) to process a 7522 token prompt and I get the below outputs for when I dont set mlx cache limit and when I do set it. Notice the high MLX Cache memory usage.

Run LLM Genereation WITHOUT Cache Limit

This is the default behavior of MLX. Notice the high cache memory usage.

python3  memory_check.py --input 3blue1brown_attention.txt --output response.txt --model /Users/kerekovskik/hf/Meta-Llama-3-8B-Instruct-MLX --cache-limit 0
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
MLX Memory Pre-Generation: 15316.51 MB
MLX Cache Memory Pre-Generation: 0.00 MB
MLX Total Memory Pre-Generation: 15316.51 MB
MLX Memory Post-Generation: 16150.47 MB
MLX Cache Memory Post-Generation: 53876.89 MB
MLX Total Memory Post-Generation: 70027.36 MB
Input Tokens: 7522
Peak MLX Memory: 29326.617294311523 MB

During the course of generating a response, it reserved ~53GB of memory in the MLX cache.

Run LLM Generation WITH Cache Limit

Memory usage is successfully controlled by setting a cache limit for MLX.

python3  memory_check.py --input 3blue1brown_attention.txt --output response.txt --model /Users/kerekovskik/hf/Meta-Llama-3-8B-Instruct-MLX --cache-limit 1
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
MLX Memory Pre-Generation: 15316.51 MB
MLX Cache Memory Pre-Generation: 0.00 MB
MLX Total Memory Pre-Generation: 15316.51 MB
MLX Memory Post-Generation: 16358.98 MB
MLX Cache Memory Post-Generation: 1024.51 MB
MLX Total Memory Post-Generation: 17383.49 MB
Input Tokens: 7522
Peak MLX Memory: 29326.617294311523 MB

My example is from mlx 0.12.0, mlx_lm 0.12.0 on Macbook Pro 96GB M2

It's worth noting that I've noticed this memory behavior on a lot of models but my test was specific to Llama 3 8B. I've noticed this high memory usage when using the mlx_lm.generate and the mlx_lm.server as well.

@kerekovskik Today I do more experiments on this issue. I guess the main reason why mlx_lm allocate so much memory from system is because the kv cache which costs a lot of memory to speed up the calculation of attention. I suppose in python or in mlx_lm memory arrangement does not occur too frequently. So there may be too many memory fragmentation that force the program to allocate more memory from system. But with cache limit set, the program will arrange the memory when it hits the limit instead of allocating more.

@alexC-nonsense4k ,

I'm not sure I follow the memory fragmentation part of your response. For what it's worth I've noticed this behavior even when almost nothing is running on my machine after a reboot, hence why I don't think memory fragmentation is a factor here. I wouldn't expect MLX to require 53GB of memory in cache for a 7k token prompt. It seems that MLX is not properly cleaning up its cache allocations or reusing already allocated memory in the cache since it just accrues more and more cache during token generation, which is why forcing it to cleanup after a given quota by setting the cache limit mitigates the issue.

@awni , I hope my repo helps to show the issue the issue. https://github.com/kerekovskik/mlx-memory-usage-check

I'd be happy to contribute some code to mlx-examples code to provide a setting for MLX cache settings on the generate function and on the MLX_LM server code as well - I've seen this behavior in both places. Unless there is some fix for this within mlx core that you're pursuing.

I feel like I maybe experiencing the similar issue when trying to fine-tune the Llama 3 8b model (bf16). The peak memory slowly grows up to 300GB and eventually breaks the fine-tuning.
Some lora params:

  • maximum sequence length of 8096
  • batch size 1
Trainable parameters: 4.178% (335.544M/8030.261M)
Loading datasets
Training
Starting training..., iters: 113000
Iter 1: Val loss 1.343, Val took 29.293s
Iter 10: Train loss 0.986, Learning Rate 1.000e-06, It/sec 0.381, Tokens/sec 174.114, Trained Tokens 4568, Peak mem 28.349 GB
Iter 20: Train loss 0.490, Learning Rate 1.000e-06, It/sec 0.368, Tokens/sec 144.697, Trained Tokens 8495, Peak mem 29.992 GB
Iter 30: Train loss 0.637, Learning Rate 1.000e-06, It/sec 0.262, Tokens/sec 164.795, Trained Tokens 14794, Peak mem 45.742 GB
Iter 40: Train loss 0.586, Learning Rate 1.000e-06, It/sec 0.218, Tokens/sec 171.634, Trained Tokens 22678, Peak mem 45.742 GB
Iter 50: Train loss 0.554, Learning Rate 1.000e-06, It/sec 0.276, Tokens/sec 157.820, Trained Tokens 28397, Peak mem 45.742 GB
Iter 60: Train loss 0.496, Learning Rate 1.000e-06, It/sec 0.260, Tokens/sec 159.653, Trained Tokens 34543, Peak mem 45.742 GB
Iter 70: Train loss 0.509, Learning Rate 1.000e-06, It/sec 0.275, Tokens/sec 163.460, Trained Tokens 40495, Peak mem 45.742 GB
Iter 80: Train loss 0.454, Learning Rate 1.000e-06, It/sec 0.231, Tokens/sec 166.602, Trained Tokens 47692, Peak mem 46.582 GB
Iter 90: Train loss 0.507, Learning Rate 1.000e-06, It/sec 0.289, Tokens/sec 154.947, Trained Tokens 53045, Peak mem 46.582 GB
Iter 100: Train loss 0.541, Learning Rate 1.000e-06, It/sec 0.240, Tokens/sec 162.060, Trained Tokens 59799, Peak mem 46.582 GB
Iter 110: Train loss 0.498, Learning Rate 1.000e-06, It/sec 0.267, Tokens/sec 162.180, Trained Tokens 65882, Peak mem 46.582 GB
Iter 120: Train loss 0.580, Learning Rate 1.000e-06, It/sec 0.302, Tokens/sec 152.723, Trained Tokens 70939, Peak mem 46.582 GB
Iter 130: Train loss 0.531, Learning Rate 1.000e-06, It/sec 0.327, Tokens/sec 149.257, Trained Tokens 75501, Peak mem 46.582 GB
Iter 140: Train loss 0.569, Learning Rate 1.000e-06, It/sec 0.232, Tokens/sec 164.464, Trained Tokens 82576, Peak mem 46.582 GB
Iter 150: Train loss 0.549, Learning Rate 1.000e-06, It/sec 0.302, Tokens/sec 150.769, Trained Tokens 87566, Peak mem 46.582 GB
Iter 160: Train loss 0.360, Learning Rate 1.000e-06, It/sec 0.308, Tokens/sec 151.939, Trained Tokens 92507, Peak mem 46.582 GB
Iter 170: Train loss 0.418, Learning Rate 1.000e-06, It/sec 0.275, Tokens/sec 151.748, Trained Tokens 98027, Peak mem 47.609 GB
Iter 180: Train loss 0.381, Learning Rate 1.000e-06, It/sec 0.315, Tokens/sec 149.350, Trained Tokens 102763, Peak mem 47.609 GB
Iter 190: Train loss 0.396, Learning Rate 1.000e-06, It/sec 0.344, Tokens/sec 146.310, Trained Tokens 107013, Peak mem 47.609 GB
Iter 200: Train loss 0.775, Learning Rate 1.000e-06, It/sec 0.224, Tokens/sec 166.830, Trained Tokens 114475, Peak mem 47.609 GB
Iter 200: Val loss 0.515, Val took 31.350s
Iter 210: Train loss 0.484, Learning Rate 1.000e-06, It/sec 0.291, Tokens/sec 155.553, Trained Tokens 119827, Peak mem 47.609 GB
Iter 220: Train loss 0.412, Learning Rate 1.000e-06, It/sec 0.294, Tokens/sec 152.979, Trained Tokens 125022, Peak mem 47.609 GB
Iter 230: Train loss 0.349, Learning Rate 1.000e-06, It/sec 0.375, Tokens/sec 136.702, Trained Tokens 128667, Peak mem 47.609 GB
Iter 240: Train loss 0.443, Learning Rate 1.000e-06, It/sec 0.274, Tokens/sec 157.527, Trained Tokens 134412, Peak mem 47.609 GB
Iter 250: Train loss 0.351, Learning Rate 1.000e-06, It/sec 0.301, Tokens/sec 155.320, Trained Tokens 139568, Peak mem 47.609 GB
Iter 260: Train loss 0.422, Learning Rate 1.000e-06, It/sec 0.255, Tokens/sec 155.295, Trained Tokens 145665, Peak mem 47.609 GB
Iter 270: Train loss 0.286, Learning Rate 1.000e-06, It/sec 0.372, Tokens/sec 141.298, Trained Tokens 149466, Peak mem 47.609 GB
Iter 280: Train loss 0.406, Learning Rate 1.000e-06, It/sec 0.246, Tokens/sec 162.882, Trained Tokens 156100, Peak mem 47.803 GB
Iter 290: Train loss 0.453, Learning Rate 1.000e-06, It/sec 0.295, Tokens/sec 157.641, Trained Tokens 161439, Peak mem 47.803 GB
Iter 300: Train loss 0.437, Learning Rate 1.000e-06, It/sec 0.206, Tokens/sec 166.620, Trained Tokens 169515, Peak mem 56.919 GB
Iter 310: Train loss 0.419, Learning Rate 1.000e-06, It/sec 0.284, Tokens/sec 153.968, Trained Tokens 174928, Peak mem 56.919 GB
Iter 320: Train loss 0.418, Learning Rate 1.000e-06, It/sec 0.292, Tokens/sec 154.294, Trained Tokens 180216, Peak mem 56.919 GB
Iter 330: Train loss 0.417, Learning Rate 1.000e-06, It/sec 0.298, Tokens/sec 150.434, Trained Tokens 185269, Peak mem 56.919 GB
Iter 340: Train loss 0.389, Learning Rate 1.000e-06, It/sec 0.288, Tokens/sec 154.552, Trained Tokens 190637, Peak mem 56.919 GB
Iter 350: Train loss 0.441, Learning Rate 1.000e-06, It/sec 0.262, Tokens/sec 157.700, Trained Tokens 196648, Peak mem 56.919 GB
Iter 360: Train loss 0.516, Learning Rate 1.000e-06, It/sec 0.247, Tokens/sec 159.186, Trained Tokens 203092, Peak mem 61.846 GB
Iter 370: Train loss 0.493, Learning Rate 1.000e-06, It/sec 0.267, Tokens/sec 155.597, Trained Tokens 208911, Peak mem 61.846 GB
Iter 380: Train loss 0.262, Learning Rate 1.000e-06, It/sec 0.344, Tokens/sec 144.697, Trained Tokens 213118, Peak mem 61.846 GB
Iter 390: Train loss 0.582, Learning Rate 1.000e-06, It/sec 0.250, Tokens/sec 160.472, Trained Tokens 219547, Peak mem 61.846 GB
Iter 400: Train loss 0.568, Learning Rate 1.000e-06, It/sec 0.245, Tokens/sec 161.165, Trained Tokens 226128, Peak mem 61.846 GB
Iter 400: Val loss 0.498, Val took 34.168s
Iter 410: Train loss 0.336, Learning Rate 1.000e-06, It/sec 0.356, Tokens/sec 142.795, Trained Tokens 230144, Peak mem 61.846 GB
Iter 420: Train loss 0.311, Learning Rate 1.000e-06, It/sec 0.283, Tokens/sec 153.967, Trained Tokens 235587, Peak mem 61.846 GB
Iter 430: Train loss 0.486, Learning Rate 1.000e-06, It/sec 0.251, Tokens/sec 160.670, Trained Tokens 241979, Peak mem 61.846 GB
Iter 440: Train loss 0.470, Learning Rate 1.000e-06, It/sec 0.265, Tokens/sec 158.942, Trained Tokens 247981, Peak mem 61.846 GB
Iter 450: Train loss 0.371, Learning Rate 1.000e-06, It/sec 0.264, Tokens/sec 156.952, Trained Tokens 253926, Peak mem 63.603 GB
Iter 460: Train loss 0.363, Learning Rate 1.000e-06, It/sec 0.281, Tokens/sec 157.530, Trained Tokens 259537, Peak mem 63.603 GB
Iter 470: Train loss 0.547, Learning Rate 1.000e-06, It/sec 0.215, Tokens/sec 165.215, Trained Tokens 267216, Peak mem 73.275 GB
Iter 480: Train loss 0.395, Learning Rate 1.000e-06, It/sec 0.350, Tokens/sec 141.575, Trained Tokens 271261, Peak mem 73.275 GB
Iter 490: Train loss 0.487, Learning Rate 1.000e-06, It/sec 0.257, Tokens/sec 156.983, Trained Tokens 277374, Peak mem 73.275 GB
Iter 500: Train loss 0.350, Learning Rate 1.000e-06, It/sec 0.315, Tokens/sec 148.127, Trained Tokens 282079, Peak mem 73.275 GB
Iter 510: Train loss 0.420, Learning Rate 1.000e-06, It/sec 0.317, Tokens/sec 149.263, Trained Tokens 286788, Peak mem 73.275 GB
Iter 520: Train loss 0.307, Learning Rate 1.000e-06, It/sec 0.352, Tokens/sec 144.492, Trained Tokens 290897, Peak mem 73.275 GB
Iter 530: Train loss 0.584, Learning Rate 1.000e-06, It/sec 0.258, Tokens/sec 159.787, Trained Tokens 297085, Peak mem 73.275 GB
Iter 540: Train loss 0.396, Learning Rate 1.000e-06, It/sec 0.296, Tokens/sec 155.352, Trained Tokens 302340, Peak mem 73.275 GB
Iter 550: Train loss 0.228, Learning Rate 1.000e-06, It/sec 0.429, Tokens/sec 135.418, Trained Tokens 305497, Peak mem 73.275 GB
Iter 560: Train loss 0.264, Learning Rate 1.000e-06, It/sec 0.269, Tokens/sec 151.848, Trained Tokens 311148, Peak mem 76.803 GB
Iter 570: Train loss 0.389, Learning Rate 1.000e-06, It/sec 0.240, Tokens/sec 162.426, Trained Tokens 317925, Peak mem 76.803 GB
Iter 580: Train loss 0.475, Learning Rate 1.000e-06, It/sec 0.292, Tokens/sec 156.218, Trained Tokens 323283, Peak mem 76.803 GB
Iter 590: Train loss 0.357, Learning Rate 1.000e-06, It/sec 0.337, Tokens/sec 148.189, Trained Tokens 327682, Peak mem 76.803 GB
Iter 600: Train loss 0.459, Learning Rate 1.000e-06, It/sec 0.288, Tokens/sec 154.027, Trained Tokens 333033, Peak mem 76.803 GB
Iter 600: Val loss 0.457, Val took 28.372s
Iter 610: Train loss 0.353, Learning Rate 1.000e-06, It/sec 0.342, Tokens/sec 147.739, Trained Tokens 337358, Peak mem 76.803 GB
Iter 620: Train loss 0.458, Learning Rate 1.000e-06, It/sec 0.239, Tokens/sec 162.011, Trained Tokens 344141, Peak mem 76.803 GB
Iter 630: Train loss 0.395, Learning Rate 1.000e-06, It/sec 0.328, Tokens/sec 145.892, Trained Tokens 348593, Peak mem 76.803 GB
Iter 640: Train loss 0.219, Learning Rate 1.000e-06, It/sec 0.304, Tokens/sec 150.470, Trained Tokens 353548, Peak mem 76.803 GB
Iter 650: Train loss 0.405, Learning Rate 1.000e-06, It/sec 0.290, Tokens/sec 154.072, Trained Tokens 358866, Peak mem 76.803 GB
Iter 660: Train loss 0.509, Learning Rate 1.000e-06, It/sec 0.267, Tokens/sec 157.297, Trained Tokens 364755, Peak mem 76.803 GB
Iter 670: Train loss 0.279, Learning Rate 1.000e-06, It/sec 0.336, Tokens/sec 146.709, Trained Tokens 369123, Peak mem 76.803 GB
Iter 680: Train loss 0.369, Learning Rate 1.000e-06, It/sec 0.326, Tokens/sec 150.740, Trained Tokens 373752, Peak mem 76.803 GB
Iter 690: Train loss 0.336, Learning Rate 1.000e-06, It/sec 0.329, Tokens/sec 147.995, Trained Tokens 378248, Peak mem 76.803 GB
Iter 700: Train loss 0.423, Learning Rate 1.000e-06, It/sec 0.342, Tokens/sec 146.086, Trained Tokens 382521, Peak mem 76.803 GB
Iter 710: Train loss 0.286, Learning Rate 1.000e-06, It/sec 0.318, Tokens/sec 146.715, Trained Tokens 387129, Peak mem 76.803 GB
Iter 720: Train loss 0.541, Learning Rate 1.000e-06, It/sec 0.195, Tokens/sec 167.106, Trained Tokens 395711, Peak mem 84.082 GB
Iter 730: Train loss 0.301, Learning Rate 1.000e-06, It/sec 0.259, Tokens/sec 155.923, Trained Tokens 401735, Peak mem 84.082 GB
Iter 740: Train loss 0.499, Learning Rate 1.000e-06, It/sec 0.271, Tokens/sec 156.532, Trained Tokens 407513, Peak mem 84.082 GB
Iter 750: Train loss 0.385, Learning Rate 1.000e-06, It/sec 0.330, Tokens/sec 145.641, Trained Tokens 411925, Peak mem 84.082 GB
Iter 760: Train loss 0.401, Learning Rate 1.000e-06, It/sec 0.319, Tokens/sec 150.321, Trained Tokens 416631, Peak mem 84.082 GB
Iter 770: Train loss 0.406, Learning Rate 1.000e-06, It/sec 0.294, Tokens/sec 153.838, Trained Tokens 421858, Peak mem 84.082 GB
Iter 780: Train loss 0.417, Learning Rate 1.000e-06, It/sec 0.222, Tokens/sec 167.364, Trained Tokens 429393, Peak mem 84.082 GB
Iter 790: Train loss 0.453, Learning Rate 1.000e-06, It/sec 0.266, Tokens/sec 157.032, Trained Tokens 435294, Peak mem 84.082 GB
Iter 800: Train loss 0.355, Learning Rate 1.000e-06, It/sec 0.321, Tokens/sec 148.496, Trained Tokens 439921, Peak mem 84.082 GB
Iter 800: Val loss 0.538, Val took 38.868s
Iter 810: Train loss 0.306, Learning Rate 1.000e-06, It/sec 0.310, Tokens/sec 147.321, Trained Tokens 444674, Peak mem 84.082 GB
Iter 820: Train loss 0.303, Learning Rate 1.000e-06, It/sec 0.232, Tokens/sec 155.179, Trained Tokens 451365, Peak mem 100.759 GB
Iter 830: Train loss 0.341, Learning Rate 1.000e-06, It/sec 0.296, Tokens/sec 153.921, Trained Tokens 456563, Peak mem 100.759 GB
Iter 840: Train loss 0.161, Learning Rate 1.000e-06, It/sec 0.353, Tokens/sec 141.079, Trained Tokens 460565, Peak mem 100.759 GB
Iter 850: Train loss 0.349, Learning Rate 1.000e-06, It/sec 0.258, Tokens/sec 163.112, Trained Tokens 466896, Peak mem 100.759 GB
Iter 860: Train loss 0.413, Learning Rate 1.000e-06, It/sec 0.247, Tokens/sec 159.343, Trained Tokens 473342, Peak mem 100.759 GB
Iter 870: Train loss 0.271, Learning Rate 1.000e-06, It/sec 0.319, Tokens/sec 144.595, Trained Tokens 477870, Peak mem 100.759 GB
Iter 880: Train loss 0.320, Learning Rate 1.000e-06, It/sec 0.358, Tokens/sec 138.773, Trained Tokens 481743, Peak mem 100.759 GB
Iter 890: Train loss 0.413, Learning Rate 1.000e-06, It/sec 0.403, Tokens/sec 136.836, Trained Tokens 485136, Peak mem 100.759 GB
Iter 900: Train loss 0.236, Learning Rate 1.000e-06, It/sec 0.382, Tokens/sec 134.347, Trained Tokens 488652, Peak mem 100.759 GB
Iter 910: Train loss 0.456, Learning Rate 1.000e-06, It/sec 0.273, Tokens/sec 157.334, Trained Tokens 494405, Peak mem 100.759 GB
Iter 920: Train loss 0.269, Learning Rate 1.000e-06, It/sec 0.340, Tokens/sec 141.056, Trained Tokens 498553, Peak mem 100.759 GB
Iter 930: Train loss 0.411, Learning Rate 1.000e-06, It/sec 0.243, Tokens/sec 160.938, Trained Tokens 505173, Peak mem 100.759 GB
Iter 940: Train loss 0.299, Learning Rate 1.000e-06, It/sec 0.310, Tokens/sec 151.622, Trained Tokens 510059, Peak mem 100.759 GB
Iter 950: Train loss 0.222, Learning Rate 1.000e-06, It/sec 0.402, Tokens/sec 137.520, Trained Tokens 513478, Peak mem 100.759 GB
Iter 960: Train loss 0.370, Learning Rate 1.000e-06, It/sec 0.233, Tokens/sec 161.763, Trained Tokens 520416, Peak mem 100.759 GB
Iter 970: Train loss 0.357, Learning Rate 1.000e-06, It/sec 0.270, Tokens/sec 156.157, Trained Tokens 526200, Peak mem 100.759 GB
Iter 980: Train loss 0.511, Learning Rate 1.000e-06, It/sec 0.250, Tokens/sec 162.201, Trained Tokens 532697, Peak mem 100.759 GB
Iter 990: Train loss 0.474, Learning Rate 1.000e-06, It/sec 0.256, Tokens/sec 158.522, Trained Tokens 538879, Peak mem 100.759 GB
Iter 1000: Train loss 0.431, Learning Rate 1.000e-06, It/sec 0.279, Tokens/sec 156.410, Trained Tokens 544488, Peak mem 100.759 GB
Iter 1000: Val loss 0.448, Val took 25.538s
Iter 1000: Saved adapter weights to adapters/adapters.safetensors and adapters/0001000_adapters.safetensors.
Iter 1010: Train loss 0.369, Learning Rate 1.000e-06, It/sec 0.248, Tokens/sec 143.717, Trained Tokens 550285, Peak mem 100.759 GB
Iter 1020: Train loss 0.321, Learning Rate 1.000e-06, It/sec 0.322, Tokens/sec 147.624, Trained Tokens 554867, Peak mem 100.759 GB
Iter 1030: Train loss 0.492, Learning Rate 1.000e-06, It/sec 0.375, Tokens/sec 141.280, Trained Tokens 558630, Peak mem 100.759 GB
Iter 1040: Train loss 0.455, Learning Rate 1.000e-06, It/sec 0.234, Tokens/sec 162.958, Trained Tokens 565587, Peak mem 100.759 GB
Iter 1050: Train loss 0.494, Learning Rate 1.000e-06, It/sec 0.203, Tokens/sec 168.376, Trained Tokens 573869, Peak mem 100.759 GB
Iter 1060: Train loss 0.477, Learning Rate 1.000e-06, It/sec 0.239, Tokens/sec 159.817, Trained Tokens 580546, Peak mem 100.759 GB
Iter 1070: Train loss 0.248, Learning Rate 1.000e-06, It/sec 0.373, Tokens/sec 134.322, Trained Tokens 584149, Peak mem 100.759 GB
Iter 1080: Train loss 0.273, Learning Rate 1.000e-06, It/sec 0.351, Tokens/sec 143.903, Trained Tokens 588254, Peak mem 100.759 GB
Iter 1090: Train loss 0.405, Learning Rate 1.000e-06, It/sec 0.291, Tokens/sec 153.531, Trained Tokens 593531, Peak mem 100.759 GB
Iter 1100: Train loss 0.278, Learning Rate 1.000e-06, It/sec 0.296, Tokens/sec 153.472, Trained Tokens 598722, Peak mem 100.759 GB
Iter 1110: Train loss 0.490, Learning Rate 1.000e-06, It/sec 0.271, Tokens/sec 154.934, Trained Tokens 604444, Peak mem 100.759 GB
Iter 1120: Train loss 0.400, Learning Rate 1.000e-06, It/sec 0.325, Tokens/sec 146.586, Trained Tokens 608949, Peak mem 100.759 GB
Iter 1130: Train loss 0.365, Learning Rate 1.000e-06, It/sec 0.322, Tokens/sec 148.295, Trained Tokens 613550, Peak mem 100.759 GB
Iter 1140: Train loss 0.440, Learning Rate 1.000e-06, It/sec 0.177, Tokens/sec 168.750, Trained Tokens 623092, Peak mem 105.648 GB
Iter 1150: Train loss 0.303, Learning Rate 1.000e-06, It/sec 0.219, Tokens/sec 163.935, Trained Tokens 630565, Peak mem 105.648 GB
Iter 1160: Train loss 0.218, Learning Rate 1.000e-06, It/sec 0.321, Tokens/sec 145.202, Trained Tokens 635086, Peak mem 105.648 GB
Iter 1170: Train loss 0.388, Learning Rate 1.000e-06, It/sec 0.274, Tokens/sec 158.218, Trained Tokens 640867, Peak mem 105.648 GB
Iter 1180: Train loss 0.432, Learning Rate 1.000e-06, It/sec 0.323, Tokens/sec 149.960, Trained Tokens 645513, Peak mem 105.648 GB
Iter 1190: Train loss 0.612, Learning Rate 1.000e-06, It/sec 0.219, Tokens/sec 165.295, Trained Tokens 653060, Peak mem 105.648 GB
Iter 1200: Train loss 0.207, Learning Rate 1.000e-06, It/sec 0.315, Tokens/sec 146.334, Trained Tokens 657700, Peak mem 105.648 GB
Iter 1200: Val loss 0.502, Val took 25.981s
Iter 1210: Train loss 0.411, Learning Rate 1.000e-06, It/sec 0.319, Tokens/sec 148.936, Trained Tokens 662373, Peak mem 105.648 GB
Iter 1220: Train loss 0.363, Learning Rate 1.000e-06, It/sec 0.327, Tokens/sec 147.686, Trained Tokens 666892, Peak mem 105.648 GB
Iter 1230: Train loss 0.364, Learning Rate 1.000e-06, It/sec 0.289, Tokens/sec 157.080, Trained Tokens 672319, Peak mem 105.648 GB
Iter 1240: Train loss 0.516, Learning Rate 1.000e-06, It/sec 0.290, Tokens/sec 151.550, Trained Tokens 677544, Peak mem 105.648 GB
Iter 1250: Train loss 0.451, Learning Rate 1.000e-06, It/sec 0.241, Tokens/sec 163.533, Trained Tokens 684339, Peak mem 105.648 GB
Iter 1260: Train loss 0.280, Learning Rate 1.000e-06, It/sec 0.308, Tokens/sec 152.698, Trained Tokens 689292, Peak mem 105.648 GB
Iter 1270: Train loss 0.407, Learning Rate 1.000e-06, It/sec 0.259, Tokens/sec 157.758, Trained Tokens 695372, Peak mem 105.648 GB
Iter 1280: Train loss 0.374, Learning Rate 1.000e-06, It/sec 0.301, Tokens/sec 150.997, Trained Tokens 700389, Peak mem 105.648 GB
Iter 1290: Train loss 0.225, Learning Rate 1.000e-06, It/sec 0.369, Tokens/sec 137.591, Trained Tokens 704122, Peak mem 105.648 GB
Iter 1300: Train loss 0.272, Learning Rate 1.000e-06, It/sec 0.294, Tokens/sec 152.828, Trained Tokens 709323, Peak mem 105.648 GB
Iter 1310: Train loss 0.439, Learning Rate 1.000e-06, It/sec 0.286, Tokens/sec 152.748, Trained Tokens 714657, Peak mem 105.648 GB
Iter 1320: Train loss 0.523, Learning Rate 1.000e-06, It/sec 0.245, Tokens/sec 162.473, Trained Tokens 721302, Peak mem 105.648 GB
Iter 1330: Train loss 0.279, Learning Rate 1.000e-06, It/sec 0.322, Tokens/sec 143.763, Trained Tokens 725770, Peak mem 105.648 GB
Iter 1340: Train loss 0.421, Learning Rate 1.000e-06, It/sec 0.320, Tokens/sec 148.320, Trained Tokens 730399, Peak mem 105.648 GB
Iter 1350: Train loss 0.458, Learning Rate 1.000e-06, It/sec 0.214, Tokens/sec 165.950, Trained Tokens 738151, Peak mem 105.648 GB
Iter 1360: Train loss 0.410, Learning Rate 1.000e-06, It/sec 0.275, Tokens/sec 153.886, Trained Tokens 743738, Peak mem 105.648 GB
Iter 1370: Train loss 0.376, Learning Rate 1.000e-06, It/sec 0.300, Tokens/sec 149.599, Trained Tokens 748730, Peak mem 105.648 GB
Iter 1380: Train loss 0.382, Learning Rate 1.000e-06, It/sec 0.316, Tokens/sec 150.644, Trained Tokens 753494, Peak mem 105.648 GB
Iter 1390: Train loss 0.464, Learning Rate 1.000e-06, It/sec 0.289, Tokens/sec 155.151, Trained Tokens 758870, Peak mem 105.648 GB
Iter 1400: Train loss 0.379, Learning Rate 1.000e-06, It/sec 0.331, Tokens/sec 146.173, Trained Tokens 763284, Peak mem 105.648 GB
Iter 1400: Val loss 0.471, Val took 32.143s
Iter 1410: Train loss 0.317, Learning Rate 1.000e-06, It/sec 0.340, Tokens/sec 148.325, Trained Tokens 767644, Peak mem 105.648 GB
Iter 1420: Train loss 0.623, Learning Rate 1.000e-06, It/sec 0.217, Tokens/sec 166.303, Trained Tokens 775300, Peak mem 105.648 GB
Iter 1430: Train loss 0.289, Learning Rate 1.000e-06, It/sec 0.334, Tokens/sec 147.963, Trained Tokens 779731, Peak mem 105.648 GB
Iter 1440: Train loss 0.386, Learning Rate 1.000e-06, It/sec 0.249, Tokens/sec 161.376, Trained Tokens 786205, Peak mem 105.648 GB
Iter 1450: Train loss 0.304, Learning Rate 1.000e-06, It/sec 0.255, Tokens/sec 159.659, Trained Tokens 792464, Peak mem 105.648 GB
Iter 1460: Train loss 0.212, Learning Rate 1.000e-06, It/sec 0.324, Tokens/sec 147.236, Trained Tokens 797008, Peak mem 105.648 GB
Iter 1470: Train loss 0.360, Learning Rate 1.000e-06, It/sec 0.271, Tokens/sec 160.529, Trained Tokens 802931, Peak mem 105.648 GB
Iter 1480: Train loss 0.310, Learning Rate 1.000e-06, It/sec 0.314, Tokens/sec 148.951, Trained Tokens 807673, Peak mem 105.648 GB
Iter 1490: Train loss 0.449, Learning Rate 1.000e-06, It/sec 0.307, Tokens/sec 148.810, Trained Tokens 812519, Peak mem 105.648 GB
Iter 1500: Train loss 0.529, Learning Rate 1.000e-06, It/sec 0.269, Tokens/sec 154.623, Trained Tokens 818263, Peak mem 105.648 GB
Iter 1510: Train loss 0.425, Learning Rate 1.000e-06, It/sec 0.290, Tokens/sec 156.487, Trained Tokens 823654, Peak mem 105.648 GB
Iter 1520: Train loss 0.493, Learning Rate 1.000e-06, It/sec 0.231, Tokens/sec 164.658, Trained Tokens 830796, Peak mem 105.648 GB
Iter 1530: Train loss 0.297, Learning Rate 1.000e-06, It/sec 0.415, Tokens/sec 131.330, Trained Tokens 833961, Peak mem 105.648 GB
Iter 1540: Train loss 0.429, Learning Rate 1.000e-06, It/sec 0.169, Tokens/sec 169.992, Trained Tokens 844027, Peak mem 110.028 GB
Iter 1550: Train loss 0.527, Learning Rate 1.000e-06, It/sec 0.227, Tokens/sec 161.510, Trained Tokens 851145, Peak mem 110.028 GB
Iter 1560: Train loss 0.425, Learning Rate 1.000e-06, It/sec 0.308, Tokens/sec 148.905, Trained Tokens 855980, Peak mem 110.028 GB
Iter 1570: Train loss 0.633, Learning Rate 1.000e-06, It/sec 0.190, Tokens/sec 169.923, Trained Tokens 864905, Peak mem 110.365 GB
Iter 1580: Train loss 0.177, Learning Rate 1.000e-06, It/sec 0.418, Tokens/sec 131.500, Trained Tokens 868052, Peak mem 110.365 GB
Iter 1590: Train loss 0.491, Learning Rate 1.000e-06, It/sec 0.291, Tokens/sec 152.670, Trained Tokens 873294, Peak mem 110.365 GB
Iter 1600: Train loss 0.508, Learning Rate 1.000e-06, It/sec 0.249, Tokens/sec 158.610, Trained Tokens 879667, Peak mem 110.365 GB
Iter 1600: Val loss 0.540, Val took 32.750s
Iter 1610: Train loss 0.505, Learning Rate 1.000e-06, It/sec 0.282, Tokens/sec 158.663, Trained Tokens 885293, Peak mem 110.365 GB

The Peak mem gets to 300GB? That's very unexpected unless the sequence length is long. The peak memory is not related to cache behavior at all, it's really just the largest working set size needed to execute the graph evaluation. So this doesn't sound like the same issue as this thread.

@mzbac would you mind opening a separate issue and include details on the model / command you are using? We should check if 300GB makes sense for the model size and sequence length you have.

@alexC-nonsense4k and @kerekovskik thanks for the info and sorry for the delay in investigating this. It's definitely something we plan to look into.

For now one possible workaround is to clear the cache after each generation mx.metal.clear_cache() (needs mlx==0.12). Hopefully we can find a more elegant solution.

@awni ,

No problem on the delay, I wasn't actually expecting a response on the weekend, I appreciate you responding.

Doing a clear_cache() would help in clearing up the 50+GB of cache memory that the single generation accrued, but it doesn't resolve my issue of a single mlx_lm.generate call taking up inordinate amounts of memory. The only solution that I've found so far has been to explicitly set a finite cache limit prior to calling mlx_lm.generate(). Please check the github repo I made that demos out this workaround.

Is the problem perhaps that MLX is not setting a good default cache limit by default without user intervention? I first noticed this issue in back in MLX 0.10.0 when the mlx_cache_limit by default was 0, which signified unlimited cache size in that version of MLX.

When I check the docs on mlx cache size for 0.12.0 at https://ml-explore.github.io/mlx/build/html/python/_autosummary/mlx.core.metal.set_cache_limit.html I see that it states "The cache limit defaults to the memory limit. See set_memory_limit() for more details."

and when I go further and check the set_memory_limit docs they say "The memory limit defaults to 1.5 times the maximum recommended working set size reported by the device."

How does one check the 'work set size' and determine what that memory limit would be at runtime and by extension what the default cache limit would be? I dont see any ways to get the cache limit settings, only ways to set them. I am suspicious that this default value is extremely high or that, alternatively, despite what the documentation states, the default behavior is to not limit the cache size like it used to be in MLX 0.10.0.

I personally would like to have the ability to provide a cache limit to the generate function and to be able to provide it as a CLI command flag to the mlx_lm.serve as a workaround for this issue and also just in general to control memory usage. Please let me know if you'd be open to that functionality being put into a PR, I'd be happy to code it up and submit one.

@awni Thanks to you reply. Hope this problem can be solved soon.

How does one check the 'work set size' and determine what that memory limit would be at runtime and by extension what the default cache limit would be?

You can check it like this:

old_limit = mx.metal.set_memory_limit(0)
print(old_limit)

The default may be too high, but that's one reason we made those limits configurable.

I personally would like to have the ability to provide a cache limit to the generate function and to be able to provide it as a CLI command flag to the mlx_lm.serve as a workaround for this issue and also just in general to control memory usage. Please let me know if you'd be open to that functionality being put into a PR, I'd be happy to code it up and submit one.

I'm open to it. Though you can always set the cache limit outside in the calling code before you call generate. For the CLI maybe it is more useful to control the cache size in MLX LM with a flag.

Thanks, I will check to see what the working set size is on my device and see if it is defaulting to too high a value.

I will work on a PR to add a CLI flag for the mlx_lm.generate.py and mlx_lm.server.py to set a cache size limit. I think it will be very useful on the mlx_lm.server.py code in particular. I will drop a reference to that PR once I've made it.

Thanks again!

Looking at the 7k prompt you shared @kerekovskik, some observations:

  • The cache starts out at about 20GB, just from the prompt.
  • The cache size grows after every ~5-10 tokens. The growth in the cache size is consistent with the KV cache size (about 1GB for this model for 7K prompt).

In terms of improvements:

  • The prompt memory use should be reduced a lot by flash attention ml-explore/mlx#964 to avoid materializing the full N^2 score matrices
  • The generation is kind of adversarial for our memory cache since we grow the memory requirement for the KV cache just enough every few tokens that nothing in the existing cache fits anymore. One way to improve this is with a pre-allocated KV cache: #643. I haven't landed that PR yet because it is slightly slower for short prompts, but I think we should push on it as it helps with long prompts a lot.

In the meantime/in addition I think adding a cache limit option to MLX LM is a good idea.

@awni , I've submitted PR #744 for this functionality. I've added CLI flag --cache-limit-gb for mlx_lm.generate and mlx_lm.server. It defaults to None which doesn't alter existing behavior. Per your comment earlier, for any programmatic use of mlx_lm.generate, the calling function can set an MLX cache if needed prior to calling the generate function.

@awni ,

Regarding the working set size. Thank you for that code snippet. I checked by running the code that you gave me and converting to GB and dividing by 1.5 since the docs at https://ml-explore.github.io/mlx/build/html/python/_autosummary/mlx.core.metal.set_memory_limit.html#mlx.core.metal.set_memory_limit state that the memory limit is set to 1.5 times the maximum working set size.

>>> import mlx.core as mx
>>> old_limit = mx.metal.set_memory_limit(0)
>>> print(old_limit/1024/1024/1024) 
131.8359375
>>> print(old_limit/1024/1024/1024/1.5)
87.890625

So, that shows that I've got a maximum recommended working set size of ~87.89GB and the memory limit is being set to ~131.83GB. As a result of that the default cache size is also being set to ~131.83GB since the default value matches the memory limit. I'm on a Macbook Pro with 96GB of memory with the M2 processor. Does it make sense for the default cache size to be set to a value higher than the amount of memory on the system? With default settings, it basically guarantees that MLX can attempt to use all memory for caching purposes and in practice operates as if there is no limit to the cache size.

It doesn’t make sense at all. The problem is we tuned that multiplier on a machine with smaller RAM to be close to the limit. We need to find a better way to set it. I will look into it.

A few updates:

  • The default memory limits have been changed in MLX core to be always less than the devices total RAM
  • KV cache has landed and really improves cache memory growth for longer generations: see #643
  • Landed the #744 to give user ability to set the cache limit

I think we are good to close this. We'll continue to improve memory use with things like flash attention, but I think this is no longer as acute of an issue.

Thanks for the updates and support, @awni .