Confusing error / silent failure with broadcasted functions with type instability
DomCRose opened this issue · comments
When a function is broadcasted which is type unstable with Dual type inputs, there is a good chance the element type of the resulting output will be abstract, leading to a failure of the logic at
Zygote.jl/src/lib/broadcast.jl
Line 284 in 2f49370
A MWE of silent failure on 1.9.0, in a temporary environment with only Zygote:
using Zygote
f(x) = x > 1.0 ? 1.0 : x^2
g(x) = sum(f.(x))
gradient(g, collect(0.5:0.25:1.5)) # (nothing,)
In contrast with the expected behaviour of:
using Zygote
f(x) = x > 1.0 ? one(x) : x^2
g(x) = sum(f.(x))
gradient(g, collect(0.5:0.25:1.5)) # ([1.0, 1.5, 2.0, 0.0, 0.0],)
A MWE of error, using repeat with the inner keyword as an example which doesn't allow the Dual to leak:
using Zygote
f(x) = x > 1.0 ? 1.0 : x^2
g(x) = sum(repeat(x, inner=2) .* f.(repeat(x, inner=2)))
gradient(g, collect(0.5:0.25:1.5))
results in:
ERROR: MethodError: no method matching Float64(::ForwardDiff.Dual{Nothing, Float64, 1})
Closest candidates are:
(::Type{T})(::Real, ::RoundingMode) where T<:AbstractFloat
@ Base rounding.jl:207
(::Type{T})(::T) where T<:Number
@ Core boot.jl:792
(::Type{T})(::AbstractChar) where T<:Union{AbstractChar, Number}
@ Base char.jl:50
...
Stacktrace:
[1] convert(#unused#::Type{Float64}, x::ForwardDiff.Dual{Nothing, Float64, 1})
@ Base .\number.jl:7
[2] setindex!(A::Vector{Float64}, x::ForwardDiff.Dual{Nothing, Float64, 1}, i1::Int64)
@ Base .\array.jl:969
[3] (::Zygote.var"#626#634"{Int64, Vector{Float64}})(Δ::Vector{ForwardDiff.Dual{Nothing, Float64, 1}})
@ Zygote C:\Users\domin\.julia\packages\Zygote\JeHtr\src\lib\array.jl:137
[4] (::Zygote.var"#2822#back#642"{Zygote.var"#626#634"{Int64, Vector{Float64}}})(Δ::Vector{ForwardDiff.Dual{Nothing, Float64, 1}})
@ Zygote C:\Users\domin\.julia\packages\ZygoteRules\OgCVT\src\adjoint.jl:80
[5] Pullback
@ .\REPL[29]:1 [inlined]
[6] (::Zygote.Pullback{Tuple{typeof(g), Vector{Float64}}, Tuple{Zygote.var"#2822#back#642"{Zygote.var"#626#634"{Int64, Vector{Float64}}}, Zygote.var"#2822#back#642"{Zygote.var"#626#634"{Int64, Vector{Float64}}}, Zygote.Pullback{Tuple{Type{NamedTuple{(:inner,)}}, Tuple{Int64}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:inner,), Tuple{Int64}}}, Tuple{Int64}}, Tuple{Zygote.var"#2224#back#311"{Zygote.Jnew{NamedTuple{(:inner,), Tuple{Int64}}, Nothing, true}}}}}}, Zygote.var"#3027#back#778"{Zygote.var"#772#776"{Vector{Real}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Vector{Real}}, Tuple{}}, Zygote.var"#3802#back#1201"{Zygote.var"#1197#1200"{Vector{Float64}, Vector{Real}}}, Zygote.var"#2017#back#200"{typeof(identity)}, Zygote.var"#2017#back#200"{typeof(identity)}, Zygote.Pullback{Tuple{Type{NamedTuple{(:inner,)}}, Tuple{Int64}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:inner,), Tuple{Int64}}}, Tuple{Int64}}, Tuple{Zygote.var"#2224#back#311"{Zygote.Jnew{NamedTuple{(:inner,), Tuple{Int64}}, Nothing, true}}}}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcasted), typeof(f), Vector{Float64}}, Tuple{Zygote.var"#2173#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{}}, Zygote.var"#4123#back#1356"{Zygote.var"#1386#1388"}}}, Zygote.var"#2017#back#200"{typeof(identity)}, Zygote.var"#2881#back#684"{Zygote.var"#map_back#678"{typeof(Base.Broadcast.broadcastable), 1, Tuple{Tuple{}}, Tuple{Val{0}}, Tuple{}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcastable), Vector{Float64}}, Tuple{}}, Zygote.var"#2173#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing}, Tuple{}}, Zygote.var"#combine_styles_pullback#1162"{Tuple{Nothing, Nothing}}}}, Zygote.var"#2017#back#200"{typeof(identity)}}}}})(Δ::ForwardDiff.Dual{Nothing, Float64, 1})
@ Zygote C:\Users\domin\.julia\packages\Zygote\JeHtr\src\compiler\interface2.jl:0
[7] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{typeof(g), Vector{Float64}}, Tuple{Zygote.var"#2822#back#642"{Zygote.var"#626#634"{Int64, Vector{Float64}}}, Zygote.var"#2822#back#642"{Zygote.var"#626#634"{Int64, Vector{Float64}}}, Zygote.Pullback{Tuple{Type{NamedTuple{(:inner,)}}, Tuple{Int64}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:inner,), Tuple{Int64}}}, Tuple{Int64}}, Tuple{Zygote.var"#2224#back#311"{Zygote.Jnew{NamedTuple{(:inner,), Tuple{Int64}}, Nothing, true}}}}}}, Zygote.var"#3027#back#778"{Zygote.var"#772#776"{Vector{Real}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Vector{Real}}, Tuple{}}, Zygote.var"#3802#back#1201"{Zygote.var"#1197#1200"{Vector{Float64}, Vector{Real}}}, Zygote.var"#2017#back#200"{typeof(identity)}, Zygote.var"#2017#back#200"{typeof(identity)}, Zygote.Pullback{Tuple{Type{NamedTuple{(:inner,)}}, Tuple{Int64}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:inner,), Tuple{Int64}}}, Tuple{Int64}}, Tuple{Zygote.var"#2224#back#311"{Zygote.Jnew{NamedTuple{(:inner,), Tuple{Int64}}, Nothing, true}}}}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcasted), typeof(f), Vector{Float64}}, Tuple{Zygote.var"#2173#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{}}, Zygote.var"#4123#back#1356"{Zygote.var"#1386#1388"}}}, Zygote.var"#2017#back#200"{typeof(identity)}, Zygote.var"#2881#back#684"{Zygote.var"#map_back#678"{typeof(Base.Broadcast.broadcastable), 1, Tuple{Tuple{}}, Tuple{Val{0}}, Tuple{}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcastable), Vector{Float64}}, Tuple{}}, Zygote.var"#2173#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing}, Tuple{}}, Zygote.var"#combine_styles_pullback#1162"{Tuple{Nothing, Nothing}}}}, Zygote.var"#2017#back#200"{typeof(identity)}}}}}})(Δ::ForwardDiff.Dual{Nothing, Float64, 1})
@ Zygote C:\Users\domin\.julia\packages\Zygote\JeHtr\src\compiler\interface.jl:45
[8] gradient(f::Function, args::Vector{Float64})
@ Zygote C:\Users\domin\.julia\packages\Zygote\JeHtr\src\compiler\interface.jl:97
[9] top-level scope
@ REPL[30]:1
which leaves it unclear where the Duals originate from, since the forward pass succeeds with incorrect outputs:
julia> pullback(g, collect(0.5:0.25:1.5))
(Dual{Nothing}(8.59375,7.25), Zygote.var"#75#76"{Zygote.Pullback{Tuple{typeof(g), Vector{Float64}}, Tuple{Zygote.var"#2822#back#642"{Zygote.var"#626#634"{Int64, Vector{Float64}}}, Zygote.var"#2822#back#642"{Zygote.var"#626#634"{Int64, Vector{Float64}}}, Zygote.Pullback{Tuple{Type{NamedTuple{(:inner,)}}, Tuple{Int64}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:inner,), Tuple{Int64}}}, Tuple{Int64}}, Tuple{Zygote.var"#2224#back#311"{Zygote.Jnew{NamedTuple{(:inner,), Tuple{Int64}}, Nothing, true}}}}}}, Zygote.var"#3027#back#778"{Zygote.var"#772#776"{Vector{Real}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Vector{Real}}, Tuple{}}, Zygote.var"#3802#back#1201"{Zygote.var"#1197#1200"{Vector{Float64}, Vector{Real}}}, Zygote.var"#2017#back#200"{typeof(identity)}, Zygote.var"#2017#back#200"{typeof(identity)}, Zygote.Pullback{Tuple{Type{NamedTuple{(:inner,)}}, Tuple{Int64}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:inner,), Tuple{Int64}}}, Tuple{Int64}}, Tuple{Zygote.var"#2224#back#311"{Zygote.Jnew{NamedTuple{(:inner,), Tuple{Int64}}, Nothing, true}}}}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcasted), typeof(f), Vector{Float64}}, Tuple{Zygote.var"#2173#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{}}, Zygote.var"#4123#back#1356"{Zygote.var"#1386#1388"}}}, Zygote.var"#2017#back#200"{typeof(identity)}, Zygote.var"#2881#back#684"{Zygote.var"#map_back#678"{typeof(Base.Broadcast.broadcastable), 1, Tuple{Tuple{}}, Tuple{Val{0}}, Tuple{}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcastable), Vector{Float64}}, Tuple{}}, Zygote.var"#2173#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing}, Tuple{}}, Zygote.var"#combine_styles_pullback#1162"{Tuple{Nothing, Nothing}}}}, Zygote.var"#2017#back#200"{typeof(identity)}}}}}}(∂(g)))
In the long run it would be better to fix this, however, in the short term simply adding an error before
Zygote.jl/src/lib/broadcast.jl
Line 284 in 2f49370
An error would be better than the present state, e.g. isconcretetype(T) || error(...)
When T
is abstract, could it just assume that there are Duals in there? If not, construct an array of zeros, instead of nothing
?
An array of zeros doesn't seem quite right, in the first MWE above that would lead to incorrect zero gradients if I understand correctly?
Assuming Dual seems like it might work, since calling partials on a real or complex simply returns 0.0 anyway, although it might require a rework of the branching on complex inputs. Though I don't think the compiler will remove things if every element is not a Dual, so the quicker branch should be left for when the compiler confirms that the eltype isn't Dual.
Perhaps the dispatch on complex outputs could be moved inside the _broadcast_forward and _broadcast_forward_complex loops using another internal function? E.g. on this line
Zygote.jl/src/lib/broadcast.jl
Line 298 in 2f49370
to split on complex o1 instead to do
Zygote.jl/src/lib/broadcast.jl
Line 311 in 2f49370
when required, so it is dispatched element wise. Should produce the same code when the eltype is uniform?
Small update: I have a fix for this written I think, just need to add tests.