OpenNLPLab / lightning-attention

Lightning Attention-2: A Free Lunch for Handling Unlimited Sequence Lengths in Large Language Models

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

assert d in supports_dim and e in supports_dim ?

XintianHan opened this issue · comments

Thank you for the nice implementation! It seems that dim=192 is not in supports_dim. Why is it the case here? Could you add dim=192?

I tried this script

import torch
from lightning_attn.ops import lightning_attn_func
from lightning_attn.utils import _build_slope_tensor

dtype = torch.bfloat16
device = torch.device("cuda")
b, h, n, d, e = 2, 12, 2048, 192, 192

q = torch.randn((b, h, n, d), dtype=dtype, device=device).requires_grad_()
k = torch.randn((b, h, n, d), dtype=dtype, device=device).requires_grad_()
v = torch.randn((b, h, n, e), dtype=dtype, device=device).requires_grad_()
s = _build_slope_tensor(h).to(q.device).to(torch.float32)

o = lightning_attn_func(q, k, v, s)

print(o.shape)

and got this error

    o = lightning_attn_func(q, k, v, s)
  File "/opt/tiger/mariana/lightning-attention-main/lightning_attn/ops/lightning_attn_interface.py", line 10, in lightning_attn_func
    assert d in supports_dim and e in supports_dim
AssertionError

Hello, I have update the code right now, and the code support head_dim=192. Can you try the example again?

Hello, I have update the code right now, and the code support head_dim=192. Can you try the example again?

Hi. Thanks for the quick reply. I think I still have problem with dimensions not equaling to the power of 2.

Here is what I ran

from lightning_attn.ops import lightning_attn_func

dtype = torch.bfloat16
device = torch.device("cuda")
b, h, n, d, e = 1, 16, 2, 192, 96

q = torch.randn((b, h, n, d), dtype=dtype, device=device).bfloat16().requires_grad_()
k = torch.randn((b, h, n, d), dtype=dtype, device=device).bfloat16().requires_grad_()
v = torch.randn((b, h, n, e), dtype=dtype, device=device).bfloat16().requires_grad_()
s = torch.randn(h, 1, 1).to(q)

o = torch.sum(lightning_attn_func(q, k, v, s))
o.backward()

Then the error happened at backward.
loc("/opt/tiger/mariana/lightning-attention-main/lightning_attn/ops/triton/lightning_attn2.py":123:64): error: Number of elements must be power-of-two, but %49 = "tt.make_range"() <{end = 96 : i32, start = 0 : i32}> : () -> tensor<96xi32> doesn't follow the rule (96) elements

Any thought here? Thank you so much!

Working on this.

This problem has been temporarily solved. The current solution is to use F.pad. I will provide a more efficient solution in the future.