pytorch-labs / gpt-fast

Simple and efficient pytorch-native transformer text generation in <1000 LOC of python.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Help explain "Actually better for Inductor to codegen attention here"

huntzhan opened this issue · comments

Can you help explain this comment? What is the best setup for torch.compile?

gpt-fast/generate.py

Lines 64 to 67 in db7b273

with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here
next_token, next_prob = decode_one_token(
model, cur_token, input_pos, **sampling_kwargs
)

F.scaled_dot_product_attention automatically makes a decision about what backend to dispatch to. For example, it can choose to dispatch to the FlashAttention2 kernel. Or, for example, on platforms where FlashAttention2 is not supported, it can choose to dispatch to the "math" implementation, which is attention implemented using more primitive PyTorch operators.

In the case of decoding, however, the FlashAttention algorithm is not beneficial. In fact, it's actively detrimental. So in this case, it's better to dispatch to more primitive operators, where torch.compile can codegen the kernels from scratch.

Thanks a lot!

Hi @Chillee - Why would you say that Flash attn is actively detrimental for decoding?

During prefill stage it helps to avoid the materialization of attention matrix of shape [2, num_head, bs, input_seq_len, input_seq_len] in global memory which saves time by avoiding doing IO to global memory.

During decoding phase it avoid materialization of attention matrix of shape ~[2, num_head, bs, input_seq_len] in global memory for decoding 1 token. The saving for avoiding doing IO to global memory would scale to ~[2, num_head, bs, input_seq_len, output_seq_len] over the entire output length.

These are huge matrices which take a lot of space. Flash attention leads to model taking less memory in GPU along with getting speedup by avoiding global mem IO (even without reducing FLOPS). Usually the prefill stage is FLOP bound and still the reduction in memory IO gives a good speedup. The decoding phase is usually memory bandwidth bound so reduction in memory IO should be useful?

The big issue is the work partitioning structure. FlashAttention parallelizes among heads, BS, and output_seq_len (i.e. seq_query). In this case, BS and output_seq_len is 1, so the only parallelism is among heads. An A100 GPU has 108 SMs, so it just can't utilize the entire GPU efficiently enough.

The saving for avoiding doing IO to global memory would scale to ~[2, num_head, bs, input_seq_len, output_seq_len] over the entire output length.

output_seq_len is only 1 in this case. And for the low-latency setting, bs is also 1. So your intermediate matrix is size [2, num_head, 1, input_seq_len]. That's not nothing, but it's not a large enough advantage to dwarf issue 1. I would expect FlashDecoding to perform better than Inductor's generated kernel.

Hi @Chillee , looks like torch.compile cannot handle the torch.backends.cuda.sdp_kernel decorator. The function decode_n_tokens, in which the torch.backends.cuda.sdp_kernel decorator is used, is not compiled. Does that mean the aforementioned behavior is not applied?

  File "/home/wden/.local/lib/python3.8/site-packages/torch/_dynamo/exc.py", line 193, in unimplemented
    raise Unsupported(msg)
torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor _GeneratorContextManager call_function <function sdp_kernel at 0x7fd9c5584a60>

from user code:
   File "/data/xxx", line 451, in model_decode_one_token
    with torch.backends.cuda.sdp_kernel(

btw, I replace torch.backends.cuda.sdp_kernel with the following statements:

            torch.backends.cuda.enable_flash_sdp(False)
            torch.backends.cuda.enable_mem_efficient_sdp(False)
            torch.backends.cuda.enable_math_sdp(True)

Inspired by your project, I've successfully applied the optimization strategy to baichuan 13b:
https://github.com/armed-gpt/gpt-blazing
Could I submit a PR to add a reference case to the README?

The function decode_n_tokens, in which the torch.backends.cuda.sdp_kernel decorator is used, is not compiled. Does that mean the aforementioned behavior is not applied?

No, decode_n_tokens calls decode_token, which does have the decorator. And the annotation is still respected there.

The function decode_n_tokens, in which the torch.backends.cuda.sdp_kernel decorator is used, is not compiled. Does that mean the aforementioned behavior is not applied?

No, decode_n_tokens calls decode_token, which does have the decorator. And the annotation is still respected there.

I see...
Since the decode_one_token is not called directly, and due to the deferred compilation, the annotation is still respected.

Inspired by your project, I've successfully applied the optimization strategy to baichuan 13b: https://github.com/armed-gpt/gpt-blazing Could I submit a PR to add a reference case to the README?

Hi @Chillee, I've submitted a PR #48, could you make a comment about it? thanks.