HazyResearch / safari

Convolutions for Sequence Modeling

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Questions about bidirectional-version of H3

alstonlo opened this issue · comments

Thank you for the amazing code and work! I am interested in using H3 bidirectionally and have some questions:

  1. Would I instantiate H3 with bidirectional shift and diagonal SSMs, similar to S4 and S4D?
  2. If so, this can be achieved by passing a kernel (to be applied in the reverse direction) into the k_rev argument of fftconv() where it is applied?
  3. Below, I have written a script comparing bidirectional convolutions under various implementations. The S4 & S4D convolution implementation seems to give differing results from the H3 and naive implementations. Assuming I have not made an error in my code, is this intended?
import numpy as np
import scipy
import torch
import torch.nn.functional as F

L = 100
n_fft = L * 2


def conv_direct(u, k, k_rev):
    fwd = scipy.signal.convolve(u.numpy(), k.numpy(), method="direct")[:L]
    bwd = scipy.signal.convolve(u.flip(-1).numpy(), k_rev.numpy(), method="direct")[:L]
    return fwd + np.flip(bwd, -1)


def conv_fft_s4(u, k, k_rev):
    k = F.pad(k, (0, L)) + F.pad(k_rev.flip(-1), (L, 0))
    u_f = torch.fft.rfft(u, n=n_fft)
    k_f = torch.fft.rfft(k, n=n_fft)
    return torch.fft.irfft(u_f * k_f, n=n_fft)[..., :L].numpy()


def conv_fft_h3(u, k, k_rev):
    u_f = torch.fft.rfft(u, n=n_fft, norm="backward")
    k_f = torch.fft.rfft(k, n=n_fft, norm="forward") + torch.fft.rfft(k_rev, n=n_fft, norm="forward").conj()
    return torch.fft.irfft(u_f * k_f, n=n_fft, norm="forward")[..., :L].numpy()


def conv_fft_s4_v2(u, k, k_rev):
    k = F.pad(k, (0, L)) + torch.roll(F.pad(k_rev.flip(-1), (L, 0)), 1, -1)
    u_f = torch.fft.rfft(u, n=n_fft)
    k_f = torch.fft.rfft(k, n=n_fft)
    return torch.fft.irfft(u_f * k_f, n=n_fft)[..., :L].numpy()


def compare():
    u = torch.randn(L)
    k = torch.randn(L)
    k_rev = torch.randn(L)

    direct = conv_direct(u=u, k=k, k_rev=k_rev)
    fft_s4 = conv_fft_s4(u=u, k=k, k_rev=k_rev)
    fft_h3 = conv_fft_h3(u=u, k=k, k_rev=k_rev)
    fft_s4v2 = conv_fft_s4_v2(u=u, k=k, k_rev=k_rev)

    print("Direct:  ", direct[:5])
    print("FFT S4:  ", fft_s4[:5])
    print("FFT H3:  ", fft_h3[:5])
    print("FFT S4v2:", fft_s4v2[:5])

    assert np.abs(direct - fft_h3).max() <= 1e-5
    assert np.abs(fft_s4v2 - fft_h3).max() <= 1e-5


if __name__ == "__main__":
    compare()

Output:

Direct:   [-0.8028737 -2.0864778 -3.4721029 11.840934  12.045782 ]
FFT S4:   [-2.326698  -3.9480963 12.992389  12.143847  10.611053 ]
FFT H3:   [-0.80287194 -2.0864816  -3.4721045  11.840934   12.045784  ]
FFT S4v2: [-0.8028715 -2.0864806 -3.4721053 11.840935  12.045782 ]

Thanks in advance!

commented

We don't have bidirectional supported in H3 right now, but you can implement it the same way as S4.

Here's an example of how you would do it (you can view the long conv kernel as analogous to the H3 shift and diagonal kernels): https://github.com/HazyResearch/safari/blob/main/src/models/sequence/long_conv.py#L135

Just like S4, you can double the number of channels to get one kernel that goes forward, and another that goes backward: https://github.com/HazyResearch/safari/blob/main/src/models/sequence/long_conv.py#L75

Thanks!

Sorry, I am still confused about the potential discrepancy raised in my third question. The bidirectional convolution implemented in S4, S4D, and LongConv (concatenating the two kernels) seems to differ from the naive implementation (directly computing convolutions in both directions and adding the results).

The bidirectional version of S4 has an off-by-one in the reverse kernel on purpose for efficiency reasons. One can make it "correct" by replacing

k = F.pad(k, (0, L)) + F.pad(k_rev.flip(-1), (L, 0))

with

k = F.pad(k, (0, L)) + F.pad(k_rev[1:].flip(-1), (L+1, 0)) + F.pad(k_rev[:1], (0, 2*L-1))

(I didn't check this but it should be something like this. The point is that adding the forward and reverse kernels will overlap by one position, while the S4 kernel makes them disjoint for simplicity by stacking them back-to-back.)

I haven't actually seen the other version you tested called conv_fft_h3 and it's not immediately obvious to me why it works, but I can believe it. Note that the reason the pad-and-sum versions are used is that they should be faster because they do fewer FFTs.

Thanks!