google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Home Page:http://jax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

jax.scipy.sparse.linalg.cg inconsistent results between runs

Dinple opened this issue · comments

Hi all,

the conjugate gradient function inside jax.scipy.sparse seems to be very inconsistent on jax GPU. I'm a new user to jax so im not sure if this issue has been addressed somewhere. I believe it is somewhat related to #565 and #9784.

To see the full picture, I have saved both input A and b so I can get consistent result between each runs. No preconditioning is applied.

I tested my result on three platforms: CPU[colab + local], GPU[colab + local] and TPU[colab].

Out of all the runs I have done, these three platforms all produce different results but only GPU has inconsistent issue between runs.

  • On local machine, jax on CPU produces exactly the same result with colab CPU. And it is CONSISTENT between different runs.
  • On colab, jax on TPU is also CONSISTENT between different runs.
  • On GPU, both colab and my local machine has large INCONSISTENCY between runs. Sometimes even output a nan matrix.

I have seen people mention the issue with CUDA version, so I tested out cuda11.1, 11.2 and 11.4 and they all have the same issue.

To see how much changes it make, heres the output of three different runs:
DeviceArray([ 9.28246680e+03, 1.50545068e+04, 1.90608145e+04, 2.23634746e+04, 2.50702012e+04, 2.76033926e+04, 2.99257559e+04, 3.21613457e+04, 3.42872852e+04,...

DeviceArray([-8.13425984e-03, -1.17020588e-02, -1.27483038e-02, -1.18785836e-02, -9.67487786e-03, -6.41405629e-03, -2.11878261e-03, 3.24898120e-03, 9.95288976e-03,...

DeviceArray([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,...

I am using
jax 0.3.4
jaxlib 0.3.2+cuda11.cudnn82
scipy 1.8.0
numpy 1.22.3

FYI: here is a minimal example: https://colab.research.google.com/drive/1Z802HuUZ_TCTRxeQNRvC6XJppEiGAiyQ?zusp=sharing

For nearly singular matrices (e.g. conditioned graph Laplacians), cg returns different (sometimes valid, sometimes nan) solutions on GPU between different runs when x0 is not set. Is this expected? GMRES seems a bit more stable.

cond = np.linalg.cond(A.todense())
print(cond)
1.1747166e+17

It seems the condition number is very large. Is this what you expected?

Yes- so that's one thing. It's mainly an issue for poorly conditioned matrices. I guess typical and correct use cases
wont suffer from this problem.

But the main thing is the inconsistency we see between cg on gpu and cg on cpu. Is there any reason
that running cg on different devices leads to different results despite the same input?

I guess my question is less about finding a solution (although that would be great) and more verifying that this expected behavior (and not a sign of a deeper problem)?

Looks like the input is symmetric positive definite. I have checked
np.all(np.linalg.eigvals(A.todense()) > 0) and
np.allclose(A_dense, A_dense.T.conj(), rtol=rtol, atol=atol)

Have you tried using float64?

The nondeterministic behavior may come from multithreading.

Was this issue resolved? @Dinple

I reran the colab repro.

On Colab GPU, I get values which are fairly consistently large and similar (with some occasional nan, inf)

[0.0000000e+00 9.8054398e+10 1.9872968e+11 6.2595138e+10 2.0462353e+11
 2.4414941e+10 1.1615307e+11 7.3973126e+12 2.6439701e+13 4.8272494e+11]

[0.0000000e+00 2.8103875e+13 4.0418297e+13 4.0165729e+13 4.6595383e+13
 4.9564275e+13 4.1783241e+13 9.9750607e+13 3.7144668e+13 4.0531598e+13]
 
[0.0000000e+00 2.3296439e+11 9.7694744e+13 4.8008027e+11 1.4699356e+11
 7.5435302e+12 9.5971885e+10 9.2001908e+11 2.6396608e+11 9.4231685e+11]

[0.0000000e+00 3.2017095e+11 2.1751036e+13 4.9183644e+11 2.9387155e+11
 2.1419085e+12 3.1061000e+11 1.8692791e+11 1.2441883e+13 4.7176105e+12]

[0.0000000e+00 5.7433942e+11 3.3398978e+12 3.7421702e+11 2.7343428e+12
 1.1824465e+12 5.5289119e+11 2.3275011e+13 8.5892014e+10 3.8059177e+12]

[0.0000000e+00 4.9893282e+10 1.5721413e+11 7.4679969e+13 3.8181389e+12
 1.5548703e+10 1.5535525e+12 2.3955174e+11 1.4923589e+13           nan]

[0.0000000e+00 6.1637303e+11 5.2090693e+11 6.1480265e+11 6.0722269e+11
           inf 2.5805787e+13 4.8744081e+11 6.9813535e+11 3.0362653e+13]

[0.0000000e+00 2.8956185e+13 8.2401290e+11 7.0151563e+11 9.3606137e+11
 4.9036857e+11 2.1517848e+12 1.5174398e+12 6.1058523e+12 7.8914683e+11]

[0.0000000e+00 2.8998506e+12 3.1845496e+11 3.1819612e+11           inf
 1.7479309e+13 1.3300310e+11 2.3351959e+13 7.2910045e+10 1.4908412e+13]

[0.0000000e+00 6.4928147e+11 8.0207806e+11 1.4854672e+12 7.3838559e+11
 1.3758692e+12 8.9137250e+11 1.3934976e+12 9.1503893e+11 2.3913224e+12]

On Colab TPU, I get consistently all zeros:

[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]

[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]

[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]

[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]

After some debugging, I found out that cuSPARSE isn't being called.
The from_scipy_sparse function sets indices_sorted=False and then in _bcoo_dot_general_gpu_lowering the _bcoo_dot_general_default_lowering function is being called based on the indices_sorted value. So the issue is in _bcoo_dot_general_impl.
Sorting the indices (by adding A = A.sort_indices()) makes the code call cuSPARSE and then the results are consistent.

Thanks for looking into this! We recently enable the lowering of BCOO dot_general to cuSparse (#12138). Yes, indices_sorted=True is one of the requirements for using cuSparse.

@tlu7
Just another remark - currently jax uses CUSPARSE_MV_ALG_DEFAULT as a parameter to cusparse spmv in files https://github.com/google/jax/blob/main/jaxlib/cuda/cusparse_kernels.cc and https://github.com/google/jax/blob/main/jaxlib/cuda/cusparse.cc , which is deprecated and might default to non-deterministic result in general. I would suggest using CUSPARSE_SPMV_COO_ALG2 instead - according to docs https://docs.nvidia.com/cuda/cusparse/index.html#cusparse-generic-function-spmv :

Provides deterministic (bit-wise) results for each run. If opA != CUSPARSE_OPERATION_NON_TRANSPOSE, it is identical to CUSPARSE_SPMV_COO_ALG1

Thanks for the suggestions! Can you also share insights on the cusparse matmat algorithms? Which one shall we use as the default for jax? @marsaev

https://docs.nvidia.com/cuda/cusparse/index.html#cusparse-generic-function-spmm

@jakevdp do have a suggestion what may be wrong in _bcoo_dot_general_impl? Is it the same algorithm that is being used for TPUs/CPUs?

@tlu7

  1. If determinism (reproducibility) is needed (which i assume true for JAX), then there is only one option - for SpMV use CUSPARSE_SPMV_COO_ALG2 and CUSPARSE_SPMV_CSR_ALG2, for SpMM use CUSPARSE_SPMM_COO_ALG2 and CUSPARSE_SPMM_CSR_ALG3 according to the matrix format.
  2. If there is no such requirements - unfortunately there is no heuristics available, only guidance from the docs for those functions, i.e. https://docs.nvidia.com/cuda/cusparse/index.html#cusparse-generic-function-spmm . There are also studies like https://arxiv.org/pdf/2202.08556.pdf that show other users experience with different algorithm values.

@marsaev @fbusato

Thank you. A follow-up question about the cuda versions for these new algorithms.

I found this in the release notes

[2.5.12. cuSPARSE: Release 11.2 Update 1](https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html#cusparse-11.2.1)
...
New algorithms for CSR/COO Sparse Matrix - Vector Multiplication (cusparseSpMV) with better performance.
....
New algorithm (CUSPARSE_SPMM_CSR_ALG3) for Sparse Matrix - Matrix Multiplication
...

Shall I assume that all these algorithms are available of cuda 11.2 and onwards. Is there any document that I can find this information? I need the version information to make sure those JAX routines are backward compatible.

@marsaev @fbusato

Comparing to the default algorithms for SpMV and SpMM, do the four algorithms that provide determinism have trade-offs on accuracy? I broke a few accuracy tests due to the change from the default algorithms to the four aforementioned algorithms.

Hi @tlu7,

Shall I assume that all these algorithms are available of cuda 11.2 and onwards. Is there any document that I can find this information? I need the version information to make sure those JAX routines are backward compatible.

Yes, these algorithm enumerators are compatible with CUDA 11.x and 12.x. It could change in CUDA 13.x.

Comparing to the default algorithms for SpMV and SpMM, do the four algorithms that provide determinism have trade-offs on accuracy? I broke a few accuracy tests due to the change from the default algorithms to the four aforementioned algorithms.

no, we cannot say that one algorithm is more accurate than another one

Thanks @fbusato !

Can you share more information on the versions in 11.x when the four algorithms become available?

It seems CUSPARSE_SPMM_CSR_ALG3 is in since 11.2 and the other three is unclear from the release notes.

2.5.12. cuSPARSE: Release 11.2 Update 1
...
New algorithms for CSR/COO Sparse Matrix - Vector Multiplication (cusparseSpMV) with better performance.
....
New algorithm (CUSPARSE_SPMM_CSR_ALG3) for Sparse Matrix - Matrix Multiplication
...

There is a small trick that you can use to check old toolkit documentations 😀 https://developer.nvidia.com/cuda-toolkit-archive
CUSPARSE_SPMM_CSR_ALG3 and SpMV algorithms have been introduced in CUDA 11.2u1 https://docs.nvidia.com/cuda/archive/11.2.1/cusparse/index.html#cusparse-generic-function-spmm

Thanks! It works like a charm. need some patience though :)