Performance regression with CUDA after commit 9c67c277
rgerganov opened this issue · comments
I am observing a performance regression with the CUDA backend after commit 9c67c27.
I used to generate 48 t/s with TinyLLama1.1 before this commit:
$ bin/main -m ../models/tinyllama-1b/ggml-model-f16.gguf -p "Hello, my name is" --repeat-penalty 1.0 -n 64 -ngl 99
...
llama_print_timings: eval time = 1291,42 ms / 63 runs ( 20,50 ms per token, 48,78 tokens per second)
...
After commit 9c67c27 I am getting about 36 t/s without flash attention (which is the default):
$ bin/main -m ../models/tinyllama-1b/ggml-model-f16.gguf -p "Hello, my name is" --repeat-penalty 1.0 -n 64 -ngl 99
...
llama_print_timings: eval time = 1742,03 ms / 63 runs ( 27,65 ms per token, 36,16 tokens per second)
...
With FA enabled I am getting 58 t/s which is great but we shouldn't have this regression with FA disabled.
Does using F32 mask restore the performance? nvm, it's already F32 without FA
diff --git a/llama.cpp b/llama.cpp
index 18d6297c..bc2d6f25 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -6911,7 +6911,7 @@ struct llm_build_context {
}
cb(lctx.inp_KQ_mask, "KQ_mask", -1);
ggml_set_input(lctx.inp_KQ_mask);
- return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16) : lctx.inp_KQ_mask;
+ return lctx.inp_KQ_mask;
}
struct ggml_tensor * build_inp_KQ_pos(bool causal = true) {
What GPU is this (nvidia-smi
)?
This change doesn't fix it.
This is the output from nvidia-smi
:
nvidia-smi
Mon May 13 14:42:56 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.171.04 Driver Version: 535.171.04 CUDA Version: 12.2 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA T1200 Laptop GPU Off | 00000000:01:00.0 Off | N/A |
| N/A 59C P8 7W / 40W | 573MiB / 4096MiB | 22% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
Does this help?
diff --git a/llama.cpp b/llama.cpp
index adbcc07e..e81122db 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -11510,7 +11510,7 @@ static int llama_decode_internal(
// a heuristic, to avoid attending the full cache if it is not yet utilized
// after enough generations, the benefit from this heuristic disappears
// if we start defragmenting the cache, the benefit from this will be more important
- kv_self.n = std::min(kv_self.size, std::max(256u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 256)));
+ kv_self.n = std::min(kv_self.size, std::max(32u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)));
//kv_self.n = llama_kv_cache_cell_max(kv_self);
}
}
Yes, this fixes the regression, I am getting the same tg speed as before.
Can confirm, I can reproduce the regression and slaren's patch fixes it:
build: 952d03d (2770), no changes:
model | backend | ngl | test | t/s |
---|---|---|---|---|
llama 7B Q4_0 | CUDA | 99 | pp 512 | 4074.65 ± 59.65 |
llama 7B Q4_0 | CUDA | 99 | tg 128 | 138.66 ± 0.33 |
build: 9c67c27 (2771), no changes:
model | backend | ngl | test | t/s |
---|---|---|---|---|
llama 7B Q4_0 | CUDA | 99 | pp 512 | 4060.25 ± 64.17 |
llama 7B Q4_0 | CUDA | 99 | tg 128 | 135.92 ± 0.62 |
build: 9c67c27 (2771), pad to 32 instead of 256:
model | backend | ngl | test | t/s |
---|---|---|---|---|
llama 7B Q4_0 | CUDA | 99 | pp 512 | 4061.76 ± 57.34 |
llama 7B Q4_0 | CUDA | 99 | tg 128 | 138.32 ± 0.46 |
I think the way to solve this is to condition the padding on whether or not FlashAttention is being used. With a lower padding value you would need to either do runtime checks or more exponential functions with the CUDA FA kernels both of which would be expensive.
Yeah, I underestimated the effect of that this can have - it's less than 1% on RTX 2060 and Metal. Will push a fix
fixed