[CPU] `np.linalg.eigh` crash on a 50k x 50k matrix
romanngg opened this issue · comments
On a CPU, with plenty of RAM (> 150Gb) and jax-0.3.7 jaxlib-0.3.7 numpy-1.21.6
I can't perform an eigendecomposition of a 50k x 50k matrix.
The code below
import jax.numpy as np
import jax
jax.config.update('jax_enable_x64', False)
n = 50_000
array = np.diag(jax.random.normal(jax.random.PRNGKey(1), (n,), np.float32))
vals, vecs = np.linalg.eigh(array, 'U', symmetrize_input=False)
print(vals.shape, vecs.shape)
print(np.mean(vals), np.mean(vecs))
prints the output, but also an error message below, and ends aborted:
romann@romann2 ~/P/pythonProject> python main.py (base)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
** On entry to SORMTRULSPBCONSafe minimumNon-unitTransposeUpperNo transposeLower parameter number 12 had an illegal value
(50000,) (50000, 50000)
-0.0050003906 2e-05
free(): invalid pointer
fish: Job 1, 'python main.py' terminated by signal SIGABRT (Abort)
In a Colab, when running the same code, the runtime crashes for an unknown reason. Logs:
app.log
I have verified that performing the same on a 32_767 x 32_767
matrix works: https://colab.research.google.com/gist/icml2022anon/7137667c6ded0c0043ba2391c7ee9b05/32767_ok.ipynb
Potentially relevant to #4358, however, this particular bug is indeed fixed. The above errors are happening for n = 50_000
(but no crash for n = 32_767
; however, wrong outputs may be computed, see #10420).