HazyResearch / safari

Convolutions for Sequence Modeling

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Question concerning FFT operation.

veritas9872 opened this issue · comments

k_f = torch.fft.rfft(k, n=fft_size) / fft_size
u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size)
if len(u.shape) > 3: k_f = k_f.unsqueeze(1)
y = torch.fft.irfft(u_f * k_f, n=fft_size, norm='forward')[..., :seqlen]

Hello. Thank you for the great work in this paper. I only have a minor question concerning the code.

When performing the FFT, it is my understanding that the inputs should be shifted before and after the operation to be equivalent to the DFT.

Therefore, fftshift(fft(ifftshift(x))) and fftshift(ifft(ifftshift(X))) are the correct methods.

Because the rfft function removes half of the frequency space, I believe that the correct transformation should be rfft(ifftshift(x)) and fftshift(irfft(X)) for the conversions to and from the frequency domain. This may not impact the model performance, and there may be no great difference in the outputs, but I believe that it may be worth noting.

I have included the following links for reference.

https://groups.google.com/g/comp.soft-sys.matlab/c/rUcc0bRRZf4?pli=1

https://dsp.stackexchange.com/questions/66716/why-do-we-have-to-rearrange-a-vector-and-shift-the-zero-point-to-the-first-index

commented

Thank you for the quick response!
I think that my question is slightly different.
The FFTShift and IFFTShift operations move the low-frequency regions to the center of the sequence.
image

Due to an implementation issue, the FFT and IFFT require center frequency shifting to accurately calculate the DFT.
While this may be canceled out, I was curious if this might affect the result.

This discussion may also be helpful. pytorch/pytorch#51022

commented

afaik, this is due to the fact that MATLAB arrays are 1-indexed, which forced many communities working with MATLAB to adopt the fftshift + centered DFT convention. You don't need fftshift in PyTorch code for the DFT result to be right.

I have tested the function and I believe that this is indeed the issue.

The following code does indeed show that shifting is unnecessary for FFT in PyTorch.

Thank you for your help!

from scipy import signal
import torch
import numpy as np


@torch.inference_mode()
def test1():
    seq_len = 13
    a = np.random.rand(seq_len)
    b = np.random.rand(seq_len)
    c = signal.convolve(a, b, mode='full', method='direct')
    d = torch.fft.rfft(torch.from_numpy(a), n=2 * seq_len) / (2 * seq_len)
    e = torch.fft.rfft(torch.from_numpy(b), n=2 * seq_len)
    f = torch.fft.irfft(d * e, n=2 * seq_len, norm='forward').numpy()[:-1]
    print(np.allclose(c, f))  # True


@torch.inference_mode()
def test2():
    seq_len = 13
    a = np.random.rand(seq_len)
    b = np.random.rand(seq_len)
    c = signal.convolve(a, b, mode='full', method='direct')
    d = torch.fft.rfft(torch.fft.ifftshift(torch.from_numpy(a)), n=2 * seq_len) / (2 * seq_len)
    e = torch.fft.rfft(torch.fft.ifftshift(torch.from_numpy(b)), n=2 * seq_len)
    f = torch.fft.fftshift(torch.fft.irfft(d * e, n=2 * seq_len, norm='forward')).numpy()[:-1]
    print(np.allclose(c, f))  # False

The PyTorch and NumPy functions produce identical results. The MATLAB implementation does seem to have been the issue.

Another question though. Is taking the front of the resultant convolved sequence the desired behavior? I believe that the middle part, corresponding to scipy.signal.convolve(...,mode='same') may be more desirable.

The resulting code would be as follows.

seqlen = u.shape[-1]
fft_size = 2 * seqlen

 k_f = torch.fft.rfft(k, n=fft_size, norm='forward')
 u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size, norm='backward')  # Explicit norm mode for better readability.
  
 if len(u.shape) > 3: k_f = k_f.unsqueeze(1) 
 y = torch.fft.irfft(u_f * k_f, n=fft_size, norm='forward')[seqlen//2:seqlen//2+seqlen]

Thanks for verifying! Could you elaborate as to why that would be more desirable? If you don't take the first seqlen elements, your convolution is no longer causal. Padding is just an artifact to turn a circular convolution (for which the FFTConv method holds) into a linear convolution (which is what we want to compute) - at the output, you need to select the first elements for the result to be correct.

I see that the desired result is to take only the first part of the output sequence, instead of the region with the maximum overlap. Thank you for the explanation!