Question about FlashAttention and KV-cache
KnowingNothing opened this issue · comments
Hi, I notice that you use KV-cache with FlashAttention in CausalSelfAttention. As far as I am concerned, FlashAttention has already implemented the causal self-attention in its kernels, which means for Q [batch, head, seq_len, d_k] x K [batch, head, seq_len, d_k] , only the lower half of the lower triangular result matrix is computed. But in lit_llama/model.py
, the CausalSelfAttention uses KV-cache, so only Q [batch, head, 1, d_k] x K [batch, head, seq_len, d_k] is passed to FlashAttention. I think this may cause FlashAttention to compute only the first element of result matrix. I just want to confirm if this is correct. Please correct me if I am wrong. Could anybody kindly explain the reason? Thanks.