linalg.solve produces NaNs on GPU, but not on CPU
fkaab opened this issue · comments
fkaab commented
Description
This may be related to #20047 or lineax/#79.
When using fmmax to simulate a system using a vector RCWA formulation jax.numpy.linalg.solve
produces an array of NaN
using the GPU backend, but not using the CPU backend, which is why I suspect the error lies with jax. This is the error message produced using the NaN debugging flag:
FloatingPointError: invalid value (nan) encountered in jit(triangular_solve). Because jax_config.debug_nans.value and/or config.jax_debug_infs is set, the de-optimized function (i.e., the function as if the `jit` decorator were removed) was called in an attempt to get a more precise error message. However, the de-optimized function did not produce invalid values during its execution. This behavior can result from `jit` optimizations causing the invalid value to be produced. It may also arise from having nan/inf constants as outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`.
In my use case the error occured here. The input arrays had shape (490, 490)
and (490,)
, the dtype
was complex64
.
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.27
jaxlib: 0.4.27
numpy: 1.26.4
python: 3.11.5 (main, Oct 25 2023, 16:19:59) [GCC 8.5.0 20210514 (Red Hat 8.5.0-20)]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
platform: uname_result(system='Linux', release='4.18.0-513.24.1.el8_9.x86_64', version='#1 SMP Thu Apr 4 18:13:02 UTC 2024', machine='x86_64')
$ nvidia-smi
Wed May 8 00:34:18 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15 Driver Version: 550.54.15 CUDA Version: 12.4 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA H100 On | 00000000:9D:00.0 Off | 0 |
| N/A 36C P0 82W / 700W | 534MiB / 95830MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 157152 C python 524MiB |
+-----------------------------------------------------------------------------------------+