Flash attention implementations do not handle case where value vectors have different dimension from query vectors
fairydreaming opened this issue · comments
For example in ggml.c implementations of ops related to flash attention declare variable D and use it as both dimension of value vector and dimension or key/query vector This will fail for models where query and value vectors have different lengths (for example DeepSeek-V2).
Below are selected fragments of GGML_OP_FLASH_ATTN_EXT op implementation to illustrate the problem.
Creation of result tensor:
Lines 6792 to 6793 in 51e9d02
(note that query tensor dimensions are used everywhere, while in reality ne[0] shall be equal to ne[0] of value tensor because the attention output is a linear combination of value vectors.
Definition of variable D:
Line 15879 in 51e9d02
Assertions all expecting the same length:
Lines 15889 to 15891 in 51e9d02
Usage of D as a dimension of a value vector:
Line 15958 in 51e9d02
Usage of D as a dimension of a query vector:
Lines 15985 to 15987 in 51e9d02
Suggested solution: create two variables Dq (length of the query vector) and Dv (length of value vector) and use Dq as a query/key vector length and Dv as value vector length. I fixed ggml_compute_forward_flash_attn_ext_f16() this way and it produces correct results (confirmed by running DeepSeek-V2 with -fa option).
I'm not 100% sure if CUDA and Metal implementations are also affected, but it's likely - I also found the same variable D used in the code and comments like "K and V have same shape".
Thanks for reporting that - CUDA and Metal kernels should also be affected. We should fix that, but maybe after DS2 support is merged so we have something to test with
Possibly relevant - #2445 (comment)