HazyResearch / safari

Convolutions for Sequence Modeling

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

RuntimeError: u must have shape (batch_size, H, L)​

kaansancak opened this issue · comments

Hello,
I am trying to run the benchmark here with fused_fft_conv enabled but I am getting RuntimeError: u must have shape (batch_size, H, L)​ error. In this case the shape of u is [1, 1, 768, 1, 2048]​ but it expects [1, 1, 768]​. Normally, fftconv handles the last dimension but in this case, the shape check fails.

Log:

Traceback (most recent call last):
  File "/localscratch/safari/benchmarks/runtime_hyena_flashmha.py", line 77, in <module>
    m, t = benchmark_forward(hyena, x, repeats=10, desc='', verbose=False)
  File "/localscratch/safari/benchmarks/runtime_hyena_flashmha.py", line 23, in benchmark_forward
    m = t.timeit(repeats)
  File "/opt/conda/envs/gps/lib/python3.9/site-packages/torch/utils/benchmark/utils/timer.py", line 266, in timeit
    self._timeit(number=max(int(number // 100), 2))
  File "/opt/conda/envs/gps/lib/python3.9/site-packages/torch/utils/benchmark/utils/timer.py", line 256, in _timeit
    return max(self._timer.timeit(number), 1e-9)
  File "/opt/conda/envs/gps/lib/python3.9/timeit.py", line 177, in timeit
    timing = self.inner(it, self.timer)
  File "<timeit-src>", line 6, in inner
  File "/opt/conda/envs/gps/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/localscratch/safari/src/models/sequence/hyena.py", line 361, in forward
    v = self.filter_fn(v, l_filter, k=k[o], bias=bias[o, None, :, None])
  File "/opt/conda/envs/gps/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/localscratch/safari/src/models/sequence/hyena.py", line 218, in forward
    y = fftconv_func(
  File "/localscratch/safari/src/ops/fftconv.py", line 102, in fftconv_func
    return FFTConvFunc.apply(u, k, D, dropout_mask, gelu, force_fp16_output,
  File "/localscratch/safari/src/ops/fftconv.py", line 79, in forward
    out = fftconv_fwd(u, k_f, D, v, head_dim, q, dropout_mask, gelu, False, False, fft_size, force_fp16_output, output_hbl_layout, fftfp16)
RuntimeError: u must have shape (batch_size, H, L)

Hey, I think the code for fftconv expects the model to only have a single head and number of blocks, while the model code has already integrated support for multiple heads and blocks (which then breaks fftconv as you noticed). Also at some point the code expects a transposed version of the input. You can patch src/models/sequence/hyena.py like this to get it running for now:

@@ -314,13 +314,13 @@ class HyenaOperator(nn.Module):
         
         uc = self.short_filter(u)[...,:l_filter] 
         
-        uc = rearrange(uc, 'b (ho v) (z l) -> b ho v z l', 
-            z=self.num_blocks, 
-            ho=self.num_heads, 
-            v=self.head_dim * (self.order + 1)
-        )
+        # uc = rearrange(uc, 'b (ho v) (z l) -> b ho v z l',
+        #     z=self.num_blocks,
+        #     ho=self.num_heads,
+        #     v=self.head_dim * (self.order + 1)
+        # )
 
-        *x, v = uc.split(self.d_model, dim=2)
+        *x, v = uc.split(self.d_model, dim=1)
         k = self.filter_fn.filter(l_filter)
         
         # `c` is always 1 by default
@@ -339,7 +339,7 @@ class HyenaOperator(nn.Module):
                 v = self.dropout(v * x_i)
 
             # the bias term is broadcasted. Last dimension (l) is handled by fftconv
-            v = self.filter_fn(v, l_filter, k=k[o], bias=bias[o, None, :, None])
+            v = self.filter_fn(v, l_filter, k=k[o], bias=bias[o])
             
             if self.post_order_ffn: 
                 w = self.ord_proj_w[o]
@@ -347,7 +347,10 @@ class HyenaOperator(nn.Module):
                     rearrange(w, 'h1 h2 -> 1 h1 h2 1 1 1'), rearrange(v, 'b h v z l -> b h 1 v z l')
                 )
 
-        y = self.activation(rearrange(v * x[0], 'b h v z l -> b (z l) (h v)', z=self.num_blocks, h=self.num_heads))
+        y = self.activation(
+            (v * x[0]).transpose(-2, -1),
+            # rearrange(v * x[0], 'b h v z l -> b (z l) (h v)', z=self.num_blocks, h=self.num_heads)
+        )
         y = self.out_proj(y)
         
         if self.return_state:
@@ -356,4 +359,4 @@ class HyenaOperator(nn.Module):
 
     @property
     def d_output(self):
-        return self.d_model
\ No newline at end of file
+        return self.d_model

Hi is there any update about the fftconv for multi-head support?

The module already supports multi-head - you can find an example in the H3 code: https://github.com/HazyResearch/safari/blob/main/src/models/sequence/h3.py#L160

In H3, the names of the three branches (what Hyena calls x[0], x[1], and v) are called q, k, and v.

Passing in head_dim > 1 will trigger multi-head support:

y = fftconv_func(k, ssm_kernel, self.D,
                             dropout_mask, False, torch.is_autocast_enabled(), True,
                             v, self.head_dim, q)