patrick-kidger / lineax

Linear solvers in JAX and Equinox. https://docs.kidger.site/lineax

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

How to use PyTreeLinearOperator

ToshiyukiBandai opened this issue · comments

Hi lineax community,

I came from Patrick's comment (jax-ml/jax#17203 (comment)). I believe PyTreeLinearOperator will do the job but am struggling to use it correctly. In the example below, I want to solve a Newton system $Jx = -F$, where the Jacobian matrix $F$ is a PyTree. How can I use PyTreeLinearOperator correctly?

import jax
from jax import jit, vmap, lax, jacfwd, jacrev, grad, vjp, jvp, random
import jax.numpy as jnp
from jax.config import config
from jax.tree_util import tree_structure, tree_flatten, tree_unflatten
import equinox as eqx
from jaxtyping import Float, Array, Bool
import lineax as lx

class Parameter(eqx.Module):
    alpha: Float[Array, ""]
    beta: Float[Array, ""]
    def __init__(self, alpha, beta):
        self.alpha = alpha
        self.beta = beta

class State(eqx.Module):
    x_0: Array
    x_1: Array
    def __init__(self, x_guess):
        self.x_0 = x_guess[0]
        self.x_1 = x_guess[1]

class Model(eqx.Module):
    parameters: eqx.Module
    def __init__(self, parameters):
        self.parameters = parameters
    def residual(self, state):
        F_0 = state.x_0**2 + state.x_1**2 - self.parameters.alpha
        F_1 = self.parameters.beta*state.x_0**3 - state.x_1
        return jnp.array([F_0, F_1])

alpha = 4.0
beta = 1.0
x_test = jnp.array([2.0, 3.0])

parameter = Parameter(alpha, beta)
state = State(x_test)
model = Model(parameter)

Jacobian_JAX_class = jacfwd(model.residual, argnums=0, has_aux=False)
F = model.residual(state)
J = Jacobian_JAX_class(state)
J = lx.PyTreeLinearOperator(J, jax.eval_shape(lambda: state))
# lx.linear_solve(J, F) # failed

So the issue in this example is F is a JAX array, but the output structure of J has the PyTree structure of State, as set by the out_structure in lx.PyTreeLinearOperator(fn, out_structure). So $Jx$ has the PyTree structure of State but $F$ has the PyTree structure of a standard JAX array. Throwing an error is the correct thing to do here, since $Jx = F$ doesn't make sense when $Jx$ and $F$ have different PyTree structures.

There's a number of ways to fix this. The cleanest is probably to use lx.JacobianLinearOperator, which exists specifically to simplify cases like this. In this case, the final code block becomes:

F = model.residual(state)
J = lx.JacobianLinearOperator(lambda x, a: model.residual(x), state)
lx.linear_solve(J, F)

where the lambda is there because in general we anticipate that the function in JacobianLinearOperator can take an extra args argument. There's no need to define Jacobian_JAX_class here, it's taken care of by lx.JacobianLinearOperator.

If you wanted to stick with PyTreeLinearOperator, then either replace the eval_shape with jax.eval_shape(lambda: F)

Jacobian_JAX_class = jacfwd(model.residual, argnums=0, has_aux=False)
F = model.residual(state)
J = Jacobian_JAX_class(state)
J = lx.PyTreeLinearOperator(J, jax.eval_shape(lambda: F))
soln = lx.linear_solve(J, F)
soln.value # A JAX array with the same shape as F

or make F have a PyTree structure of State

Jacobian_JAX_class = jacfwd(model.residual, argnums=0, has_aux=False)
F = State(model.residual(state))
J = Jacobian_JAX_class(state)
J = lx.PyTreeLinearOperator(J, jax.eval_shape(lambda: state))
soln = lx.linear_solve(J, F)
soln.value # A PyTree with the same structure as State

However, I'm assuming this third option is not what you had in mind.

Hi Jason,

Thank you so much! It worked pretty well. As you might have guessed, it is used to solve some PDEs. Now, I am thinking of just dumping my solver and switching to optimistix in fact because it can do all the job under the hood (if they support Newton with backtracking line search but it seems not).

It's true, Newton with backtracking line search isn't something we've implemented in Optimistix yet. (patrick-kidger/optimistix#4)

It should be essentially straightforward to do, though: Newton and Gauss-Newton are basically the same algorithm, just applied to different problems. As such copy-pasting optx.AbstractGaussNewton would get us 95% of the way there. (If you feel up to we'd be happy to take a PR on that.)

Just to be clear, are you handling a general minimisation problem, or a nonlinear least-squares problem? In your example problem, $F$ is a residual and $J$ the Jacobian, so the solution $x = -J^{-1}F$ is actually the Gauss-Newton step, assuming your loss is $F_1^2 + F_2^2$.

If your actual PDE is also of this form, ie. you have some vector of residuals and you'd like to minimise the sum of their squares, then check out optx.AbstractGaussNewton and optx.BacktrackingArmijo with the mix-and-match API.

I am solving the system of non-linear equations (resulting from nonlinear PDEs) by the Newton method with backtracking line search. So, it's a root finding problem.

Gauss-Newton and backtracking Armijo with the mix-and-match API should work then. Gauss-Newton is mathematically equivalent to Newton for nonlinear systems. As Patrick said, they're just applied a little differently.

The aim is to solve $F(x) = 0$. Recasting as a nonlinear least-squares problem $\min_x \frac{1}{2} \Vert F(x) \Vert_2^2$, the residual vector is $F(x)$. For Jacobian $J$ of $F(x)$ the Gauss-Newton update is

$$ \begin{aligned} x_{k + 1} &= x_k - (J^T J)^{-1} J^T F(x) \\ &= x_k - J^{-1} (J^T)^{-1} J^T F(x) \\ &= x_k - J^{-1} F(x). \end{aligned}$$

Which is the Newton update. You don't have to do this conversion manually, calling optx.root_find(fn, YourFancyNewSolver, ...) will automatically convert to the least-squares problem $\min_x \frac{1}{2} \Vert F(x) \Vert_2^2$ and solve it as described above.

Okay, that sounds good too. I will give it a shot. I will keep you posted.