Non-Concrete Error when running on an ODE system

j-murphy-slingshot opened this issue · comments

When running on an ODE solve the solver kicks a non-concrete error message, I've included a MWE of a Lotka-Volterra system demonstrating the behavior. I get the same error for an in-place solve

function lv_ode(u, p, t)
    x, y = u
    α, β, δ, γ = p

    return [α*x - β*x*y;
            δ*x*y - γ*y]

u0_lv = [10.0; 100.0]
p_lv = [.3; .015; .015; .7]
tspan_lv = (0.0, 100.0)

lv_prob = ODEProblem(lv_ode, u0_lv, tspan_lv, p_lv)

function taylordiff_lv(x, prob)
    _prob = remake(prob, u0=x)
    return Array(solve(_prob, Vern9(), abstol=1E-13, reltol=1E-13, save_start=false, save_everystep=false))[:, end]

# Checking other Derivative Methods
ad_jac = ForwardDiff.jacobian((x) -> taylordiff_lv(x, lv_prob), u0_lv)
fd_jac = FiniteDifferences.jacobian(central_fdm(5, 1), (x) -> taylordiff_lv(x, lv_prob), u0_lv)[1]

td_jac = derivative((x)->taylordiff_lv(x, lv_prob), u0_lv, [1.0; 0.0], 1)

ERROR: Non-concrete element type inside of an Array detected.
Arrays with non-concrete element types, such as
Array{Union{Float32,Float64}}, are not supported by the
differential equation solvers. Anyways, this is bad for
performance so you don't want to be doing this!

If this was a mistake, promote the element types to be
all the same. If this was intentional, for example,
using Unitful.jl with different unit values, then use
an array type which has fast broadcast support for
heterogeneous values such as the ArrayPartition
from RecursiveArrayTools.jl.

I tried using the solution suggested in the error with the line
_prob = remake(prob, u0=ArrayPartition(x))

and converting to an in-place solve

function lv_ode!(du, u, p, t)
    x, y = u
    α, β, δ, γ = p

    du[1:2] =[α*x - β*x*y;
            δ*x*y - γ*y]


but get the the following error instead

ERROR: MethodError: no method matching zero(::Type{Any})

@tansongchen do you have a silent conversion to Float32 somewhere? That seems pretty critical.

This has nothing to do with Float32. The problem is that DifferentialEquations.jl tries to determine unitless eltype, which is defined here: https://github.com/SciML/RecursiveArrayTools.jl/blob/cba251a986e85c3bd677760127250817009e8900/src/utils.jl#L237

recursive_unitless_eltype(a) = recursive_unitless_eltype(eltype(a))
recursive_unitless_eltype(a::Type{Any}) = Any

So when prob.u0 is passed to this function, it gets Any.

@ChrisRackauckas Is this an expected behavior of RecursiveArrayTools.recursive_unitless_eltype?

The array should be concretely typed, so it should eltype(a) shouldn't be Any from the array. If all arrays with taylor scalars are Any typed then performance is screwed anyways, so that should be fixed.

Arrays of TaylorScalar is concretely typed:

>>> _prob.u0
TaylorScalar{Float64, 2}[TaylorScalar{Float64, 2}((10.0, 1.0)), TaylorScalar{Float64, 2}((100.0, 0.0))]

>>> eltype(_prob.u0)
TaylorScalar{Float64, 2}

However, I haven't defined what eltype(TaylorScalar{Float64, 2}) should be, so it falls back to

>>> eltype(TaylorScalar{Float64, 2})


 recursive_unitless_eltype(TaylorScalar{Float64, 2})
 recursive_unitless_eltype(eltype(TaylorScalar{Float64, 2}))

Starting v0.2.0, TaylorScalar isn't a subtype of Number, so it cannot make use of https://github.com/SciML/RecursiveArrayTools.jl/blob/cba251a986e85c3bd677760127250817009e8900/src/utils.jl#L246

@ChrisRackauckas the real problem here is, I have no way to make this correct unless I make a PR to RecursiveArrayTools.jl and tell them recursive_unitless_eltype(a::Type{T}) where {T <: TaylorScalar} = T

eltype(TaylorScalar{Float64, 2}) == TaylorScalar{Float64, 2} should be true. That's a property that numbers generally have.

But if I defined eltype(::Type{T<:TaylorScalar}) = T then this recursive_unitless_eltype function will go into infinite loop since there is no halting branch...

Oh yeah because there no number type. So we need an extension for this

Why is it not <: Number? That doesn't seem to make sense.

td_jac = TaylorDiff.derivative((x)->taylordiff_lv(x, lv_prob), u0_lv, [1.0; 0.0], 1)

Base.real(a::TaylorScalar) = a 
DiffEqBase.UNITLESS_ABS2(x::TaylorScalar) = abs2(x)
Base.abs2(x::TaylorScalar) = abs(x^2)
Base.abs(x::TaylorScalar) = x^2 / x
DiffEqBase.value(x::TaylorScalar) = x.value[1]
Base.eltype(a::TaylorScalar) = a 
Base.iterate(a::TaylorScalar) = a
DiffEqBase.recursive_length(v::Vector{<:TaylorScalar}) = length(v)

TaylorScalar{Float64, 2}(x::Int) = TaylorScalar{Float64, 2}(float(x))

function DiffEqBase.UNITLESS_ABS2(x::AbstractArray{<:TaylorScalar})
    sum(abs2, x)

almost gets it working. But it really just points out... it's a number, it needs to extend. Otherwise it needs hundreds of definitions which are redundant of number.

With <:Number I'm getting close:

using OrdinaryDiffEq, TaylorDiff

function lv_ode(u, p, t)
    x, y = u
    α, β, δ, γ = p

    return*x - β*x*y;
            δ*x*y - γ*y]

u0_lv = [10.0; 100.0]
p_lv = [.3; .015; .015; .7]
tspan_lv = (0.0, 100.0)

lv_prob = ODEProblem(lv_ode, u0_lv, tspan_lv, p_lv)

function taylordiff_lv(x, prob)
    _prob = remake(prob, u0=x)
    return Array(solve(_prob, Vern9(), abstol=1E-13, reltol=1E-13, save_start=false, save_everystep=false, dt = 0.1))[:, end]

Base.real(a::TaylorScalar) = a 
Base.abs(x::TaylorScalar) = x^2 / x
Base.convert(::Type{TaylorScalar{T,N}}, x::Union{Int, Rational, AbstractFloat}) where {T,N} = TaylorScalar{T,N}(float(x))
TaylorScalar{T, N}(x::Union{Int, Rational}) where {T,N} = TaylorScalar{T,N}(float(x))
Base.AbstractFloat(x::TaylorScalar) = x
TaylorScalar{Float64, N}(x::BigFloat) where {N} = TaylorScalar{Float64,N}(Float64(x))
Base.isless(x::TaylorScalar, y::TaylorScalar) = x.value[1] < y.value[1]
Base.isless(x::TaylorScalar, y::Float64) = x.value[1] < y
Base.isless(x::Float64, y::TaylorScalar) = x < y.value[1]
DiffEqBase.value(x::TaylorScalar) = x.value[1]

td_jac = TaylorDiff.derivative((x)->taylordiff_lv(x, lv_prob), u0_lv, [1.0; 0.0], 1)

@tansongchen can we chat about this later today?

To <: Number or not to <: Number is indeed at the heart of the high-level design of this package...

Initially I chose <: Number because it is very intuitive as you said. However, I realized that Taylor polynomials doesn't behave like a number when it comes to AD (specifically, I mean using AD to differentiate programs containing TaylorScalar). For example, ChainRules.jl has rules like

rrule(::typeof(*), a::Number, b::Number) = a * b, Ω -> (b * Ω, a * Ω)

but this is incorrect when back-propagating gradients through *(a::TaylorScalar, b::TaylorScalar).

So, if it were a subtype of Number (as it was), I would need to write a ton of @opt_outs to make cases like PINN works. To make things worse, many rules that need to be opted out doesn't exist in ChainRules.jl - they exist in specific packages like Flux.jl.

Therefore in v0.2.0 I removed <: Number. But now it seems I have to add it back.

At this point I think we've hit the fundamental limitations of operator-overloading approach: in Julia - a language where every exquisite stuff is implemented by dispatch, overloading methods to a custom type to support full-language AD are unlikely to be sound. Maybe we should go to source-code transform or somewhere in between?

I would like to chat about these design problems. But right now I need to write things (that I have already done) up as a thesis so that I can graduate on time. 😂 Let's do that two weeks later?

For any normal operation it should act like a number, it's just that you don't want to inherit AD rules that are defined for number. But I'm pretty sure you don't want to inherit any AD rules for that matter, and instead need to trace them to reconstruct their Taylor series approximation.

But yes we can chat 2 weeks later on this.

I'm a bit out of my depth, but I'll ask some questions that may help others, fully appreciating that you've tabled this for more pressing matters

  1. ForwardDiff.jl seems to allow its dual numbers to subtype Number. Indeed, this package used to as well. Do you have a specific example of an operation that will fail (or give a bad answer silently) if TaylorDiff does? Is it only when TaylorDiff.jl might be used as a layer inside of another AD package?
  2. Does ForwardDiff.jl have a failure mode related to the above? (I'm a little worried now.)

Should be fine with latest version