HazyResearch / safari

Convolutions for Sequence Modeling

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Question about long sequence lengths with Hyena

VivekPanyam opened this issue · comments

Hello!

In the Hyena paper, section 4.4 says "Hyena speedups reach 100x at sequence length 64K."

The figure referenced by that section (figure 4.3) stops at a sequence length short of 10k and the optimized implementation in this repo appears to be limited to an 8k sequence length.

There are a few other references to a 100x speedup over FlashAttention in the paper (and in blog posts). Are these measured speedups or extrapolated from smaller sequence lengths?

I've experimented with the implementation in standalone_hyena.py but it appears to be ~3x slower than FlashAttention at sequence lengths > 32k tokens.

Do you have an estimate for when the fftconv implementation in this repo will support longer sequence lengths (or a pointer to another Hyena codebase if the speedups in the paper were measured)?

Thanks for the great work!

The runtime numbers in the paper do not use the optimized fftconv kernel precisely because of the temporary 8k limitation.

The figure referenced by that section (figure 4.3) stops at a sequence length short of 10k

Figure 4.3 (left) goes up to 100k, I think you're referencing the one on the right (which is only a zoomed-in portion of the left figure)?

I've experimented with the implementation in standalone_hyena.py but it appears to be ~3x slower than FlashAttention at sequence lengths > 32k tokens.

Can you give more details on your benchmarking workload? Hyena should already be much faster at 32k tokens, I suspect there might be other factors at play.

Figure 4.3 (left) goes up to 100k, I think you're referencing the one on the right (which is only a zoomed-in portion of the left figure)?

Yeah my bad. I saw 10^5 and thought 10000 for some reason 🤦

Can you give more details on your benchmarking workload?

My benchmarking baseline was a model with:

  • Sequence length of 36032 tokens (36k rounded up to a multiple of 64)
  • Batch size of 1
  • FlashAttention with 16 heads and an embedding dim of 1024 (from the pypi package for this repo)
  • bfloat16
  • A100

I compared it to a model replacing FlashAttention with a HyenaOperator (d_model of 512, l_max of the sequence length above and everything else at the default values).

I ran into memory issues using the HyenaOperator with an embedding dim of 1024 so I had to drop to 512.

Even with the much smaller embedding dim, the network with the HyenaOperator was ~3x slower than the one with FlashAttention.

Do you think the memory usage issues are just an artifact of the standalone implementation?

Hyena should already be much faster at 32k tokens, I suspect there might be other factors at play.

Are you saying that the code in standalone_hyena.py should be much faster than FlashAttention at that sequence length?

The benchmark above was just a quick test to get some rough numbers so there were some other differences between the baseline and the Hyena test:

  • Smaller embedding dim, but AFAIK, that should work in favor of Hyena
  • torch.compile didn't work with the HyenaOperator (complaining about complex to float conversions). I haven't had a chance to dig into the issue in depth yet, but it seems unlikely that it's responsible for the performance gap assuming you expect Hyena to be significantly faster (i.e. 10x to 50x faster). Maybe the overall model is just getting really slow without torch.compile, but if I remember correctly, that should be less than a 2x perf difference. I don't think not using torch.compile would "undo" a significant improvement from Hyena especially because Hyena should be improving the bottleneck of the network. Of course, this is pretty straightforward to test so I can do that.

Do you have any suggestions for things to try?

Thanks!

Thanks for all the info! I pushed a small benchmarking script here for both forward and backward passes, what numbers do you see when you run it? On my end (on a single A100) I see Hyena as 5x/6x faster at batch size 1 and seqlen 32k. If you use the same script and benchmark at batch size 64, you should get the FlashAttention runtime numbers of the original paper.

Regarding memory: yes at the moment without the custom kernel the memory scaling is slightly worse for Hyena w.r.t FlashAttention, though they are both linear. Doing a bit more recomputation on the backward pass helps, we're working on these optimizations.

I was just about to run a benchmark I wrote when you posted your comment :)

I modified your script to import from standalone_hyena and I can roughly reproduce your results on an A100. FlashAttention (fwd + bwd) takes ~3.8x longer than Hyena (fwd + bwd) at a seq len of 32k and batch size of 1.

Full output:

2048
/home/ubuntu/standalone_hyena.py:17: UserWarning: ComplexHalf support is experimental and many operators don't support it yet. (Triggered internally at ../aten/src/ATen/EmptyTensor.cpp:31.)
  k_f = torch.fft.rfft(k, n=fft_size) / fft_size
4096
8192
16384
32768
65536
131072
---
{2048: 0.0002955092999854969, 4096: 0.0005243146999873716, 8192: 0.0016268614999717101, 16384: 0.0051274498999646315, 32768: 0.017562666200001333, 65536: 0.06848136959997646, 131072: 0.2707951788000173}
{2048: 0.000688433300001634, 4096: 0.001220289100001537, 8192: 0.003820151900026758, 16384: 0.011008315900016895, 32768: 0.041841238200004224, 65536: 0.16533355219999066, 131072: 0.6602845148999676}
---
{2048: 0.0008685043000241421, 4096: 0.0009082306999971479, 8192: 0.0017019433000314166, 16384: 0.002847423499997603, 32768: 0.00549896880002052, 65536: 0.01111297390002619, 131072: 0.02508695080000507}
{2048: 0.0017592959000012343, 4096: 0.0017505167999843252, 8192: 0.003214542400019127, 16384: 0.005148207100000945, 32768: 0.009867695800039655, 65536: 0.02085177230001136, 131072: 0.044432145000018866}

That said, the speed difference at 64k is still "only" 7.3x vs the 100x from the paper. Any thoughts on what could be causing that?

Thanks again!

Awesome! It's all a game of batch sizes, try running at batch sizes 16, 32 and 64 and you should see the speedup get larger.

Hmm that doesn't seem to work.

At a batch size of 16 and sequence length of 32k, FlashAttention takes 3.48 times longer than Hyena (see details below).

Thoughts:

  1. Currently, a single HyenaOperator with a batch size of 16 and seq len of 32k uses up almost all GPU memory (> 35gb on a 40GB A100) when running the benchmark. Did you have to use checkpointing to test larger combos?

  2. At that sequence length, Hyena with a batch size of 16 is ~18.7x slower than Hyena with a batch size of one. This seems to imply that batching is worse than just serially processing each sequence.

As far as I can tell, there isn't really a good reason to use a large batch size here vs a batch size of 1 + gradient accumulation.

Any ideas on what's going on?


Details:

GPU: A100 40GB SXM4
Versions:

Torch 2.0.0
flash-attn 0.2.8
einops 0.6.0

Hyena implementation permalink

Benchmark permalink

Batch size of 1

Bechmark script with batch size of 1:

2048
/home/ubuntu/standalone_hyena.py:17: UserWarning: ComplexHalf support is experimental and many operators don't support it yet. (Triggered internally at ../aten/src/ATen/EmptyTensor.cpp:31.)
  k_f = torch.fft.rfft(k, n=fft_size) / fft_size
4096
8192
16384
32768
65536
131072
---
{2048: 0.0003016873999968084, 4096: 0.0004113602999950672, 8192: 0.0012788771000032284, 16384: 0.0046683007999945405, 32768: 0.017532283400009875, 65536: 0.06853519320000032, 131072: 0.2708130353000115}
{2048: 0.0005378262000022005, 4096: 0.0009546344999989742, 8192: 0.0030730292000043847, 16384: 0.010984528699987095, 32768: 0.041858487700005755, 65536: 0.16538351169999715, 131072: 0.6592029120000007}
---
{2048: 0.0008691948000205229, 4096: 0.0009022319999985485, 8192: 0.0014434583000138446, 16384: 0.002845392300014282, 32768: 0.005500168399998983, 65536: 0.011115612699995836, 131072: 0.02511233400000492}
{2048: 0.001425099199991564, 4096: 0.0014490745000102835, 8192: 0.0027972210999905656, 16384: 0.005150951700011319, 32768: 0.009863483899994207, 65536: 0.0208421756000007, 131072: 0.044432989299980366}

Batch size of 16

Benchmark script with batch size of 16 (limited to seq len of 32k because Hyena runs out of memory at larger seq lengths):

2048
/home/ubuntu/standalone_hyena.py:17: UserWarning: ComplexHalf support is experimental and many operators don't support it yet. (Triggered internally at ../aten/src/ATen/EmptyTensor.cpp:31.)
  k_f = torch.fft.rfft(k, n=fft_size) / fft_size
4096
8192
16384
32768
---
{2048: 0.0020986096000342514, 4096: 0.0062191390999942085, 8192: 0.02071911389998604, 16384: 0.0747385351000048, 32768: 0.2817839461999938}
{2048: 0.004406064900013007, 4096: 0.014036858499957816, 8192: 0.049613316399972976, 16384: 0.1854088593999677, 32768: 0.7167293173999951}
---
{2048: 0.004067870600010792, 4096: 0.008555072500030292, 8192: 0.017019115599987346, 16384: 0.0345663336999678, 32768: 0.069258173299977}
{2048: 0.009594092299994372, 4096: 0.019340128099975117, 8192: 0.04256208110000444, 16384: 0.09568457360001048, 32768: 0.21756787399999666}

Something else I noticed is that the paper says "Hyena speedups reach 100x at sequence length 64K" and references Figure 4.3, but if you look at the LaTeX for Figure 4.3, it's actually only an 11.4x difference.

I know the paper is still a draft so is the figure (or text) outdated? Or are we interpreting the meaning of "speedup" differently?

Thanks!


Figure 4.3 from the paper:

image

\addplot [line width=1pt, indianred]
table {%
1024 0.9
2048 1.16
4096 1.47
8192 1.5
16384 2.84
32768 5.41
65536 11.32
};
\addplot [line width=1pt, cornflowerblue]
table {%
1024 0.4
2048 1.25
4096 2.16
8192 6.17
16384 21.74
32768 90.71
};
\addplot [line width=1pt, lightseagreen, dashed]
table {%
1024 0.29
2048 0.3
4096 0.63
8192 2.1
16384 8.33
32768 32.85
65536 129.07
};

(FlashAttention at 64k) / (Hyena at 64k) = 129.07/11.32 = ~11.4

Section 4.4 says (emphasis mine):

We benchmark runtime of an order 2 Hyena operator compared to attention and FlashAttention layers (Dao
et al., 2022b). Hyena uses a fused CUDA kernel to perform FFTConv (Dao et al., 2022c). We set batch
size to 64 and measure runtime (in milliseconds). Results are provided in Figure 4.3. Hyena speedups reach
100× at sequence length 64K.
Crossover points for Hyena and attention is at length 2048, and for Hyena and
FlashAttention is between 4096 and 8196. Despite the absolute reduction in FLOPs, speedups are achieved
only on longer sequences when the gap grows sufficiently large. This occurs because hardware utilization of
Hyena is lower than FlashAttention. We expect the gap between theoretical maximum speedup to shrink
with improved implementations of FFTConv and specialized hardware.

Interesting finds, a few things here:

  • I wouldn't read too much in the superlinear runtime scaling you mentioned w.r.t to batch size, it's all round-off measurement errors. You can increase the number of repeats in the benchmark, and at DIM=768 the scaling will be almost perfectly linear. At smaller model widths, the scaling is also much better for Hyena:
FlashAttention: {(1, 32768): 0.011, (4, 32768): 0.048, (8, 32768): 0.098, (16, 32768): 0.1977, (32, 32768): 0.4041, (64, 32768): 0.8306}    
Hyena: {(1, 32768): 0.0012, (4, 32768): 0.0024, (8, 32768): 0.0043, (16, 32768): 0.0082, (32, 32768): 0.0160, (64, 32768): 0.0346}  
  • If the model is smaller (i.e., the quadratic cost dominates), the speedups get even larger at shorter sequence lengths. An example is given below at DIM=96, batch_size=64:
FlashAttention: {(64, 32768): 0.8301, (64, 65536): 3.3059, (64, 131072): 13.2732}     
Hyena: {(64, 32768): 0.0346, (64, 65536): 0.0725, (64, 131072): 0.1762} 

At saturation (width large enough), take the numbers in the figure and what you see running this benchmark as ground truth, and expect a few more x of speedup and memory reduction as we figure out how to optimize various operations. We'll update the paper when we do so!

If you plan to run models at DIM=768, SEQ_LEN=64k

FlashAttention:  {(1, 65536): 0.0745, (4, 65536): 0.2999, (8, 65536): 0.6023, (16, 65536): 1.2074}
Hyena: {(1, 65536): 0.0118, (4, 65536): 0.0406, (8, 65536): 0.0807, (16, 65536): 0.1806}                                                                                                                                                                                                  

Thanks! That makes sense. I think it would be super useful to have a sweep over (dim, batch_size, seq_len) comparing FlashAttention and Hyena runtimes (for both forward and backward passes), but I don't think I'll be able to get to that anytime soon. Do you think you'll have time to run that sweep? It might even be worth committing to the repo or adding to a wiki so there's a place for people to quickly see potential speedups.

Thanks again!