google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Home Page:http://jax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

jax.lax.linalg.eigh on GPU and multi-core CPU doesn't parallel appropriately.

YouJiacheng opened this issue · comments

import jax
import jax.numpy as jnp

def timer(f):
    from time import time
    f() # warmup and compile
    t = time()
    for _ in range(3):
        f()
    print((time() - t) / 3)

y = jax.random.uniform(jax.random.PRNGKey(0), (16, 1024, 1024)) / 16
s = jax.block_until_ready(y @ y.transpose(0, 2, 1) + jnp.eye(1024))

from jax.lax.linalg import eigh as jeigh
f = jax.jit(jax.vmap(jeigh))
timer(lambda: jax.block_until_ready(f(s))) # 0.90s for 16 problems

from scipy.linalg import eigh as seigh
import numpy as np
ss = np.array(s[0])
timer(lambda: seigh(ss)) # 0.21s for 1 problem

GPU: V100-PCIE 16G
CPU: Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz

jax.lax.linalg.eigh on 1 GPU use 0.90s for 16 problems. on all CPU-core(top report 2340% peak CPU usage) use 2.44s for 16 problems.
scipy.linalg.eigh on 1 CPU-core(top report 200% peak CPU usage) use 0.21s for 1 problem.
This result means that, GPU only have <4x throughput, and >11x CPU usage only have <1.4x throughput, while there should be a embarrassingly parallel given vmap.

Ultimately JAX is at the mercy of the algorithms provided by cusolver here. For small matrices (smaller than 32x32), JAX currently uses the batched Jacobi solver that Nvidia provides. For larger matrices, JAX currently iterates over the batch elements sequentially, so you should expect no speedup from vmap.

There are a number of things one could try here.

One would be to try the batched Jacobi solver at larger sizes (

if n <= 32:
and its HLO-only cousin a few lines above), see also https://docs.nvidia.com/cuda/cusolver/index.html#cuSolverDN-lt-t-gt-syevjbatch
Note this code is in jaxlib, although it's in the Python part of jaxlib so you can just locally edit your copy to play with it.

Another would be for jaxlib to solve multiple eigendecomposition problems in parallel on multiple CUDA streams. That would only be profitable if you aren't fully occupying GPU and CPU.

Thanks for speedy reply! IIUC, I can change the jaxlib python code without building jaxlib by myself, and let jaxlib use batched jacobi solver for large matrices as well.

Yes, you could just edit the (installed) copy of cusolver.py to alter the threshold. Does it help?

It helps! 0.90s -> 0.55s. Thank you so much! (But it is still much slower than my expectation.)
And I wonder why CPU version of jax.lax.linalg.eigh + vmap doesn't linear speedup comparing to single core scipy, it has >11x peak CPU usage.

You could send a PR altering the threshold, if you like, although we'd probably need to collect a wider range of timings at different sizes and batch sizes.

The CPU version also just calls a LAPACK function in a loop to handle batches. In fact, it's a LAPACK function we use provided by scipy, so I'd be surprised if you saw any speedup over scipy at all. That said, the algorithm does use parallelism internally at least for some of the phases. If we aren't getting enough parallelism, we could consider using multiple threads.

We don't have a batched eigh on CPU (as far as I am aware, no-one does on CPU, although some of the algorithms that work well when vectorized on GPU and TPU might work well on CPU also particularly for small matrix sizes, e.g., a vectorized Jacobi solver).

In fact, it's a LAPACK function we use provided by scipy, so I'd be surprised if you saw any speedup over scipy at all. That said, the algorithm does use parallelism internally at least for some of the phases.

JAX will use multiple core, while scipy only use single core. But JAX with multiple core only has a bit speed up, at the cost of preventing user manually using spmd/data parallel.

Can we have pytorch-like set_num_threads and set_num_interop_threads to control the parallel?

import torch

torch.set_num_threads(1)
torch.set_num_interop_threads(24)

@torch.jit.script
def mt_eigh(x: torch.Tensor):
    futs = [torch.jit._fork(torch.linalg.eigh, x[i]) for i in range(24)]
    return [torch.jit._wait(fut) for fut in futs]

I find that this(7.5s for 24*1024*320*320) is 50x faster than JAX on 24-core CPU (15.6s for 1024*320*320) and 40x faster than naively let pytorch use intra-op parallelism with 24 threads(12.4s for 1024*320*320). --- which is actually 1.8x slower than single thread(6.9s for 1024*320*320), 2.3x slower than 4 threads(5.4s for 1024*320*320).