Samsung / ONE

On-device Neural Engine

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[one-optimize] Optimize part of the transformer's attention-head

BalyshevArtem opened this issue · comments

What

Let's introduce two new optimization passes to simplify and accelerate part of transformer's attention-head.
Original it has the following pattern we can optimize:

Screenshot from 2024-04-24 17-53-02

1. First we can fuse

StridedSlice --- Concatenation 
StridedSlice --- Neg /

pattern as Mul operation, consisting of 1 and where there was a Neg operation there -1

As a result we will have:
Screenshot from 2024-04-24 17-57-50

2. The we can twice fuse Mul with FullyConnected nodes and get:

Screenshot from 2024-04-24 17-59-33

3. And finally Fuse horizontal fc layers, we will get single FC node :

Screenshot from 2024-04-24 18-01-03

Why

To speed up and simplify attention-based models.

How

  • Introduce pass to fuse StridedSlices/Neg/Concatenation as Mul pattern.
  • Introduce pass to fuse Mul with FullyConnected node.

@BalyshevArtem , this is awesome!
I've resized the images a little bit smaller for better readability :)

Would you let me know which model you used?

In the model I used, only one FullyConnected layer was created in the corresponding part, so it seems that the structure varies slightly depending on the model.

Would you let me know which model you used?

I used model generated in one of the internal repo - Modified Llama2 (split head).
It is decoder part.

@BalyshevArtem Thanks for a good idea :) As @periannath mentioned, the original pattern seems to have duplicate FCs, i.e., the two FCs are in fact the same. So the baseline would be the pattern with a single FC layer.

For the second fusion, the second MUL is for applying rotary embedding, which would be a user input (not constant) if the model supports dynamic behavior.

If a model only supports fixed positions (all input tokens' position is fixed, which means that the number of previously cached tokens is also fixed), this would be an effective optimization.

Introduce pass to fuse StridedSlices/Neg/Concatenation as Mul pattern.

This fusion looks good to me. One minor concern is that this will reduce operator counts but create a new constant tensor. It has to be considered not to increase model size too much.

This fusion looks good to me. One minor concern is that this will reduce operator counts but create a new constant tensor. It has to be considered not to increase model size too much.

We can fuse this pattern only, if we can then fuse Mul with const into In the fully connected operation, which is located above. As I understand it, the issue with dynamic or static rotation embedding used does not affect this fusion optimization, since there will always be an fully connected layer in front of this pattern, right?

the original pattern seems to have duplicate FCs, i.e., the two FCs are in fact the same. So the baseline would be the pattern with a single FC layer.

I'm not sure I got it right :) In the example that I used:

Screenshot from 2024-04-24 17-53-02

these two fc operations have different constants.
Or do you mean some another pattern?

the issue with dynamic or static rotation embedding used does not affect this fusion optimization, since there will always be an fully connected layer in front of this pattern, right?

Yes :)

these two fc operations have different constants.

Ah, your model seems to be the one whose attention heads are split. I thought about the pattern without head split. Below is the original pattern of rotary embedding whose heads are not split.

image

After heads are split, it seems that a new FC is created as FC is fused with Mul (left Mul in the above graph).

I think that kind of fusion should be applied carefully suppressed because it will increase model size quite much (model size is a bottleneck of performance as of now). Thanks for finding.

@BalyshevArtem Could you share any preliminary result after this optimization, e.g., impacts on cycles/traffic? If there is some sensitive information, please use our internal repo.

@BalyshevArtem Could you share any preliminary result after this optimization, e.g., impacts on cycles/traffic? If there is some sensitive information, please use our internal repo.

Sure, I will post results in internal repo :)

Below is the original pattern of rotary embedding whose heads are not split.

In this example, we can also apply some optimizations:

  1. Fuse StridedSlices-Neg-Concatenation pattern as Mul operation, consisting of 1 and where there was a Neg operation there -1.
  2. Fuse two muls with const values.
  3. Fuse pattern Add(Mul(Input, Const1), Mul(Input, Const2) as Mul(Input, Const3), where Const3= Const1 + Const2

It seems that the first fusion is invalid. Please check the begin/end of StridedSlice.

StridedSlice A(begin:0, end:40) --- Concatenation (B+A)
StridedSlice B(begin:40, end:80)--- Neg /

The order of two sliced tensors is changed, so it is impossible to convert the pattern to a simple Mul.

It seems that the first fusion is invalid. Please check the begin/end of StridedSlice.

StridedSlice A(begin:0, end:40) --- Concatenation (B+A)
StridedSlice B(begin:40, end:80)--- Neg /

The order of two sliced tensors is changed, so it is impossible to convert the pattern to a simple Mul.

Yes, you're right, thank you! Indeed, there is a division in half and a reverse of these halves.

Such a pattern can still be optimized, but it gets more complicated. Let's expand the pattern in question by adding Fully Connected.

---- Weight_Const
|
FullyConnected---->StridedSlice A(begin:0, end:40) --- Concatenation (B+A)
             \---->StridedSlice B(begin:40, end:80)--- Neg /

So the idea is to first split weights and rotate in the same way as StridedSlices->Concatenation does. So In the example from #12917 (comment) we need change weights for FullyConnected (with shape 80 x 240) - split it into two parts by rows: 40 x 240 - first_part and 40 x 240 - second_part and reverse theirs order, now now second_part is first and first_part is second. And after that introduce Mul with negative values (first part), and then fuse it in FC and so on (as in #12917 (comment)).

It turns out to be a highly specialized optimization pattern, but at the same time it allows us to greatly reduce unnecessary calculations and even reduce the binary size, due to fusing constants and weights.
@jinevening,
The question is to: Does this pattern occur in our target models? If you find it helpful to implement such optimization, I will to do it, but if you think that this is too rare pattern that will not be useful to us, then it is better to postpone this task. What do you think? :)

@BalyshevArtem I've answered the question in the internal repo.