JuliaMath / HypergeometricFunctions.jl

A Julia package for calculating hypergeometric functions

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Stack overflow from AD for specific argument range

cgeoga opened this issue · comments

Hello---

First, thanks for writing this package. I am amazed at how easy you have made it to work with these functions.

In working with some functions using 2F1, I've stumbled on an issue where for a specific range of values I get a stack overflow issue when I try to use autodiff. Here's a minimal reproduction:

using HypergeometricFunctions, ForwardDiff
fun(arg, v) = _\_2 F \_1 (0.5, 0.5*(v+1), 1.5, -(arg^2)/v)
ForwardDiff.derivative(v->fun(2.4, v), 1.825) # error
ForwardDiff.derivative(v->fun(2.4, v), 1.8)   # works fine

The error that I get is this:

ERROR: StackOverflowError:
Stacktrace:
 [1] unsafe_gamma(::ForwardDiff.Dual{ForwardDiff.Tag{var"#5#6",Float64},Float64,1}) at /home/cg/Scratch/HypergeometricFunctions.jl/src/specialfunctions.jl:144 (repeats 79984 times)

I've tracked the issue down to ./src/specialfunctions.jl on line 225. The call to unsafe_gamma(z+\epsilon) is what causes the stack overflow. I see that unsafe_gamma has specialization for x::Dual, but if my print debugging is correct, the dispatch doesn't work and despite z+\epsilon being a dual number, it calls unsafe_gamma on line 144, which is unsafe_gamma(x::Real) = unsafe_gamma(float(x)).

I see that this software uses some of the aliases from DualNumbers.jl, so I tried changing realpart to value and dualpart to epsilon, and a couple other tweaks like that, but I haven't been able to get the dispatch to work. Is this at least the right place to be looking? Do you have any idea what might be causing this issue?

I think this is a variant of JuliaLang/julia#26552

Apologies for the double tap, but after some helpful clarification on the discourse, I see that ForwardDiff.Dual <: Real, which is precisely why that stack overflow was happening.

If I add an extra branch unsafe_gamma(x::ForwardDiff.Dual), that gets triggered. The following new method seems to work and looks correct compared to finite diff:

function unsafe_gamma(z::ForwardDiff.Dual{T,V,N}) where{T,V,N}
  r  = z.value
  du = ForwardDiff.partials(T, z, 1)
  w  = unsafe_gamma(r)
  ForwardDiff.Dual{T}(w, w*digamma(r)*du)
end

So here's the follow up question: I have no idea how maintaining a real Julia package works. Is there some way to include this code here without adding a dependency to ForwardDiff, which I would guess you don't want to do since you already support DualNumbers?

Also, thanks for the link, @tpapp! That is helpful context.

We should definitely support ForwardDiff.jl

Okay, well in that case, would it make sense for me to submit a PR (after learning how to do that)?

Yes definitely. (For open source software the only way to guarantee getting such bugs are fixed is to make PRs. Usually the maintainers are much more happy to do code review on a PR than debug themselves. So very useful skill to learn!)

Awesome, thank you. I'm in the process of preparing one now. But I actually caught an issue with my implementation that I'm having some trouble debugging.

This now works:

ForwardDiff.derivative(v->fun(2.4, v), 1.825)

But this does not:

ForwardDiff.gradient(v->fun(v...), [2.4, 1.825])

with the (truncated) error

ERROR: MethodError: no method matching _mul_partials(::ForwardDiff.Partials{1,Float64}, ::ForwardDiff.Partials{2,Float64}, ::Float64, ::Float64)
Closest candidates are:
  _mul_partials(::ForwardDiff.Partials{0,A}, ::ForwardDiff.Partials{N,B}, ::Any, ::Any) where {N, A, B} at /home/cg/.julia/packages/ForwardDiff/sdToQ/src/partials.jl:141
  _mul_partials(::ForwardDiff.Partials{N,A}, ::ForwardDiff.Partials{0,B}, ::Any, ::Any) where {N, A, B} at /home/cg/.julia/packages/ForwardDiff/sdToQ/src/partials.jl:142
  _mul_partials(::ForwardDiff.Partials{N,V} where V, ::ForwardDiff.Partials{N,V} where V, ::Any, ::Any) where N at /home/cg/.julia/packages/ForwardDiff/sdToQ/src/partials.jl:119
  ...

I am clearly not manually doing the manual type assignment in my unsafe_gamma correctly. I'm confused about what 0 in the type signature ForwardDiff.Partial{0,A} where{A} means, as that would suggest taking a derivative with respect to zero arguments (right?). I've tried look at the ForwardDiff source for derivative and gradient, but I am really having trouble identifying where a 1 turns into a 0 in that type somewhere in the extract_gradient or extract_derivative process. Do you have any suggestions or thoughts of what I am doing incorrectly?