zhuzilin / ring-flash-attention

Ring attention implementation with flash attention

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[Feature Request] Balancing computation with zigzag blocking

zhuzilin opened this issue · comments

Currently the implementation will split the input sequence into n blocks, e.g. 4 gpu will split into:

b0 | b1 | b2 | b3

however, this will result in uneven calculation, where the gpu that has b3 will do around 4 times more calculation than the gpu that has b0, due to causal attention mask.

If we split the input sequence into 2n blocks, e.g. 4 gpu will split into:

b0,b7 | b1,b6 | b2,b5 | b3,b4

then all gpu will have the same amount of calculation, and theoratically the latency should be decrease by half.

Do you mean Striped Attention when you say zigzag blocking? Or is it something more simple which still gives you a much better utilization (lower latency) of the gpus?

@andreaskoepf Oh... I haven't read the striped attention before... (thought that was some sparse attention mask version of ring attention like window attention from the name, my bad...)

but from a quick look, it seems that stripe attention is doing some thing like:
image

0,4,8,12 | 1,5,9,13 | 2,6,10,14 | 3,7,11,15

I was thinking about doing something like:

0,1,14,15| 2,3,12,13 | 4,5,10,11 | 6,7,8,9

which may be able to fold the causal mask from:

x
xx
xxx
xxxx
xxxxx
xxxxxx
xxxxxxx
xxxxxxxx

into

x xxxxxxxx
xx xxxxxxx
xxx xxxxxx
xxxx xxxxx

I'm not sure which could give better performace...

@zhuzilin this is great work! Your zig zag pattern looks to be the optimal sharding for Ring Attention: it homogeneously spreads the computation across all ranks. Wanted to share my analysis in case it's helpful for others.

Setup: the queries, keys, and values (q, k, v) all have sequence length S in their unsharded form. Use ring attention across R ranks of GPUs. Ignore the batch, hidden, and head dimensions in the following for simplicity.

Ring attention divides the sequence length into R chunks of size S/R each (assume S is divisible by R). The three strategies for doing so are:

  • Naive: give sequential sequence positions to sequential ranks. This is basically row-major sharding. In einops notation, this would be like logically sharding the full queries via q_shards = einops.rearrange(q, "(r t) -> r t", r=R, t = S/R), which has shape (R, S/R), and giving rank r the (S/R,)-shaped q_shards[r] slice. Equivalently, q_shards = q.reshape(R, S // R). Similar for keys and values.
  • Striped: more like column-major sharding. The logical sharding pattern is q_shards = einops.rearrange(q, "(t r) -> r t", r=R, t = S/R) (r and t are swapped in the first part of the pattern) and giving q_shards[r] to rank r. Equivalently, q_shards = q.reshape(S // R, R).swapdims(0, 1).
  • Zig Zag: Chunk the sequence into 2*R sequential pieces and give rank r chunk indices r and 2R-r-1. Continuing with einops notation, this is like chunking as q_shards = einops.rearrange(q, "(r t) -> r t", r=2*R, t = S/(2*R)) and giving rank r the tensor torch.cat((q_shards[r], q_shards[2R-r-1])). We assumed S is also divisible by 2 * R here.

Efficiency

The first iteration of ring attention is the same for sharding strategies. Because the queries and keys come from the same rank in this step, the S/R query positions coincide with the S/R key positions and so a little analysis shows that the queries attend to $\frac{S}{2R}\left ( \frac{ S }{ R }+1 \right )$ positions in total.

The strategies all differ in how long the subsequent steps take, in which the queries attend to the keys from different ranks. We use the maximum number of positions that any rank attends to on a given iteration as a rough proxy for how much time that iteration takes.

Naive

For naive ring attention, the computation is always bottlenecked by (at least) rank r=R which owns all of the future-most query positions: {S - S / R, ..., S - 1}. Its queries will attend to every position of every other rank's keys: $\frac{ S ^{ 2 } }{ R ^{ 2 } }$ attention operations. In contrast, the r=0 rank which owns the oldest positions, {0, ..., S / R - 1}, cannot attend to any other rank's keys. All other ranks behave like one of the two cases above, depending on iteration.

The iteration time is determined by the slowest rank, so the time is set by $\frac{ S ^{ 2 } }{ R ^{ 2 } }$ , i.e. naive sharding is bottlenecked by rank R.

Striped Attention

For striped ring attention, once again rank R always performs the most non-trivial attention operations and rank 0 the fewest, but the difference is not as drastic as the naive case.

Rank r owns indices {r, r + S / R, ..., R - r + S - 1 }. A little analysis shows that the rank R can attend to $\frac{S}{2R}\left ( \frac{ S }{ R }+1 \right )$ on every iteration, same as the initial iteration above. Rank zero, owning the oldest positions, only attends to $\frac{S}{2R}\left ( \frac{ S }{ R }-1 \right)$ positions on every iteration. The story for other ranks falls into one of these two buckets, depending on iteration. This all follows from the interleaved indexing that striped attention provides.

The iteration time is now set by $\frac{S}{2R}\left ( \frac{ S }{ R }+1 \right )$, which is nearly a factor-of-two improvement over naive sharding, but the systematic mismatch in operations across ranks is still suboptimal. (In practice, it looks like you are finding a ~25% improvement, while the original paper finds ~50%.)

Zig Zag Attention

For zig zag ring attention, every rank performs the same amount of attention operations on every iteration: $\frac{S ^{ 2 }}{2R ^{ 2 }}$. This is the optimal setup where no individual rank is a bottleneck.

There are two possible scenarios for every rank:

  1. All of the query's positions fall between those owned by the keys. Every query position can
    attend to half of the key positions, resulting in $\frac{ S ^{ 2 } }{ 2 R ^{ 2 } }$ operations.
  2. All of the key's positions fall between those owned by the queries. The future-most half the
    query's positions can attend to every key position, resulting in $\frac{ S ^{ 2 } }{ 2 R ^{ 2 } }$ operations.

The compute is thus homogeneous on every rank. This improves upon striped attention by reducing maximum operations per iteration by an approximate factor of $\approx 1 - \frac{ R }{ S }$. (In practice, it looks like you're getting a ~5% improvement on 8x{A100,H100} nodes for 8k sequence lengths, which is much larger than this simple analysis would suggest.)

Minimal Example: R=2, S=4

The minimal possible example, to make things concrete. They positions owned by each strategy are:

Rank Naive Striped Zig Zag
0 {0, 1} {0, 2} {0, 4}
1 {2, 3} {1, 3} {2, 3}

On the second iteration, rank-0's queries are attending to rank-1's keys and vice versa for rank-1's queries. The aggregate number of positions each rank attends to in all cases is:

Rank Naive Striped Zig Zag
0 0 1 2
1 4 3 2

The naive strategy is maximally imbalanced, the striped strategy is somewhat imbalanced, and zig zag is perfectly balanced.