google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Home Page:http://jax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

jnp.fft.ifft imprecision for GPU

chih-kang-huang opened this issue · comments

Description

Description

jnp.ifft and jnp.irfft seem to give significant errors when using GPU :

import jax 
import jax.numpy as jnp 

jax.config.update("jax_enable_x64", True)
xs = jnp.linspace(0, 1, 100, endpoint=False) 
f = lambda x : jnp.sin(4*jnp.pi*x)

f_tilde = jnp.fft.ifft(jnp.fft.fft(f(xs)))
print(abs(f_tilde -f(xs)).mean())

gives 1.42108305e-08 with float64 enabled 9e-08 with float32.

Whereas using CPU or other frameworks (numpy or torch), errors are around 1e-15

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.26
jaxlib: 0.4.26
numpy:  1.26.4
python: 3.11.9 (main, Apr 19 2024, 16:48:06) [GCC 11.2.0]
jax.devices (1 total, 1 local): [cuda(id=0)]