ggerganov / llama.cpp

LLM inference in C/C++

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Flash attention implementations do not handle case where value vectors have different dimension from query vectors

fairydreaming opened this issue · comments

For example in ggml.c implementations of ops related to flash attention declare variable D and use it as both dimension of value vector and dimension or key/query vector This will fail for models where query and value vectors have different lengths (for example DeepSeek-V2).

Below are selected fragments of GGML_OP_FLASH_ATTN_EXT op implementation to illustrate the problem.

Creation of result tensor:

llama.cpp/ggml.c

Lines 6792 to 6793 in 51e9d02

int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);

(note that query tensor dimensions are used everywhere, while in reality ne[0] shall be equal to ne[0] of value tensor because the attention output is a linear combination of value vectors.

Definition of variable D:

llama.cpp/ggml.c

Line 15879 in 51e9d02

const int64_t D = neq0;

Assertions all expecting the same length:

llama.cpp/ggml.c

Lines 15889 to 15891 in 51e9d02

GGML_ASSERT(neq0 == D);
GGML_ASSERT(nek0 == D);
GGML_ASSERT(nev0 == D);

Usage of D as a dimension of a value vector:

llama.cpp/ggml.c

Line 15958 in 51e9d02

memset(V16, 0, D*sizeof(ggml_fp16_t));

Usage of D as a dimension of a query vector:

llama.cpp/ggml.c

Lines 15985 to 15987 in 51e9d02

for (int64_t d = 0; d < D; ++d) {
Q16[d] = GGML_FP32_TO_FP16(pq[d]);
}

Suggested solution: create two variables Dq (length of the query vector) and Dv (length of value vector) and use Dq as a query/key vector length and Dv as value vector length. I fixed ggml_compute_forward_flash_attn_ext_f16() this way and it produces correct results (confirmed by running DeepSeek-V2 with -fa option).

I'm not 100% sure if CUDA and Metal implementations are also affected, but it's likely - I also found the same variable D used in the code and comments like "K and V have same shape".

Thanks for reporting that - CUDA and Metal kernels should also be affected. We should fix that, but maybe after DS2 support is merged so we have something to test with

Possibly relevant - #2445 (comment)