google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Home Page:http://jax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

linalg.solve produces NaNs on GPU, but not on CPU

fkaab opened this issue · comments

commented

Description

This may be related to #20047 or lineax/#79.

When using fmmax to simulate a system using a vector RCWA formulation jax.numpy.linalg.solve produces an array of NaN using the GPU backend, but not using the CPU backend, which is why I suspect the error lies with jax. This is the error message produced using the NaN debugging flag:

FloatingPointError: invalid value (nan) encountered in jit(triangular_solve). Because jax_config.debug_nans.value and/or config.jax_debug_infs is set, the de-optimized function (i.e., the function as if the `jit` decorator were removed) was called in an attempt to get a more precise error message. However, the de-optimized function did not produce invalid values during its execution. This behavior can result from `jit` optimizations causing the invalid value to be produced. It may also arise from having nan/inf constants as outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`. 

In my use case the error occured here. The input arrays had shape (490, 490) and (490,), the dtype was complex64.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.27
jaxlib: 0.4.27
numpy:  1.26.4
python: 3.11.5 (main, Oct 25 2023, 16:19:59) [GCC 8.5.0 20210514 (Red Hat 8.5.0-20)]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
platform: uname_result(system='Linux', release='4.18.0-513.24.1.el8_9.x86_64', version='#1 SMP Thu Apr 4 18:13:02 UTC 2024', machine='x86_64')


$ nvidia-smi
Wed May  8 00:34:18 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA H100                    On  |   00000000:9D:00.0 Off |                    0 |
| N/A   36C    P0             82W /  700W |     534MiB /  95830MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A    157152      C   python                                        524MiB |
+-----------------------------------------------------------------------------------------+