JuliaDiff / Diffractor.jl

Next-generation AD

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

backing error when nesting forward-mode

oxinabox opened this issue · comments

Here is a trivial example of nested forwards mode

using Diffractor
f(x) = 3x^2
g(x) = Diffractor.∂☆{1}()(Diffractor.ZeroBundle{1}(f), Diffractor.TaylorBundle{1}(x, (1.0,)))
Diffractor.∂☆{1}()(Diffractor.ZeroBundle{1}(g), Diffractor.TaylorBundle{1}(10, (1.0,)))

Here is the output from running that:

julia> Diffractor.∂☆{1}()(Diffractor.ZeroBundle{1}(g), Diffractor.TaylorBundle{1}(10, (1.0,)))
ERROR: ArgumentError: Tangent for the primal Diffractor.UniformTangent{ChainRulesCore.ZeroTangent} should be backed by a NamedTuple type, not by Tuple{ChainRulesCore.ZeroTangent}.
Stacktrace:
  [1] _backing_error(P::Type, G::Type, E::Type)
    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/a4mIA/src/tangent_types/tangent.jl:62
  [2] ChainRulesCore.Tangent{Diffractor.UniformTangent{ChainRulesCore.ZeroTangent}, Tuple{ChainRulesCore.ZeroTangent}}(backing::Tuple{ChainRulesCore.ZeroTangent})
    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/a4mIA/src/tangent_types/tangent.jl:36
  [3] (ChainRulesCore.Tangent{Diffractor.UniformTangent{ChainRulesCore.ZeroTangent}})(args::ChainRulesCore.ZeroTangent)
    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/a4mIA/src/tangent_types/tangent.jl:48
  [4] partial(x::Diffractor.CompositeBundle{1, Diffractor.UniformTangent{ChainRulesCore.ZeroTangent}, Tuple{Diffractor.TangentBundle{1, ChainRulesCore.ZeroTangent, Diffractor.UniformTangent{ChainRulesCore.ZeroTangent}}}}, i::Int64)
    @ Diffractor ~/.julia/packages/Diffractor/HBYjZ/src/stage1/forward.jl:7
  [5] first_partial(x::Diffractor.CompositeBundle{1, Diffractor.UniformTangent{ChainRulesCore.ZeroTangent}, Tuple{Diffractor.TangentBundle{1, ChainRulesCore.ZeroTangent, Diffractor.UniformTangent{ChainRulesCore.ZeroTangent}}}})
    @ Diffractor ~/.julia/packages/Diffractor/HBYjZ/src/stage1/forward.jl:11
  [6] map
    @ ./tuple.jl:291 [inlined]
  [7] map(f::typeof(Diffractor.first_partial), t::Tuple{Diffractor.TangentBundle{1, typeof(Diffractor._TangentBundle), Diffractor.UniformTangent{ChainRulesCore.ZeroTangent}}, Diffractor.TangentBundle{1, Val{1}, Diffractor.UniformTangent{ChainRulesCore.ZeroTangent}}, Diffractor.TangentBundle{1, typeof(f), Diffractor.UniformTangent{ChainRulesCore.ZeroTangent}}, Diffractor.CompositeBundle{1, Diffractor.UniformTangent{ChainRulesCore.ZeroTangent}, Tuple{Diffractor.TangentBundle{1, ChainRulesCore.ZeroTangent, Diffractor.UniformTangent{ChainRulesCore.ZeroTangent}}}}})
    @ Base ./tuple.jl:292
  [8] (::Diffractor.∂☆internal{1})(::Diffractor.TangentBundle{1, typeof(Diffractor._TangentBundle), Diffractor.UniformTangent{ChainRulesCore.ZeroTangent}}, ::Vararg{Diffractor.AbstractTangentBundle{1}})
    @ Diffractor ~/.julia/packages/Diffractor/HBYjZ/src/stage1/forward.jl:110
  [9] (::Diffractor.∂☆{1})(::Diffractor.TangentBundle{1, typeof(Diffractor._TangentBundle), Diffractor.UniformTangent{ChainRulesCore.ZeroTangent}}, ::Vararg{Diffractor.AbstractTangentBundle{1}})
    @ Diffractor ~/.julia/packages/Diffractor/HBYjZ/src/stage1/forward.jl:139
 [10] TangentBundle
    @ ~/.julia/packages/Diffractor/HBYjZ/src/tangent.jl:251 [inlined]
 [11] (::Diffractor.∂☆recurse{1})(::Diffractor.TangentBundle{1, Type{Diffractor.TangentBundle{1, B, Diffractor.UniformTangent{ChainRulesCore.ZeroTangent}} where B}, Diffractor.UniformTangent{ChainRulesCore.NoTangent}}, ::Diffractor.TangentBundle{1, typeof(f), Diffractor.UniformTangent{ChainRulesCore.ZeroTangent}})
    @ Diffractor ~/.julia/packages/Diffractor/HBYjZ/src/stage1/recurse_fwd.jl:0
 [12] (::Diffractor.∂☆internal{1})(::Diffractor.TangentBundle{1, Type{Diffractor.TangentBundle{1, B, Diffractor.UniformTangent{ChainRulesCore.ZeroTangent}} where B}, Diffractor.UniformTangent{ChainRulesCore.NoTangent}}, ::Vararg{Diffractor.AbstractTangentBundle{1}})
    @ Diffractor ~/.julia/packages/Diffractor/HBYjZ/src/stage1/forward.jl:112
 [13] (::Diffractor.∂☆{1})(::Diffractor.TangentBundle{1, Type{Diffractor.TangentBundle{1, B, Diffractor.UniformTangent{ChainRulesCore.ZeroTangent}} where B}, Diffractor.UniformTangent{ChainRulesCore.NoTangent}}, ::Vararg{Diffractor.AbstractTangentBundle{1}})
    @ Diffractor ~/.julia/packages/Diffractor/HBYjZ/src/stage1/forward.jl:139
 [14] g
    @ ~/JuliaEnvs/DAECompiler.jl/scratch/jac_scratch.jl:54 [inlined]
 [15] (::Diffractor.∂☆recurse{1})(::Diffractor.TangentBundle{1, typeof(g), Diffractor.UniformTangent{ChainRulesCore.ZeroTangent}}, ::Diffractor.TangentBundle{1, Int64, Diffractor.TaylorTangent{Tuple{Float64}}})
    @ Diffractor ~/.julia/packages/Diffractor/HBYjZ/src/stage1/recurse_fwd.jl:0
 [16] (::Diffractor.∂☆internal{1})(::Diffractor.TangentBundle{1, typeof(g), Diffractor.UniformTangent{ChainRulesCore.ZeroTangent}}, ::Vararg{Diffractor.AbstractTangentBundle{1}})
    @ Diffractor ~/.julia/packages/Diffractor/HBYjZ/src/stage1/forward.jl:112
 [17] (::Diffractor.∂☆{1})(::Diffractor.TangentBundle{1, typeof(g), Diffractor.UniformTangent{ChainRulesCore.ZeroTangent}}, ::Vararg{Diffractor.AbstractTangentBundle{1}})
    @ Diffractor ~/.julia/packages/Diffractor/HBYjZ/src/stage1/forward.jl:139
 [18] top-level scope
    @ ~/JuliaEnvs/DAECompiler.jl/scratch/jac_scratch.jl:55

I am not sure how we should handle this.
I suspect it is possible to rewrite some (maybe all?) case to do ∂☆{2}
but I would need to do some thinking.

we definately shouldn't be just erroring though.

So I believe the cause of this is that CompositeBundle works (/is tested) to represent the tangent bundle for Tuples only.
But its being asked to represent the tangent of structs (in particular for the tangent bundle struct),
but when you pull the partial out for a struct, it pulls out a CRC.Tangent(P, <:Tuple) and that is what gives the error

Here is a simpler case that triggers it without nesting AD

struct Foo
    x
    y
end
foo_dub(x) = Foo(x, 2x)
dz = Diffractor.∂☆{1}()(Diffractor.ZeroBundle{1}(foo_dub), Diffractor.TaylorBundle{1}(10.0, (1.0,)))
Diffractor.first_partial(dz)

erroring with

julia> Diffractor.first_partial(dz)
ERROR: ArgumentError: Tangent for the primal Foo should be backed by a NamedTuple type, not by Tuple{Float64, Float64}.
Stacktrace:
 [1] _backing_error(P::Type, G::Type, E::Type)
   @ ChainRulesCore ~/.julia/packages/ChainRulesCore/a4mIA/src/tangent_types/tangent.jl:62
 [2] ChainRulesCore.Tangent{Foo, Tuple{Float64, Float64}}(backing::Tuple{Float64, Float64})
   @ ChainRulesCore ~/.julia/packages/ChainRulesCore/a4mIA/src/tangent_types/tangent.jl:36
 [3] (ChainRulesCore.Tangent{Foo})(::Float64, ::Vararg{Float64})
   @ ChainRulesCore ~/.julia/packages/ChainRulesCore/a4mIA/src/tangent_types/tangent.jl:48
 [4] partial(x::Diffractor.CompositeBundle{1, Foo, Tuple{Diffractor.TangentBundle{1, Float64, Diffractor.TaylorTangent{Tuple{Float64}}}, Diffractor.TangentBundle{1, Float64, Diffractor.TaylorTangent{Tuple{Float64}}}}}, i::Int64)
   @ Diffractor ~/JuliaEnvs/DAECompiler.jl/dev/Diffractor/src/stage1/forward.jl:7
 [5] first_partial(x::Diffractor.CompositeBundle{1, Foo, Tuple{Diffractor.TangentBundle{1, Float64, Diffractor.TaylorTangent{Tuple{Float64}}}, Diffractor.TangentBundle{1, Float64, Diffractor.TaylorTangent{Tuple{Float64}}}}})
   @ Diffractor ~/JuliaEnvs/DAECompiler.jl/dev/Diffractor/src/stage1/forward.jl:11
 [6] top-level scope
   @ ~/JuliaEnvs/DAECompiler.jl/scratch/jac_scratch.jl:65

Fixed by #137