google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Home Page:http://jax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[CPU] `np.linalg.eigh` crash on a 50k x 50k matrix

romanngg opened this issue · comments

On a CPU, with plenty of RAM (> 150Gb) and jax-0.3.7 jaxlib-0.3.7 numpy-1.21.6 I can't perform an eigendecomposition of a 50k x 50k matrix.

The code below

import jax.numpy as np
import jax

jax.config.update('jax_enable_x64', False)

n = 50_000
array = np.diag(jax.random.normal(jax.random.PRNGKey(1), (n,), np.float32))


vals, vecs = np.linalg.eigh(array, 'U', symmetrize_input=False)

print(vals.shape, vecs.shape)
print(np.mean(vals), np.mean(vecs))

prints the output, but also an error message below, and ends aborted:

romann@romann2 ~/P/pythonProject> python main.py               (base)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
 ** On entry to SORMTRULSPBCONSafe minimumNon-unitTransposeUpperNo transposeLower parameter number 12 had an illegal value
(50000,) (50000, 50000)
-0.0050003906 2e-05
free(): invalid pointer
fish: Job 1, 'python main.py' terminated by signal SIGABRT (Abort)

In a Colab, when running the same code, the runtime crashes for an unknown reason. Logs:
app.log

I have verified that performing the same on a 32_767 x 32_767 matrix works: https://colab.research.google.com/gist/icml2022anon/7137667c6ded0c0043ba2391c7ee9b05/32767_ok.ipynb

Potentially relevant to #4358, however, this particular bug is indeed fixed. The above errors are happening for n = 50_000 (but no crash for n = 32_767; however, wrong outputs may be computed, see #10420).