jnp.fft.ifft imprecision for GPU
chih-kang-huang opened this issue · comments
Chih-Kang Huang commented
Description
Description
jnp.ifft and jnp.irfft seem to give significant errors when using GPU :
import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)
xs = jnp.linspace(0, 1, 100, endpoint=False)
f = lambda x : jnp.sin(4*jnp.pi*x)
f_tilde = jnp.fft.ifft(jnp.fft.fft(f(xs)))
print(abs(f_tilde -f(xs)).mean())
gives 1.42108305e-08 with float64 enabled 9e-08 with float32.
Whereas using CPU or other frameworks (numpy or torch), errors are around 1e-15
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.26
jaxlib: 0.4.26
numpy: 1.26.4
python: 3.11.9 (main, Apr 19 2024, 16:48:06) [GCC 11.2.0]
jax.devices (1 total, 1 local): [cuda(id=0)]