google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Home Page:http://jax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[CPU] `np.linalg.eigh` outputs NaNs for matrix side size >=32767

romanngg opened this issue · comments

jax-0.3.7 jaxlib-0.3.7 numpy-1.21.6

import jax.numpy as np
import jax

jax.config.update('jax_enable_x64', False)

n = 50_000
array = np.eye(n)


vals, vecs = np.linalg.eigh(array, 'U', symmetrize_input=False)

print(vals.shape, vecs.shape)
print(np.mean(vals), np.mean(vecs))

Outputs, for n in 1000, 10_000, 32_766, and 32_767:

romann@romann2 ~/P/pythonProject> python main.py       (base)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
(1000,) (1000, 1000)
1.0 0.001
romann@romann2 ~/P/pythonProject> python main.py       (base)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
(10000,) (10000, 10000)
1.0 1e-04
romann@romann2 ~/P/pythonProject> python main.py       (base)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
(32766,) (32766, 32766)
1.0 3.051944e-05
romann@romann2 ~/P/pythonProject> python main.py       (base)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
 ** On entry to SSTEDC parameter number  8 had an illegal value
(32767,) (32767, 32767)
nan nan

The correct answer that should be printed for each n is

(n,), (n, n)
1.0, 1 / n

Interestingly, for n = 50_000, there happens a different issue described in #10411:

romann@romann2 ~/P/pythonProject> python main.py       (base)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
 ** On entry to SORMTRULSPBCONSafe minimumNon-unitTransposeUpperNo transposeLower parameter number 12 had an illegal value
(50000,) (50000, 50000)
1.0 2e-05
free(): invalid pointer
fish: Job 1, 'python main.py' terminated by signal SIGABRT (Abort)

The outputs are printed out correctly, but the program also outputs new error messages, and terminated by signal SIGABRT (Abort). I think this causes running the same computation in Colab to crash (#10411).

Related to: #4358.

I'm reasonably sure what's happening here is that JAX is passing parameters to LAPACK correctly, but we're using a LAPACK built with 32-bit integers. sstedc, which LAPACK (not JAX) calls internally, has the property that its workspace is O(n**2). Naturally that goes badly for n around 2**15.

I think we probably need to switch to an ILP64-built LAPACK. However that means that we'll have to get LAPACK from somewhere other than SciPy's cython exports.

Another option that might work for this specific function is to use a different LAPACK driver function. I believe ssyevr instead of ssyevd needs a slightly smaller workspace, and that is what scipy uses by default.

As a first step to resolve this issue, I suggest catching 32-bit integer overflow before executing LAPACK functions as these will most likely produce garbage or lead to crashes in the case of overflow anyway. The overflow problem applies to all LAPACK functions, not just for the ones used in linalg.eigh.

Switching to ILP64-built LAPACK would be the next step. There exists many options that require some effort. For instance, using ILP64-enabled scipy (it likely requires some work at the scipy side that has ILP64 support but not in releases) or use some other LAPACK library as an dependency such as Intel MKL, etc that provide the ILP64 support.

Does scipy have plans to make an ILP64 release?

I'm not excited to add our own build of LAPACK, Fortran toolchain and all, but another possibility is to scavenge a different LAPACK from the environment, e.g., perhaps if the user installs an ILP64 LAPACK through dpkg or something we can dlopen that, and use scipy if there's nothing else.

There is https://pypi.org/project/scipy-openblas64/ that provides LAPACK in ILP64 mode.