pytorch-labs / gpt-fast

Simple and efficient pytorch-native transformer text generation in <1000 LOC of python.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

torch.compile() with flash decoding ops

rayleizhu opened this issue · comments

I'm trying to replace F.scaled_dot_product_attention with flash decoding kernel for faster inference.

However, while the flash decoding function works well in the eager mode, I cannot make it work with torch.compile(). It seems that torch.comile() does not support such third-party operators. How can I overcome this problem?

My code is like:

...
from xformers import _C_flashattention as  flash_attn_cuda
...

# y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) ->
out, *_ = flash_attn_cuda.fwd(q, k, v, ...)

And the error message with --compile option is:

...
  File "/home/coder/miniconda3/envs/gpt-fast/lib/python3.8/site-packages/torch/_dynamo/variables/base.py", line 340, in call_method
    raise unimplemented(f"call_method {self} {name} {args} {kwargs}")
  File "/home/coder/miniconda3/envs/gpt-fast/lib/python3.8/site-packages/torch/_dynamo/exc.py", line 193, in unimplemented
    raise Unsupported(msg)
torch._dynamo.exc.Unsupported: call_method UserDefinedObjectVariable(fwd) __call__ [TensorVariable(), TensorVariable(), TensorVariable(), ConstantVariable(NoneType), ConstantVariable(float), ConstantVariable(float), ConstantVariable(bool), ConstantVariable(int), ConstantVariable(int), ConstantVariable(bool), ConstantVariable(NoneType)] {}

from user code:
   File "benchmark.py", line 64, in decode_one_token
    logits = model(x, input_pos) # [B, 1, vocab_size]
  File "/home/coder/miniconda3/envs/gpt-fast/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/coder/projects/gpt-fast/model.py", line 388, in forward
    x = layer(x, input_pos, freqs_cis, mask)
  File "/home/coder/miniconda3/envs/gpt-fast/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/coder/projects/gpt-fast/model.py", line 407, in forward
    h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
  File "/home/coder/miniconda3/envs/gpt-fast/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/coder/projects/gpt-fast/model.py", line 477, in forward
    y, lse = flash_attn_forward(q, k, v)
  File "/home/coder/projects/gpt-fast/model.py", line 34, in flash_attn_forward
    out, *_ = flash_attn_cuda.fwd(

BTW, I noticed that you mentioned in the blog that

And even cooler, these kernels are actually faster than the built in alternatives (CuBLAS and FlashAttention2)!

This is unsurprising as flash attention 1 & 2 are designed for training-phase speed (with a large batch size). Flash decoding should be a stronger baseline.

@rayleizhu To make torch.compile work with 3rd party ops, you need to register it. I'll put up an example of how to do this later.

This is unsurprising as flash attention 1 & 2 are designed for training-phase speed (with a large batch size). Flash decoding should be a stronger baseline.

I certainly agree :), and I would expect FlashDecoding to be faster than the torch.compile generated ops. But FlashDecoding is not integrated into PyTorch yet.

mark! :)

I found the examples here.

However, I have another question: is the registration required by torch.cuda.CUDAGraph() or torch._dynamo? Do I still need this registration if I want to define the graph manually with torch.cuda.CUDAGraph() instead of capturing it with Dynamo?

I've tried the torch.library approach, and ran into some problems which I've outlined here: pytorch/pytorch#120441

Your repro works for me with pytorch-nightly. TORCH_COMPILE_DEBUG give me this:

def forward(self, arg0_1: "bf16[1, 2, 2, 4]", arg1_1: "bf16[1, 5, 2, 4]", arg2_1: "bf16[1, 5, 2, 4]", arg3_1: "bf16[1, 1, 2, 4]", arg4_1: "bf16[1, 1, 2, 4]", arg5_1: "i32[1]"):
# File: /home/yifu/pytorch/torch/_dynamo/external_utils.py:25 in inner, code: return fn(*args, **kwargs)
auto_functionalized = torch._higher_order_ops.auto_functionalize.auto_functionalized(torch.ops.mylib.custom_func.default, q = arg0_1, k_cache = arg1_1, v_cache = arg2_1, k = arg3_1, v = arg4_1, cache_seqlens = arg5_1);  arg0_1 = arg3_1 = arg4_1 = arg5_1 = None
getitem: "bf16[1, 2, 2, 4]" = auto_functionalized[0]
getitem_1: "bf16[1, 5, 2, 4]" = auto_functionalized[1]
getitem_2: "bf16[1, 5, 2, 4]" = auto_functionalized[2];  auto_functionalized = None
copy_: "bf16[1, 5, 2, 4]" = torch.ops.aten.copy_.default(arg1_1, getitem_1);  arg1_1 = getitem_1 = None
copy__1: "bf16[1, 5, 2, 4]" = torch.ops.aten.copy_.default(arg2_1, getitem_2);  arg2_1 = getitem_2 = None
return (getitem,)

Maybe give the latest nightly a shot?

I found the examples here.

However, I have another question: is the registration required by torch.cuda.CUDAGraph() or torch._dynamo? Do I still need this registration if I want to define the graph manually with torch.cuda.CUDAGraph() instead of capturing it with Dynamo?

The answer is: no. If anyone has difficulties using Dynamo, consider CUDAGraph. See the blog post here. However, you need to make sure that the graph is static (be careful with if/for statements).