OpenNLPLab / lightning-attention

Lightning Attention-2: A Free Lunch for Handling Unlimited Sequence Lengths in Large Language Models

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Does using lightning-attention need retraining?

ranpin opened this issue · comments

Hello, I have replaced the normal self-attention calculation in my own model with lightning attention, without any additional operations, but I found that the model is poorly reasoned and tested.

Therefore, I would like to ask, just replacing the normal self-attention calculation with lightning attention, does this approach will have any effect on the model accuracy? Do I need to retrain my model? thank you very much!

And I find another issue(https://github.com/OpenNLPLab/lightning-attention/issues/10#issuecomment-1986779377) that you said lightning attention has no parameters, so it maybe should not have any effect on the model accuracy just like flash attention and it doesn't need to be trained?

Here is the code I used to calculate the attention forward process originally:

def forward(self, x: torch.Tensor) -> torch.Tensor:

        B, H, W, _ = x.shape

        # qkv with shape (3, B, nHead, H * W, C)
        qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
        # q, k, v with shape (B * nHead, H * W, C)
        q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)

        attn = (q * self.scale) @ k.transpose(-2, -1)

        if self.use_rel_pos:
            attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))

        attn = attn.softmax(dim=-1)
        x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)

        x = self.proj(x)

        return x

Here is my modified code for computing the attention forward process using LIGHTNING ATTENTION:

def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, H, W, _ = x.shape

        # qkv with shape (B, H * W, 3, num_heads, C)
        qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)  # unbind the 3 tensors

        # build slope tensor for Lightning Attention
        slope_tensor = _build_slope_tensor(self.num_heads).to(x.device).to(torch.float32)

        # compute attention using Lightning Attention
        attn = lightning_attn_func(q, k, v, slope_tensor)

        # reshape attention output
        attn = attn.view(B, H, W, -1)

        # final projection
        x = self.proj(attn)

        return x

Hello, Lightning Attention is an acceleration of Linear Attention and is not suitable for models that utilize Softmax Attention (i.e., flash attention). Direct replacement would result in a significant decrease in performance, and at the very least, some fine-tuning steps are necessary. You can refer to this paper for more information.

Hello, Lightning Attention is an acceleration of Linear Attention and is not suitable for models that utilize Softmax Attention (i.e., flash attention). Direct replacement would result in a significant decrease in performance, and at the very least, some fine-tuning steps are necessary. You can refer to this paper for more information.

Thanks, according to you, am I to understand that if I originally used a normal self-attention mechanism (which usually includes Softmax), if I just directly replace it with Lightning Attention, the performance will be significantly degraded and need to be fine-tuned or re-trained right? Whereas if I originally used Linear Attention, the impact may not be as significant?

Yes, you can understand it that way. In fact, when using Linear Attention, we directly employ Lightning Attention for acceleration during the training phase.

Yes, you can understand it that way. In fact, when using Linear Attention, we directly employ Lightning Attention for acceleration during the training phase.

fine, I got it, thank you very much!