patrick-kidger / lineax

Linear solvers in JAX and Equinox. https://docs.kidger.site/lineax

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Batched and recursive computations cause trouble to lx.linear_solve and sometimes break the lx.NormalCG with a fixed max_step.

pierreablin opened this issue · comments

Batched and recursive computations cause trouble to lx.linear_solve and sometimes break the lx.NormalCG with a fixed max_step.

Working example:

import jax
import jax.numpy as jnp

import lineax as lx

p = 10
A = jax.random.normal(jax.random.PRNGKey(0), (p, p))
# define a linear function that is n_mat consecutive matrix-matrix multiplications
def linear_fn(x):
    x = x @ A.T
    return x

# use it in a batched way:

batch_size = 128
x = jnp.ones((batch_size, p))
target = linear_fn(x)  # get output

# define the operator

operator = lx.FunctionLinearOperator(
    fn=linear_fn, 
    input_structure=jax.eval_shape(lambda: x)
)

and then on my machine

lx.linear_solve(operator, target, solver=lx.NormalCG(1e-2, 1e-2, max_steps=10))

raises XlaRuntimeError: INTERNAL: Generated function failed: CpuCallback error: RuntimeError: The maximum number of solver steps was reached. Try increasing max_steps.

This error happens regardless of batch_size in the above code and does not happen consistently when changing the shape of A.

Another concern is the slowness of the automatic solver: on my machine, %timeit operator.mv(x) gives 28.6 µs, and computing the transpose with transpose = operator.transpose() and then doing %timeit transpose.mv(target) gives 113 µs. But %timeit lx.linear_solve(operator, target, solver=lx.AutoLinearSolver(well_posed=False)) gives 466 ms, which is 3000 times more costly than a .mv and a transpose evaluation. This ratio gets worse as the batch size increases, which I don't think should happen -- the number of iterations should stay about the same.

With a different shape of A so that lx.linear_solve(operator, target, solver=lx.NormalCG(1e-2, 1e-2, max_steps=10)) does not break, I also observe the same problematic scaling.

Thanks for your help

PS: Thanks for the awesome library; it is really helpful!

I don't think this issue is caused by the batching actually, and it makes sense
it would fail for every batch size (or even no batch.)

CG is pretty susceptible to floating-point errors, especially NormalCG since the
condition number is squared. Generally this means taking a few more than N iterations
for an N x N matrix to correct the floating point errors that have accumulated.
Indeed, setting max_steps=12 or max_steps=13 should fix the issue.

Note that although this doesn't align with the math of CG, it is common
numerically. JAX and Scipy both also fail to solve this with 10 iterations for the same
reason, but JAX does not raise an error at all (quite bad!) and scipy passes its error
information silently via info, which is equivalent to running
lx.linear_solve(..., throw=False).

Regarding lx.AutoLinearSolve, this will never default to an iterative method. This
is because iterative methods struggle when dealing with dense matrices, which
is the typical use-case we expect when using lx.AutoLinearSolve. Specifically,
lx.AutoLinearSolve(well_posed=False) will default to lx.SVD() unless the operator
has the tag lx.diagonal_tag.

Thanks for your swift answer!

So the expected behaviour is that if after max_steps iterations convergence has not been reached, an error is returned?

Whether an error is thrown or not is controlled by the keyword-only argument throw, which defaults to True. So solution = lx.linear_solve(operator, vector, max_steps=max_steps) will throw an error if it does not converge in max_steps iterations, but solution = lx.linear_solve(operator, vector, max_steps=max_steps, throw=False) will not. Note that regardless this information will still be included in solution.results, see: https://docs.kidger.site/lineax/api/solution/.

Ok, this is very clear! Thanks.