patrick-kidger / lineax

Linear solvers in JAX and Equinox. https://docs.kidger.site/lineax

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Is lineax slower than the linear solver in JAX?

ToshiyukiBandai opened this issue · comments

Hi, thank you for creating the awesome libraries in JAX. I started to use lineax recently and compared it with the linear solver in JAX. The code below resulted in 931 us for lineax and 171 us for jnp.linalg.solve. Is there anything wrong with my implementation? Or, should I just stick to jnp.linalg.solve? No way to use _gesv Fortran routine through lineax?

from jax import random
import jax.numpy as jnp
import lineax as lx

matrix_key, vector_key = random.split(random.PRNGKey(0))
matrix = random.normal(matrix_key, (10, 10))
vector = random.normal(vector_key, (10,))

operator = lx.MatrixLinearOperator(matrix)
solution = lx.linear_solve(operator, vector)

%timeit lx.linear_solve(operator, vector, solver=lx.LU())

%timeit jnp.linalg.solve(matrix, vector)

Looks like the overhead is from two things:

  1. Error-checking on the Lineax output. By default Lineax has an extra check that the return doesn't have NaNs etc., i.e. that the solve was successful. This can be disabled by passing linear_solve(..., throw=False).

  2. Pytree flattening/unflattening across JIT boundaries. matrix and vector are simpler PyTrees than operator and lx.LU().

With this benchmark I obtain identical performance:

import jax
import jax.numpy as jnp
import jax.random as jr
import lineax as lx
import timeit

matrix_key, vector_key = jr.split(jr.PRNGKey(0))
matrix = jr.normal(matrix_key, (10, 10))
vector = jr.normal(vector_key, (10,))

@jax.jit
def solve_lineax(matrix, vector):
    operator = lx.MatrixLinearOperator(matrix)
    sol = lx.linear_solve(operator, vector, throw=False)
    return sol.value

@jax.jit
def solve_jax(matrix, vector):
    return jnp.linalg.solve(matrix, vector)

time_lineax = lambda: jax.block_until_ready(solve_lineax(matrix, vector))
time_jax = lambda: jax.block_until_ready(solve_jax(matrix, vector))

print(min(timeit.repeat(time_jax, number=1, repeat=10)))
print(min(timeit.repeat(time_lineax, number=1, repeat=10)))

Hi Patrick,

I got the same results too. Thank you!