zh217 / torch-dct

DCT (discrete cosine transform) functions for pytorch

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

rfft warning

pythonmobile opened this issue · comments

X = dct.dct(x)
UserWarning: The function torch.rfft is deprecated and will be removed in a future PyTorch release. Use the new torch.fft module functions, instead, by importing torch.fft and calling torch.fft.fft or torch.fft.rfft. (Triggered internally at /pytorch/aten/src/ATen/native/SpectralOps.cpp:590.)
Vc = torch.rfft(v, 1, onesided=False)

I met the same problem

@zh217 Yes I am meeting same problem as well. But in short this is not happening for torch <=1.6 from my observation, but for torch 1.7.0 or later it does. I tried:

  1. torch 1.6.0 + torchaudio 0.6.0
  2. torch 1.7.0 + torchaudio 0.7.0
  3. torch.1.7.1 + torchaudio 0.7.2

For the 1st case I did not observe any warning but for latter two I did.

Also I think it does not affect the performance of the system overall from my experiments, we may need some update to cope though.

commented

With pytorch 1.8.0 the lib does not work anymore.
AttributeError: module 'torch' has no attribute 'rfft'

@margaritageleta Thanks a lot for addressing.

@zh217 Shall I fix it if you don't have time?

you can use these code to use pytorch 1.8.0 to simulate pytorch1.5.0, when signal_ndim = 1 and onesided=False.
for this file, these two functions are totally enough.

def _rfft(x, signal_ndim=1, onesided=True):
# b = torch.Tensor([[1,2,3,4,5],[2,3,4,5,6]])
# b = torch.Tensor([[1,2,3,4,5,6],[2,3,4,5,6,7]])
# torch 1.8.0 torch.fft.rfft to torch 1.5.0 torch.rfft as signal_ndim=1
# written by mzero
odd_shape1 = (x.shape[1] % 2 != 0)
x = torch.fft.rfft(x)
x = torch.cat([x.real.unsqueeze(dim=2), x.imag.unsqueeze(dim=2)], dim=2)
if onesided == False:
_x = x[:, 1:, :].flip(dims=[1]).clone() if odd_shape1 else x[:, 1:-1, :].flip(dims=[1]).clone()
_x[:,:,1] = -1 * _x[:,:,1]
x = torch.cat([x, _x], dim=1)
return x

def _irfft(x, signal_ndim=1, onesided=True):
# b = torch.Tensor([[1,2,3,4,5],[2,3,4,5,6]])
# b = torch.Tensor([[1,2,3,4,5,6],[2,3,4,5,6,7]])
# torch 1.8.0 torch.fft.irfft to torch 1.5.0 torch.irfft as signal_ndim=1
# written by mzero
if onesided == False:
res_shape1 = x.shape[1]
x = x[:,:(x.shape[1] // 2 + 1),:]
x = torch.complex(x[:,:,0].float(), x[:,:,1].float())
x = torch.fft.irfft(x, n=res_shape1)
else:
x = torch.complex(x[:,:,0].float(), x[:,:,1].float())
x = torch.fft.irfft(x)
return x

Just share my experience.
Maybe we can inherit most of code while modifying only two line:

  1. Substitude rfft with fft in function dct:
def dct(x, norm=None):
    """
        <Skip comment>
    """
    x_shape = x.shape
    N = x_shape[-1]
    x = x.contiguous().view(-1, N)

    v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1)

    # Vc = torch.rfft(v, 1, onesided=False)           # comment this line
    Vc = torch.view_as_real(torch.fft.fft(v, dim=1))  # add this line

    k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N)
    W_r = torch.cos(k)
    W_i = torch.sin(k)

    V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i

    if norm == 'ortho':
        V[:, 0] /= np.sqrt(N) * 2
        V[:, 1:] /= np.sqrt(N / 2) * 2

    V = 2 * V.view(*x_shape)

    return V
  1. Substitude irfft with latest version in function idct:
def idct(X, norm=None):
    """
        <Skip comment>
    """

    x_shape = X.shape
    N = x_shape[-1]

    X_v = X.contiguous().view(-1, x_shape[-1]) / 2

    if norm == 'ortho':
        X_v[:, 0] *= np.sqrt(N) * 2
        X_v[:, 1:] *= np.sqrt(N / 2) * 2

    k = torch.arange(x_shape[-1], dtype=X.dtype, device=X.device)[None, :] * np.pi / (2 * N)
    W_r = torch.cos(k)
    W_i = torch.sin(k)

    V_t_r = X_v
    V_t_i = torch.cat([X_v[:, :1] * 0, -X_v.flip([1])[:, :-1]], dim=1)

    V_r = V_t_r * W_r - V_t_i * W_i
    V_i = V_t_r * W_i + V_t_i * W_r

    V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2)

    
    # v = torch.irfft(V, 1, onesided=False)                             # comment this line
    v= torch.fft.irfft(torch.view_as_complex(V), n=V.shape[1], dim=1)   # add this line

    x = v.new_zeros(v.shape)
    x[:, ::2] += v[:, :N - (N // 2)]
    x[:, 1::2] += v.flip([1])[:, :N // 2]

    return x.view(*x_shape)

With these two-line modification, we can perform the same functionality under the same codebase in PyTorch 1.7

I'm facing still the same issue, did you already include the improvements into the code?

@GeJulia We have not since I have not heard @zh217 for a very long time.
But you definitely can have a try with the code from PR #24, although from my point of view it is WIP. I gave some minor suggestions.

This can be closed now. Fixed in PR #24