JuliaNLSolvers / NLsolve.jl

Julia solvers for systems of nonlinear equations and mixed complementarity problems

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Allow ReverseDiff Propagation

taylormcd opened this issue · comments

This package is currently compatible with ForwardDiff, but not ReverseDiff.

using NLsolve, ForwardDiff, ReverseDiff

function residual!(r, y, x)
    r[1] = (y[1] + x[1])*(y[2]^3-x[2])+x[3]
    r[2] = sin(y[2]*exp(y[1])-1)*x[4]
end

function solve(x)
    TF = eltype(x)
    rwrap(r, y) = residual!(r, y, x[1:4])
    res = nlsolve(rwrap, TF[0.1; 1.2], autodiff=:forward)
    return res.zero
end

function program(x)
    z = 2.0*x
    w = z + x.^2
    y = solve(w)
    return y[1] .+ w*y[2]
end

x = [1.0, 2.0, 3.0, 4.0, 5.0]

ForwardDiff.jacobian(program, x)
# 5×5 Matrix{Float64}:
#   8.05247   1.94271   -0.95879   2.90746e-25   0.0
#   8.55572  14.3307    -3.0819    8.60063e-26   0.0
#  16.8073   12.2672     4.72726  -2.0063e-25    0.0
#  27.4165   20.0105    -9.87583  13.4769        0.0
#  40.3833   29.4746   -14.5467   -1.01959e-24  16.1723

ReverseDiff.jacobian(program, x)
# ERROR: UndefVarError: rT not defined

The reverse pass fails because the default constructor for SolverResults can't figure out the right type for rT. The fix is to define a more reliable constructor. For example:

mutable struct SolverResults{rT<:Real,T<:Union{rT,Complex{rT}},I<:AbstractArray{T},Z<:AbstractArray{T}}
    method::String
    initial_x::I
    zero::Z
    residual_norm::rT
    iterations::Int
    x_converged::Bool
    xtol::rT
    f_converged::Bool
    ftol::rT
    trace::SolverTrace
    f_calls::Int
    g_calls::Int
    # provide inner constructor (default inner constructor doesn't work for all cases)
    function SolverResults(method, initial_x, zero, residual_norm, iterations, x_converged, 
        xtol, f_converged, ftol, trace, f_calls, g_calls)

        # real type
        rT = promote_type(real(eltype(initial_x)), real(eltype(zero)), typeof(residual_norm), typeof(xtol), typeof(ftol))
        
        # real/complex type
        if promote_type(eltype(initial_x), eltype(zero)) <: Complex
            T = Complex{rT}
        else
            T = rT
        end

        # correct initial guess type
        if !(eltype(initial_x) <: T)
            initial_x = T.(initial_x)
        end

        # correct zero element type (if necessary)
        if !(eltype(zero) <: T)
            zero = T.(zero)
        end

        # initial guess type
        I = typeof(initial_x)

        # zero type
        Z = typeof(initial_x)

        return new{rT,T,I,Z}(method, initial_x, zero, residual_norm, iterations, 
            x_converged, xtol, f_converged, ftol, trace, f_calls, g_calls)
    end
end

Then the ReverseDiff derivatives propagate as expected.

# with modified implementation
ReverseDiff.jacobian(program, x)
# 5×5 Matrix{Float64}:
#   8.05247   1.94271   -0.95879  -4.91066e-28   0.0
#   8.55572  14.3307    -3.0819   -2.04315e-27   0.0
#  16.8073   12.2672     4.72726  -2.00371e-27   0.0
#  27.4165   20.0105    -9.87583  13.4769        0.0
#  40.3833   29.4746   -14.5467    0.0          16.1723

Note that this issue involves passing derivatives through the nonlinear solve, rather than defining a custom pullback for the nonlinear solve (as discussed in #205).