Lightning-AI / lit-llama

Implementation of the LLaMA language model based on nanoGPT. Supports flash attention, Int8 and GPTQ 4bit quantization, LoRA and LLaMA-Adapter fine-tuning, pre-training. Apache 2.0-licensed.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Question about FlashAttention and KV-cache

KnowingNothing opened this issue · comments

commented

Hi, I notice that you use KV-cache with FlashAttention in CausalSelfAttention. As far as I am concerned, FlashAttention has already implemented the causal self-attention in its kernels, which means for Q [batch, head, seq_len, d_k] x K [batch, head, seq_len, d_k] , only the lower half of the lower triangular result matrix is computed. But in lit_llama/model.py, the CausalSelfAttention uses KV-cache, so only Q [batch, head, 1, d_k] x K [batch, head, seq_len, d_k] is passed to FlashAttention. I think this may cause FlashAttention to compute only the first element of result matrix. I just want to confirm if this is correct. Please correct me if I am wrong. Could anybody kindly explain the reason? Thanks.