FluxML / Zygote.jl

21st century AD

Home Page:https://fluxml.ai/Zygote.jl/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Confusing error / silent failure with broadcasted functions with type instability

DomCRose opened this issue · comments

When a function is broadcasted which is type unstable with Dual type inputs, there is a good chance the element type of the resulting output will be abstract, leading to a failure of the logic at

T <: Union{Dual, Complex{<:Dual}} || return (out, _ -> nothing)
(thanks to @ToucheSir who helped debug this). This can then cause an error much later on than the origin of the Dual, after Duals leak into the pullback definition and e.g. the gradient of the output is pulled onto the gradient of the input which assumed a non-Dual eltype, making it confusing to debug. Perhaps even worse, in some cases it causes the gradient to fail silently, either returning nothing or Duals for the gradient.

A MWE of silent failure on 1.9.0, in a temporary environment with only Zygote:

using Zygote
f(x) = x > 1.0 ? 1.0 : x^2
g(x) = sum(f.(x))
gradient(g, collect(0.5:0.25:1.5)) # (nothing,)

In contrast with the expected behaviour of:

using Zygote
f(x) = x > 1.0 ? one(x) : x^2
g(x) = sum(f.(x))
gradient(g, collect(0.5:0.25:1.5)) # ([1.0, 1.5, 2.0, 0.0, 0.0],)

A MWE of error, using repeat with the inner keyword as an example which doesn't allow the Dual to leak:

using Zygote
f(x) = x > 1.0 ? 1.0 : x^2
g(x) = sum(repeat(x, inner=2) .* f.(repeat(x, inner=2)))
gradient(g, collect(0.5:0.25:1.5))

results in:

ERROR: MethodError: no method matching Float64(::ForwardDiff.Dual{Nothing, Float64, 1})

Closest candidates are:
  (::Type{T})(::Real, ::RoundingMode) where T<:AbstractFloat
   @ Base rounding.jl:207
  (::Type{T})(::T) where T<:Number
   @ Core boot.jl:792
  (::Type{T})(::AbstractChar) where T<:Union{AbstractChar, Number}
   @ Base char.jl:50
  ...

Stacktrace:
 [1] convert(#unused#::Type{Float64}, x::ForwardDiff.Dual{Nothing, Float64, 1})
   @ Base .\number.jl:7
 [2] setindex!(A::Vector{Float64}, x::ForwardDiff.Dual{Nothing, Float64, 1}, i1::Int64)
   @ Base .\array.jl:969
 [3] (::Zygote.var"#626#634"{Int64, Vector{Float64}})(Δ::Vector{ForwardDiff.Dual{Nothing, Float64, 1}})
   @ Zygote C:\Users\domin\.julia\packages\Zygote\JeHtr\src\lib\array.jl:137
 [4] (::Zygote.var"#2822#back#642"{Zygote.var"#626#634"{Int64, Vector{Float64}}})(Δ::Vector{ForwardDiff.Dual{Nothing, Float64, 1}})
   @ Zygote C:\Users\domin\.julia\packages\ZygoteRules\OgCVT\src\adjoint.jl:80
 [5] Pullback
   @ .\REPL[29]:1 [inlined]
 [6] (::Zygote.Pullback{Tuple{typeof(g), Vector{Float64}}, Tuple{Zygote.var"#2822#back#642"{Zygote.var"#626#634"{Int64, Vector{Float64}}}, Zygote.var"#2822#back#642"{Zygote.var"#626#634"{Int64, Vector{Float64}}}, Zygote.Pullback{Tuple{Type{NamedTuple{(:inner,)}}, Tuple{Int64}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:inner,), Tuple{Int64}}}, Tuple{Int64}}, Tuple{Zygote.var"#2224#back#311"{Zygote.Jnew{NamedTuple{(:inner,), Tuple{Int64}}, Nothing, true}}}}}}, Zygote.var"#3027#back#778"{Zygote.var"#772#776"{Vector{Real}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Vector{Real}}, Tuple{}}, Zygote.var"#3802#back#1201"{Zygote.var"#1197#1200"{Vector{Float64}, Vector{Real}}}, Zygote.var"#2017#back#200"{typeof(identity)}, Zygote.var"#2017#back#200"{typeof(identity)}, Zygote.Pullback{Tuple{Type{NamedTuple{(:inner,)}}, Tuple{Int64}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:inner,), Tuple{Int64}}}, Tuple{Int64}}, Tuple{Zygote.var"#2224#back#311"{Zygote.Jnew{NamedTuple{(:inner,), Tuple{Int64}}, Nothing, true}}}}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcasted), typeof(f), Vector{Float64}}, Tuple{Zygote.var"#2173#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{}}, Zygote.var"#4123#back#1356"{Zygote.var"#1386#1388"}}}, Zygote.var"#2017#back#200"{typeof(identity)}, Zygote.var"#2881#back#684"{Zygote.var"#map_back#678"{typeof(Base.Broadcast.broadcastable), 1, Tuple{Tuple{}}, Tuple{Val{0}}, Tuple{}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcastable), Vector{Float64}}, Tuple{}}, Zygote.var"#2173#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing}, Tuple{}}, Zygote.var"#combine_styles_pullback#1162"{Tuple{Nothing, Nothing}}}}, Zygote.var"#2017#back#200"{typeof(identity)}}}}})(Δ::ForwardDiff.Dual{Nothing, Float64, 1})
   @ Zygote C:\Users\domin\.julia\packages\Zygote\JeHtr\src\compiler\interface2.jl:0
 [7] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{typeof(g), Vector{Float64}}, Tuple{Zygote.var"#2822#back#642"{Zygote.var"#626#634"{Int64, Vector{Float64}}}, Zygote.var"#2822#back#642"{Zygote.var"#626#634"{Int64, Vector{Float64}}}, Zygote.Pullback{Tuple{Type{NamedTuple{(:inner,)}}, Tuple{Int64}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:inner,), Tuple{Int64}}}, Tuple{Int64}}, Tuple{Zygote.var"#2224#back#311"{Zygote.Jnew{NamedTuple{(:inner,), Tuple{Int64}}, Nothing, true}}}}}}, Zygote.var"#3027#back#778"{Zygote.var"#772#776"{Vector{Real}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Vector{Real}}, Tuple{}}, Zygote.var"#3802#back#1201"{Zygote.var"#1197#1200"{Vector{Float64}, Vector{Real}}}, Zygote.var"#2017#back#200"{typeof(identity)}, Zygote.var"#2017#back#200"{typeof(identity)}, Zygote.Pullback{Tuple{Type{NamedTuple{(:inner,)}}, Tuple{Int64}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:inner,), Tuple{Int64}}}, Tuple{Int64}}, Tuple{Zygote.var"#2224#back#311"{Zygote.Jnew{NamedTuple{(:inner,), Tuple{Int64}}, Nothing, true}}}}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcasted), typeof(f), Vector{Float64}}, Tuple{Zygote.var"#2173#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{}}, Zygote.var"#4123#back#1356"{Zygote.var"#1386#1388"}}}, Zygote.var"#2017#back#200"{typeof(identity)}, Zygote.var"#2881#back#684"{Zygote.var"#map_back#678"{typeof(Base.Broadcast.broadcastable), 1, Tuple{Tuple{}}, Tuple{Val{0}}, Tuple{}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcastable), Vector{Float64}}, Tuple{}}, Zygote.var"#2173#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing}, Tuple{}}, Zygote.var"#combine_styles_pullback#1162"{Tuple{Nothing, Nothing}}}}, Zygote.var"#2017#back#200"{typeof(identity)}}}}}})(Δ::ForwardDiff.Dual{Nothing, Float64, 1})
   @ Zygote C:\Users\domin\.julia\packages\Zygote\JeHtr\src\compiler\interface.jl:45
 [8] gradient(f::Function, args::Vector{Float64})
   @ Zygote C:\Users\domin\.julia\packages\Zygote\JeHtr\src\compiler\interface.jl:97
 [9] top-level scope
   @ REPL[30]:1

which leaves it unclear where the Duals originate from, since the forward pass succeeds with incorrect outputs:

julia> pullback(g, collect(0.5:0.25:1.5))
(Dual{Nothing}(8.59375,7.25), Zygote.var"#75#76"{Zygote.Pullback{Tuple{typeof(g), Vector{Float64}}, Tuple{Zygote.var"#2822#back#642"{Zygote.var"#626#634"{Int64, Vector{Float64}}}, Zygote.var"#2822#back#642"{Zygote.var"#626#634"{Int64, Vector{Float64}}}, Zygote.Pullback{Tuple{Type{NamedTuple{(:inner,)}}, Tuple{Int64}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:inner,), Tuple{Int64}}}, Tuple{Int64}}, Tuple{Zygote.var"#2224#back#311"{Zygote.Jnew{NamedTuple{(:inner,), Tuple{Int64}}, Nothing, true}}}}}}, Zygote.var"#3027#back#778"{Zygote.var"#772#776"{Vector{Real}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Vector{Real}}, Tuple{}}, Zygote.var"#3802#back#1201"{Zygote.var"#1197#1200"{Vector{Float64}, Vector{Real}}}, Zygote.var"#2017#back#200"{typeof(identity)}, Zygote.var"#2017#back#200"{typeof(identity)}, Zygote.Pullback{Tuple{Type{NamedTuple{(:inner,)}}, Tuple{Int64}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:inner,), Tuple{Int64}}}, Tuple{Int64}}, Tuple{Zygote.var"#2224#back#311"{Zygote.Jnew{NamedTuple{(:inner,), Tuple{Int64}}, Nothing, true}}}}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcasted), typeof(f), Vector{Float64}}, Tuple{Zygote.var"#2173#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{}}, Zygote.var"#4123#back#1356"{Zygote.var"#1386#1388"}}}, Zygote.var"#2017#back#200"{typeof(identity)}, Zygote.var"#2881#back#684"{Zygote.var"#map_back#678"{typeof(Base.Broadcast.broadcastable), 1, Tuple{Tuple{}}, Tuple{Val{0}}, Tuple{}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcastable), Vector{Float64}}, Tuple{}}, Zygote.var"#2173#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing}, Tuple{}}, Zygote.var"#combine_styles_pullback#1162"{Tuple{Nothing, Nothing}}}}, Zygote.var"#2017#back#200"{typeof(identity)}}}}}}(∂(g)))

In the long run it would be better to fix this, however, in the short term simply adding an error before

T <: Union{Dual, Complex{<:Dual}} || return (out, _ -> nothing)
when the element type is abstract warning that the function needs to be made type stable on Dual inputs would at least make debugging this much easier. Happy to do a PR adding that.

An error would be better than the present state, e.g. isconcretetype(T) || error(...)

When T is abstract, could it just assume that there are Duals in there? If not, construct an array of zeros, instead of nothing?

An array of zeros doesn't seem quite right, in the first MWE above that would lead to incorrect zero gradients if I understand correctly?

Assuming Dual seems like it might work, since calling partials on a real or complex simply returns 0.0 anyway, although it might require a rework of the branching on complex inputs. Though I don't think the compiler will remove things if every element is not a Dual, so the quicker branch should be left for when the compiler confirms that the eltype isn't Dual.

Perhaps the dispatch on complex outputs could be moved inside the _broadcast_forward and _broadcast_forward_complex loops using another internal function? E.g. on this line

unbroadcast(args[i], broadcast((y1, o1) -> y1 * partials(o1,i), ȳ, out))

to split on complex o1 instead to do
unbroadcast(args[i], broadcast((y1, o1) -> (real(y1)*partials(real(o1),i) + imag(y1)*partials(imag(o1), i)), ȳ, out))

when required, so it is dispatched element wise. Should produce the same code when the eltype is uniform?

Small update: I have a fix for this written I think, just need to add tests.