How to use PyTreeLinearOperator
ToshiyukiBandai opened this issue · comments
Hi lineax community,
I came from Patrick's comment (jax-ml/jax#17203 (comment)). I believe PyTreeLinearOperator
will do the job but am struggling to use it correctly. In the example below, I want to solve a Newton system PyTreeLinearOperator
correctly?
import jax
from jax import jit, vmap, lax, jacfwd, jacrev, grad, vjp, jvp, random
import jax.numpy as jnp
from jax.config import config
from jax.tree_util import tree_structure, tree_flatten, tree_unflatten
import equinox as eqx
from jaxtyping import Float, Array, Bool
import lineax as lx
class Parameter(eqx.Module):
alpha: Float[Array, ""]
beta: Float[Array, ""]
def __init__(self, alpha, beta):
self.alpha = alpha
self.beta = beta
class State(eqx.Module):
x_0: Array
x_1: Array
def __init__(self, x_guess):
self.x_0 = x_guess[0]
self.x_1 = x_guess[1]
class Model(eqx.Module):
parameters: eqx.Module
def __init__(self, parameters):
self.parameters = parameters
def residual(self, state):
F_0 = state.x_0**2 + state.x_1**2 - self.parameters.alpha
F_1 = self.parameters.beta*state.x_0**3 - state.x_1
return jnp.array([F_0, F_1])
alpha = 4.0
beta = 1.0
x_test = jnp.array([2.0, 3.0])
parameter = Parameter(alpha, beta)
state = State(x_test)
model = Model(parameter)
Jacobian_JAX_class = jacfwd(model.residual, argnums=0, has_aux=False)
F = model.residual(state)
J = Jacobian_JAX_class(state)
J = lx.PyTreeLinearOperator(J, jax.eval_shape(lambda: state))
# lx.linear_solve(J, F) # failed
So the issue in this example is F
is a JAX array, but the output structure of J
has the PyTree structure of State
, as set by the out_structure
in lx.PyTreeLinearOperator(fn, out_structure)
. So State
but
There's a number of ways to fix this. The cleanest is probably to use lx.JacobianLinearOperator
, which exists specifically to simplify cases like this. In this case, the final code block becomes:
F = model.residual(state)
J = lx.JacobianLinearOperator(lambda x, a: model.residual(x), state)
lx.linear_solve(J, F)
where the lambda is there because in general we anticipate that the function in JacobianLinearOperator
can take an extra args
argument. There's no need to define Jacobian_JAX_class
here, it's taken care of by lx.JacobianLinearOperator
.
If you wanted to stick with PyTreeLinearOperator
, then either replace the eval_shape
with jax.eval_shape(lambda: F)
Jacobian_JAX_class = jacfwd(model.residual, argnums=0, has_aux=False)
F = model.residual(state)
J = Jacobian_JAX_class(state)
J = lx.PyTreeLinearOperator(J, jax.eval_shape(lambda: F))
soln = lx.linear_solve(J, F)
soln.value # A JAX array with the same shape as F
or make F
have a PyTree structure of State
Jacobian_JAX_class = jacfwd(model.residual, argnums=0, has_aux=False)
F = State(model.residual(state))
J = Jacobian_JAX_class(state)
J = lx.PyTreeLinearOperator(J, jax.eval_shape(lambda: state))
soln = lx.linear_solve(J, F)
soln.value # A PyTree with the same structure as State
However, I'm assuming this third option is not what you had in mind.
Hi Jason,
Thank you so much! It worked pretty well. As you might have guessed, it is used to solve some PDEs. Now, I am thinking of just dumping my solver and switching to optimistix in fact because it can do all the job under the hood (if they support Newton with backtracking line search but it seems not).
It's true, Newton with backtracking line search isn't something we've implemented in Optimistix yet. (patrick-kidger/optimistix#4)
It should be essentially straightforward to do, though: Newton and Gauss-Newton are basically the same algorithm, just applied to different problems. As such copy-pasting optx.AbstractGaussNewton
would get us 95% of the way there. (If you feel up to we'd be happy to take a PR on that.)
Just to be clear, are you handling a general minimisation problem, or a nonlinear least-squares problem? In your example problem,
If your actual PDE is also of this form, ie. you have some vector of residuals and you'd like to minimise the sum of their squares, then check out optx.AbstractGaussNewton
and optx.BacktrackingArmijo
with the mix-and-match API.
I am solving the system of non-linear equations (resulting from nonlinear PDEs) by the Newton method with backtracking line search. So, it's a root finding problem.
Gauss-Newton and backtracking Armijo with the mix-and-match API should work then. Gauss-Newton is mathematically equivalent to Newton for nonlinear systems. As Patrick said, they're just applied a little differently.
The aim is to solve
Which is the Newton update. You don't have to do this conversion manually, calling optx.root_find(fn, YourFancyNewSolver, ...)
will automatically convert to the least-squares problem
Okay, that sounds good too. I will give it a shot. I will keep you posted.