odow / MathOptSymbolicAD.jl

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

User-defined functions

odow opened this issue · comments

Symbolics has some support, but ideally we'd be able to keep the analytic derivative and then replace the Derivative in the expression: https://symbolics.juliasymbolics.org/dev/manual/derivatives/#Adding-Analytical-Derivatives-1

using Symbolics
foo(x) = log(x)
foo′(x) = 1 / x
foo′′(x) = -1 / x^2

function Symbolics.derivative(
    ::typeof(foo), 
    args::NTuple{N,Any},
    ::Val{1},
) where {N}
    return 1 / args[1]
end

@variables(x)
@register_symbolic foo(x)

f = x + foo(x)
xs = [x]

∇f = Symbolics.gradient(f, xs)
∇²f = Symbolics.sparsejacobian(∇f, xs)
f_f = build_function(f, xs; expression = Val{false})
f_∇f, f_∇f! = build_function(∇f, xs; expression = Val{false})
f_∇²f, f_∇²f! = build_function(∇²f, xs; expression = Val{false})

Hello Oscar and thanks for working on this cool package! I'd like to hear your thoughts on user-defined functions that takes and returns vectors. The JuMP docs currently outlines a workaround and Symbolics.jl is also struggling a bit in this area at the moment. Do you have anything in the pipe that would move this particular needle?

No ideas or plans for this. Vector-valued nonlinear is something we're thinking about, but the lack of a good AD system is a blocker. To begin with, we'll probably implement vector-valued nonlinear in JuMP as syntactic sugar that scalarizes everything before passing it on.

To begin with, we'll probably implement vector-valued nonlinear in JuMP as syntactic sugar that scalarizes everything before passing it on.

That sounds lika a good first step. Will this lead to redundant calls to nonlinear functions like what's currently suggested in the docs, or the nonlinear functions will be traced through with scalar symbols similar to how it would work with scalar symbols from Symbolics.jl?

Will this lead to redundant calls to nonlinear functions like what's currently suggested in the docs, or the nonlinear functions will be traced through with scalar symbols similar to how it would work with scalar symbols from Symbolics.jl?

We haven't go to thinking how this will be implemented, but nonlinear functions won't be traced with scalar symbols.