Is lineax slower than the linear solver in JAX?
ToshiyukiBandai opened this issue · comments
Hi, thank you for creating the awesome libraries in JAX. I started to use lineax recently and compared it with the linear solver in JAX. The code below resulted in 931 us for lineax and 171 us for jnp.linalg.solve. Is there anything wrong with my implementation? Or, should I just stick to jnp.linalg.solve? No way to use _gesv Fortran routine through lineax?
from jax import random
import jax.numpy as jnp
import lineax as lx
matrix_key, vector_key = random.split(random.PRNGKey(0))
matrix = random.normal(matrix_key, (10, 10))
vector = random.normal(vector_key, (10,))
operator = lx.MatrixLinearOperator(matrix)
solution = lx.linear_solve(operator, vector)
%timeit lx.linear_solve(operator, vector, solver=lx.LU())
%timeit jnp.linalg.solve(matrix, vector)
Looks like the overhead is from two things:
-
Error-checking on the Lineax output. By default Lineax has an extra check that the return doesn't have NaNs etc., i.e. that the solve was successful. This can be disabled by passing
linear_solve(..., throw=False)
. -
Pytree flattening/unflattening across JIT boundaries.
matrix
andvector
are simpler PyTrees thanoperator
andlx.LU()
.
With this benchmark I obtain identical performance:
import jax
import jax.numpy as jnp
import jax.random as jr
import lineax as lx
import timeit
matrix_key, vector_key = jr.split(jr.PRNGKey(0))
matrix = jr.normal(matrix_key, (10, 10))
vector = jr.normal(vector_key, (10,))
@jax.jit
def solve_lineax(matrix, vector):
operator = lx.MatrixLinearOperator(matrix)
sol = lx.linear_solve(operator, vector, throw=False)
return sol.value
@jax.jit
def solve_jax(matrix, vector):
return jnp.linalg.solve(matrix, vector)
time_lineax = lambda: jax.block_until_ready(solve_lineax(matrix, vector))
time_jax = lambda: jax.block_until_ready(solve_jax(matrix, vector))
print(min(timeit.repeat(time_jax, number=1, repeat=10)))
print(min(timeit.repeat(time_lineax, number=1, repeat=10)))
Hi Patrick,
I got the same results too. Thank you!