JuliaNLSolvers / NLsolve.jl

Julia solvers for systems of nonlinear equations and mixed complementarity problems

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Reverse differentiation through nlsolve

antoine-levitt opened this issue · comments

OK, so this is pretty speculative as the reverse differentiation packages are not there yet, but let's dream for a moment. It would be awesome to be able to just use reverse-mode differentiation on code like

function G(α)
    F(x) = x - α
    x = nlsolve(F, zeros(length(α)))
    H(x) = sum(x)
    H(x)
end

and take the gradient of G wrt α. Of course, both F and H are examples, and can be arbitrary functions.

So how to get the gradient of G? One can of course forward diff through G, which is not too hard to support from the perspective of nlsolve (although I haven't tried). But that's pretty inefficient if α is high-dimensional. One can try reverse-diffing through G, but that's pretty heavy since this has to basically record all the iterations. A better idea is to exploit the mathematical structure of the problem, and in particular the relationship dx/dα = -(∂F/∂x)^-1 ∂F/∂α (differentiate F(x(α),α)=0 wrt α), assuming nlsolve is converged perfectly. Reverse-mode autodiff requires the user to compute (dx/dα)^T δx, which is -(∂F/∂α)^T (∂F/∂x)^-T δx. If the jacobian is not provided (Broyden or Anderson), this can be done by using an iterative solver such as GMRES, and where the individual matvecs with (∂F/∂x)^T are performed with reverse diff.

The action point for nlsolve here is to write a reverse ChainRule (https://github.com/JuliaDiff/ChainRules.jl) for nlsolve. This might be tricky because nlsolve takes a function as argument, but we might get by with just calling a diff function on F recursively. CC @jrevels to check this isn't a completely stupid idea. Of course, this isn't necessarily specific to nlsolve; the same ideas apply to optim (writing ∇F = 0) and diffeq (adjoint equations) for instance.

So how to get the gradient of G? One can of course forward diff through G, which is not too hard to support from the perspective of nlsolve (although I haven't tried). But that's pretty inefficient if α is high-dimensional. One can try reverse-diffing through G, but that's pretty heavy since this has to basically record all the iterations. A better idea is to exploit the mathematical structure of the problem, and in particular the relationship dx/dα = -(∂F/∂x)^-1 ∂F/∂α (differentiate F(x(α),α)=0 wrt α), assuming nlsolve is converged perfectly. Reverse-mode autodiff requires the user to compute (dx/dα)^T δx, which is -(∂F/∂α)^T (∂F/∂x)^-T δx. If the jacobian is not provided (Broyden or Anderson), this can be done by using an iterative solver such as GMRES, and where the individual matvecs with (∂F/∂x)^T are performed with reverse diff.

You might want to take a look at https://arxiv.org/abs/1812.01892 . Yes, forward is fast but doesn't scale, and hard-coded adjoints do well. But I think the golden solution might be to just wait for source-to-source like Zygote.jl since then reverse mode can be done without operation tracking.

Even then, Zygote still has to record all the history of the iterative method and then run it backwards, so that'll likely be slow and memory-consuming, won't it?

I don't think it has to build a tape to handle loops? But then again, then I don't know how it know how many times to go back through the loop. @MikeInnes

Well if you backward differentiate through a loop don't you have to keep track of all the intermediary steps?

I took at look at your (very nice btw) paper; I think one key difference between adjoints for solving equations and differential equations is that in diff eqs, you are interested in the whole solution, and so don't have any choice but to keep it in memory, at which point reverse-mode differentiating through the thing doesn't look so bad. When solving equations you just want the final solution and discard the iterates. This enables more efficient adjoints for nonlinear solves, where you can just discard the convergence history.

To make my point above clearer, take a simpler case: computing the gradient of y -> <b, A^-1 y>, where the A solve is done through a simple iterative method like CG (let's say). If you run reverse AD through CG, you're going to need to store all iterates (or use fancy techniques), and need potentially a lot of memory. Instead, you can just write the gradient as A^-T b and CG-solve that. Obviously this is a trivial example but it generalizes to arbitrary nonlinear systems and outputs.

I see what you're saying and it totally makes sense in this case. I don't know the right place to overload for this though. We did it directly on Flux.Tracker types.

This sounds very useful. How about starting from something very simple, like this?

using NLsolve
using Zygote
using Zygote: @adjoint, forward

@adjoint nlsolve(f, j, x0; kwargs...) =
    let result = nlsolve(f, j, x0; kwargs...)
        result, function(vresult)
            # This backpropagator returns (- v' (df/dx)⁻¹ (df/dp))'
            v = vresult[].zero
            x = result.zero
            J = j(x)
            _, back = forward(f -> f(x), f)
            return (back(-(J' \ v))[1], nothing, nothing)
        end
    end

It looks like it's working:

julia> d, = gradient(p -> nlsolve(x -> [x[1]^3 - p],
                                  x -> fill(3x[1]^2, (1, 1)),
                                  [1.0]).zero[1],
                     8.0)
(0.08333333333333333,)

julia> d  1/3 * 8.0^(1/3 - 1)
true

Is Zygote ready to be used in the wild?

Wow, that is so cool. Now we just need JuliaDiff/ChainRulesCore.jl#22 to be able to take a dependency on ChainRulesCore, put that code in nlsolve, add the iterative solve for the zeroth-order methods, and we rule the world! (well, except for mutation...)

(well, except for mutation...)

What do you mean?

@antoine-levitt FYI it looks like complex number interface could become another blocker for Zygote users (see FluxML/Zygote.jl#142 (comment)) although I guess we still can use other AD packages based on ChainRulesCore? But are there other AD packages closer to production-ready than Zygote.jl? I'm wondering if it makes sense to define just for Zygote.jl for now. Reading FluxML/Zygote.jl#291 it seems that ChainRulesCore's API would be close to Zygote.jl so migration doesn't sound hard.

Of course, people can just define their own wrapper. I did it already (https://github.com/tkf/SteadyStateFit.jl/blob/239b18252ea5a596b780ddfbeb483b7e80b17572/src/znlsolve.jl) so this is not a blocker for me personally anymore.

Don't think it's such a blocker : both zygote and chainrules support complex differentials in their full generality, it's just a question of putting the APIs together and of optimization.

It looks like zygote and chainrules are going to mesh in the short term, so we might as well wait until then. @oxinabox, does that sound reasonable? I think it's better for nlsolve to take on a dependency on ChainRulesCore than on Zygote.

Sounds reasonable to me

Don't think it's such a blocker : both zygote and chainrules support complex differentials in their full generality, it's just a question of putting the APIs together and of optimization.

Yeah, I don't think it will be much of a blocker.

It looks like zygote and chainrules are going to mesh in the short term, so we might as well wait until then. @oxinabox, does that sound reasonable? I think it's better for nlsolve to take on a dependency on ChainRulesCore than on Zygote.

No later than end of the year. Hopefully much sooner.

There is also ZygoteRules.jl which is I think a Zygote specific equiv of ChainRulesCore.
One could use that in the short term, but I am hoping for it to either be retired, or marked as "use only if you need rules that depend on the internals if Zygote".
ChainRules is more general.
Most notably, it also supports the upcoming ForwardDiff2.

OK, that's good news.

Lyndon mentioned it, but just linking ZygoteRules explicitly.

RE using Zygote in the wild: the marker for that is really going to be when we release Flux + Zygote; once it's ready for that it's going to have been pretty heavily user-tested. OTOH it's still a relatively large dependency. My suggestion would be to use ZygoteRules to add this adjoint for now, and then switch to ChainRules once it's ready.

So here's a (very crude) prototype of reverse diffing through a nonlinear PDE solve (based on @tkf's code, but with fully iterative methods, to get something representative of large-scale applications)

using NLsolve
using Zygote
using Zygote: @adjoint, forward
using IterativeSolvers
using LinearMaps
using SparseArrays

# nlsolve maps f to the solution x of f(x) = 0
# We have ∂x = -(df/dx)^-1 ∂f, and so the adjoint is df = -(df/dx)^-T dx
@adjoint nlsolve(f, x0; kwargs...) =
    let result = nlsolve(f, x0; kwargs...)
        result, function(vresult)
            dx = vresult[].zero
            x = result.zero
            _, back_x = forward(f, x)

            JT(df) = back_x(df)[1]
            # solve JT*df = -dx
            L = LinearMap(JT, length(x0))
            df = gmres(L,-dx)

            _, back_f = forward(f -> f(x), f)
            return (back_f(df)[1], nothing, nothing)
        end
    end

const N = 10000
const nonlin = 0.1
const A = spdiagm(0 => fill(10.0, N), 1 => fill(-1.0, N-1), -1 => fill(-1.0, N-1))
const p0 = randn(N)
f(x, p) = A*x + nonlin*x.^2 - p
solve_x(p) = nlsolve(x -> f(x, p), zeros(N), method=:anderson, m=10).zero
obj(p) = sum(solve_x(p))

Zygote.refresh()
g_auto, = gradient(obj, p0)
g_analytic = gmres((A + Diagonal(2*nonlin*solve_x(p0)))', ones(N))
display(g_auto)
display(g_analytic)

@btime gradient(obj, p0)
@btime gmres((A + Diagonal(2*nonlin*solve_x(p0)))', ones(N))

Performance is not great, essentially 20x compared to the analytic version. However profiling shows that this overhead is pretty localized, so it might be possible to optimize it away and get essentially the same perf as the analytic one (this should be a relatively easy case for reverse diff, since there's only vector operations, and no loop). I'm not quite sure what's going here; one possibility is that Zygote tries to diff wrt globally defined constants.

You could try explicitly dropping gradients of globals to see if that's the issue.

OK, but how do I do that?

I see a similar hiccup with closures and would like to know any solution/workaround FluxML/Zygote.jl#323

@tkf I see you closed the issue there, but the discussion there was too technical for me to follow. Could you summarize what it means for the code above? Will it be fixed by the chainrules integration?

@antoine-levitt Short answer is, IIUC, it'll be solved by switching to ChainRulesCore. I posted a longer answer with step-by-step code in FluxML/Zygote.jl#323 (comment) explaining why I thought it was solved.

So here's a (very crude) prototype of reverse diffing through a nonlinear PDE solve (based on @tkf's code, but with fully iterative methods, to get something representative of large-scale applications)

using NLsolve
using Zygote
using Zygote: @adjoint, forward
using IterativeSolvers
using LinearMaps
using SparseArrays

# nlsolve maps f to the solution x of f(x) = 0
# We have ∂x = -(df/dx)^-1 ∂f, and so the adjoint is df = -(df/dx)^-T dx
@adjoint nlsolve(f, x0; kwargs...) =
    let result = nlsolve(f, x0; kwargs...)
        result, function(vresult)
            dx = vresult[].zero
            x = result.zero
            _, back_x = forward(f, x)

            JT(df) = back_x(df)[1]
            # solve JT*df = -dx
            L = LinearMap(JT, length(x0))
            df = gmres(L,-dx)

            _, back_f = forward(f -> f(x), f)
            return (back_f(df)[1], nothing, nothing)
        end
    end

const N = 10000
const nonlin = 0.1
const A = spdiagm(0 => fill(10.0, N), 1 => fill(-1.0, N-1), -1 => fill(-1.0, N-1))
const p0 = randn(N)
f(x, p) = A*x + nonlin*x.^2 - p
solve_x(p) = nlsolve(x -> f(x, p), zeros(N), method=:anderson, m=10).zero
obj(p) = sum(solve_x(p))

Zygote.refresh()
g_auto, = gradient(obj, p0)
g_analytic = gmres((A + Diagonal(2*nonlin*solve_x(p0)))', ones(N))
display(g_auto)
display(g_analytic)

@btime gradient(obj, p0)
@btime gmres((A + Diagonal(2*nonlin*solve_x(p0)))', ones(N))

Performance is not great, essentially 20x compared to the analytic version. However profiling shows that this overhead is pretty localized, so it might be possible to optimize it away and get essentially the same perf as the analytic one (this should be a relatively easy case for reverse diff, since there's only vector operations, and no loop). I'm not quite sure what's going here; one possibility is that Zygote tries to diff wrt globally defined constants.

If I use LinearAlgebra.diagm in f(x,p), it will raise "Need an adjoint for constructor Pair". How can I write the adjoint method similar to the above adjoint? many thanks!

I don't know, you probably need to take it up with Zygote (or ChainRules). Also note that the above code was for an older version of Zygote, it needs updating (if anyone does so, please post the result and check whether the above-mentioned slowdown is still present!)

Here's @antoine-levitt's example as of today

using NLsolve
using Zygote
using Zygote: @adjoint
using IterativeSolvers
using LinearMaps
using SparseArrays
using LinearAlgebra
using BenchmarkTools

# nlsolve maps f to the solution x of f(x) = 0
# We have ∂x = -(df/dx)^-1 ∂f, and so the adjoint is df = -(df/dx)^-T dx
@adjoint nlsolve(f, x0; kwargs...) =
    let result = nlsolve(f, x0; kwargs...)
        result, function(vresult)
            dx = vresult[].zero
            x = result.zero
            _, back_x = Zygote.pullback(f, x)

            JT(df) = back_x(df)[1]
            # solve JT*df = -dx
            L = LinearMap(JT, length(x0))
            df = gmres(L,-dx)

            _, back_f = Zygote.pullback(f -> f(x), f)
            return (back_f(df)[1], nothing, nothing)
        end
    end

const N = 10000
const nonlin = 0.1
const A = spdiagm(0 => fill(10.0, N), 1 => fill(-1.0, N-1), -1 => fill(-1.0, N-1))
const p0 = randn(N)
f(x, p) = A*x + nonlin*x.^2 - p
solve_x(p) = nlsolve(x -> f(x, p), zeros(N), method=:anderson, m=10).zero
obj(p) = sum(solve_x(p))

Zygote.refresh()
g_auto, = gradient(obj, p0)
g_analytic = gmres((A + Diagonal(2*nonlin*solve_x(p0)))', ones(N))
display(g_auto)
display(g_analytic)
@show sum(abs.(g_auto - g_analytic))

@btime gradient(obj, p0); 
@btime gmres((A + Diagonal(2*nonlin*solve_x(p0)))', ones(N));

My local timings:

@btime gradient(obj, p0);  # 2.823 s (1141 allocations: 5.99 GiB) 
@btime gmres((A + Diagonal(2*nonlin*solve_x(p0)))', ones(N));  # 21.230 ms (908 allocations: 24.03 MiB)
Status `~/.../Project.toml`
   [6e4b80f9] BenchmarkTools v0.5.0
   [42fd0dbc] IterativeSolvers v0.9.0
   [7a12625a] LinearMaps v3.2.0
   [2774e3e8] NLsolve v4.5.1
   [e88e6eb3] Zygote v0.6.3
   [2f01184e] SparseArrays

Thanks @niklasschmitz for providing the updated code.
The runtime for gradient is now about 100x that for gmres. @antoine-levitt reported 20x. Do you know what happened here?
Also the number of memory allocations for Zygote are about the same but are in total about 250x that of GMRES. What is the cause of these large memory allocations?

It now seems to me that the big slowdown is caused by the sparse matrix A somehow not being handled efficiently. Here's what I get for a dense matrix A (and choosing N=1000) by changing the above snippet:

- const A = spdiagm(0 => fill(10.0, N), 1 => fill(-1.0, N-1), -1 => fill(-1.0, N-1))
+ const A = Array(spdiagm(0 => fill(10.0, N), 1 => fill(-1.0, N-1), -1 => fill(-1.0, N-1))) # try dense A, for comparison only

For this I get the following timings:

@btime gradient(obj, p0); #   26.382 ms (624 allocations: 63.30 MiB) 
@btime gmres((A + Diagonal(2*nonlin*solve_x(p0)))', ones(N));  #   16.002 ms (446 allocations: 9.52 MiB)

So the previous 100x relative slowdown seems gone, cc @antoine-levitt @rkube

The large performance penalties when using SparseMatrixCSC for A might be worth its own Zygote issue, if it doesn't exist already.

Could also be that the cost of matvecs with dense matrices is much larger than the sparse ones so that it hides the overhead?

I now tried to double-check by trying the sparse case again but with a custom rrule for the inner function f:

using ChainRulesCore
function ChainRulesCore.rrule(::typeof(f), x, p)
    y = f(x, p)
    function f_pullback(ȳ)
        ∂x = @thunk(A'+ 2nonlin*x.*ȳ)
        ∂p = @thunk(-ȳ)
        return (NO_FIELDS, ∂x, ∂p)
    end
    return y, f_pullback
end
Zygote.refresh()

Going back to the original example problem from above (i.e. N=10000 and A=spdiagm(...)) I now get

@btime gradient(obj, p0);  # 22.756 ms (986 allocations: 23.99 MiB) 
@btime gmres((A + Diagonal(2*nonlin*solve_x(p0)))', ones(N));  # 23.065 ms (786 allocations: 21.23 MiB)

JuliaDiff/ChainRulesCore.jl#363 is required to avoid the Zygote dependency.

https://github.com/SciML/DiffEqSensitivity.jl/blob/master/src/steadystate_adjoint.jl#L2-L81 is an implementation in DiffEqSensitivity. It has a ton of options for the the vjp can be calculated, https://diffeq.sciml.ai/stable/analysis/sensitivity/#Internal-Automatic-Differentiation-Options-(ADKwargs), but that should get replaced by AbstractDifferentiation.jl. See JuliaDiff/AbstractDifferentiation.jl#1

And @YingboMa did one for NonlinearSolve.jl: https://gist.github.com/YingboMa/4e4496f828c6a3179004f6d0ca224d2a
NonlinearSolve.jl is hyper-specialized for very small problems to be non-allocating and all of that, so hardcoding to use ForwardDiff kinds of things makes sense in that case.

Thanks Chris for the pointers!
Here's a generic ChainRules version of the above Zygote.@adjoint, based on the new calling back into AD mechanism introduced in JuliaDiff/ChainRulesCore.jl#363

function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(nlsolve), f, x0; kwargs...)
    result = nlsolve(f, x0; kwargs...)
    function nlsolve_pullback(Δresult)
        Δx = Δresult[].zero
        x = result.zero
        _, f_pullback = rrule_via_ad(config, f, x)
        JT(v) = f_pullback(v)[2] # w.r.t. x
        # solve JT*Δfx = -Δx
        L = LinearMap(JT, length(x0))
        Δfx = gmres(L, -Δx)
        ∂f = f_pullback(Δfx)[1] # w.r.t. f itself (implicitly closed-over variables)
        return (NoTangent(), ∂f, ZeroTangent())
    end
    return result, nlsolve_pullback
end

Full gist is here: https://gist.github.com/niklasschmitz/b00223b9e9ba2a37ed09539a264bf423

Yay! So, @pkofod it seems everything is in place for putting that rule into nlsolve. That means taking on a dependency on ChainRulesCore and IterativeSolvers, are you OK with that? Also should we do it in the new nlsolvers or not?

And can we add a keyword argument for dispatching, i.e. sensealg = ImplicitAdjoint(), allowing for sensealg = ADPassThrough() to make the AD mechanism not hit the rrule? That would be good for handling/testing stability issues. Then I think this rule should be a method-dependent default, i.e. only fixedpoint and anderson should really default to it, while without more stabilizing it might not make sense to have the Newton methods default to this. Especially with no preconditioner on the GMRES.

And can we add a keyword argument for dispatching,

This should work

i.e. sensealg = ImplicitAdjoint(), allowing for sensealg = ADPassThrough() to make the AD mechanism not hit the rrule?

This is harder.
Not using a rule if you have one defined, is harder especially if we can't use dispatch (since it is a kwarg)
we need to work it out (it is closely related to JuliaDiff/ChainRulesCore.jl#377).
It will work with Diffractor if you just make that return nothing to say don't AD, but not with Zygote (or Nabla) right now (I haven't checked Yota).

One way that could be done for this particular case is having nlsolve call nlsolve_inner,
and having a rrule for nlsolve but not for nlsolve_inner, then for if set to ADPassThrough it can use rrule_via_ad(nlsolve_inner,...)

Other possibilities are to write the dispatch on the kwargs, which I think is possible with sufficient evil?

The way we do this in SciML is to make a drop method:

https://github.com/SciML/DiffEqBase.jl/blob/v6.64.0/src/solve.jl#L66-L71

and then define the adjoint dispatch on a given set of types, making the rrule undefined on the ADPassThrough.

https://github.com/SciML/DiffEqBase.jl/blob/v6.64.0/src/solve.jl#L297-L302

by making it not part of the abstract type.

Then I think this rule should be a method-dependent default, i.e. only fixedpoint and anderson should really default to it, while without more stabilizing it might not make sense to have the Newton methods default to this. Especially with no preconditioner on the GMRES.

Yes of course, the linear solver should follow the nonlinear one. Anderson and Broyden should go to GMRES (which is the linearized version of anderson anyway), and the Newton solvers should go to whatever is used for solving the update equations. Another possibility is to just call nlsolve recursively on the linear equation.

We can also make this optional, either by having the user call explicitly an enable_implicit_adjoints() or by having a separate NLSolveAdjoints (less keen on that one because it splits the code)

I don't think it needs to be separate, just switchable.

Another possibility is to just call nlsolve recursively on the linear equation.

That would make a lot of sense and naturally make it use the same linear solver.

That would make a lot of sense and naturally make it use the same linear solver.

Yeah that sounds like a sensible default, although a bit suboptimal (GMRES is more stable than anderson, the newton method might not be able to figure out it doesn't need any stabilization, etc).

Well the default should probably still be GMRES on the fixed point methods, and only do the recursive iteration on Newton. That would likely get pretty close to optimal, since someone would only choose Newton if they need it (in theory).

The way we do this in SciML is to make a drop method:

Cool, easy, then we can do this.

We can also make methods that are only defined if you have access to a forwards mode AD.
So we can write things that do reverse via computing jacobian via forward.

That's what we call sensealg=ForwardDiffSensitivity()... and now you can see how we ended up at 8 of them:

https://diffeq.sciml.ai/stable/analysis/sensitivity/#Sensitivity-Algorithms

yep, and we can make those work without a direct dependency on ForwardDiff.
So that it will work with either Zygote calling into ForwardDiff, or Diffractor using Diffractor's Forward, etc
Or something calling into FiniteDifferencing for forwards etc

hello! just wondering about the progress on this ticket.

#205 (comment) works, and you can fine tune it to suit your needs. Putting this into the actual nlsolve requires more thought about API, solvers, tolerances, defaults, etc.

BTW, someone should check if Zygote.jl just fails currently on NLsolve.jl. If it does, then you might as well take the working adjoint and slap it on there and do a quick merge. It's not numerically robust, but it's better than failing and step 1 to making something better.

Just checking in again to say that I'm not really sure how these things would be implemented and used in practice, but if @oxinabox can help or at least hold my hand I'm happy to include this feature.

I am happy to hold your hand though this.
I think this is the solution?
#205 (comment)

Yeah, the snippet above should be fine, with the caveat that the linear solve should be replaced by a recursive nlsolve call.

I updated the gist example at https://gist.github.com/niklasschmitz/b00223b9e9ba2a37ed09539a264bf423#gistcomment-3830191
to use a recursive nlsolve call inside the rrule.

Note that this example is adapted for the matrix-free case (where the jacobian is not computed explicitly). In the case where an explicit jacobian is provided, it should be used as JT (instead of computed through AD), and the objective function v -> J^T v + delta x should be passed the explicit jacobian. @pkofod can you help with the API here? I'm a bit fuzzy on how to get that information with the nlsolversbase wrapper types & co.