FluxML / Zygote.jl

21st century AD

Home Page:https://fluxml.ai/Zygote.jl/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Pullback for Tuple(::Vector) gives the wrong type

ptiede opened this issue · comments

This is a copy of the issue JuliaDiff/ChainRulesCore.jl#623 where the pullback of Tuple(::Vector) gives the incorrect type.

@mcabbott found a simple MWE here

julia> gradient(x -> sum(Tuple(Zygote.@showgrad x)), [2.2, 3.3])
(x) = (1.0, 1.0)  # a Tuple not a Vector
([1.0, 1.0],)  # final answer fixed by _project

julia> pullback(x -> sum(Tuple(x)), [2.2, 3.3])[2](1.0)  # avoids final _project
((1.0, 1.0),)

julia> using ChainRules, ChainRulesCore

julia> rrule(Tuple, [2.2, 3.3])  # there is no rrule for this

julia> rrule(sum, (2.2, 3.3))
(5.5, ChainRules.var"#sum_pullback#1644"{Val{2}, ProjectTo{Tangent{Tuple{Float64, Float64}}, ...

After adding some printing it appears that there is a nothing coming from

julia> function mwe(x)
           Zygote.@showgrad @show x
           (;dx,) = x
           Zygote.@showgrad @show dx
           t = Tuple(dx)
           Zygote.@showgrad @show t
           sum(t)
       end
mwe (generic function with 1 method)

julia> x = [2.2, 3.3];

julia> l(x)
x = (dx = [2.2, 3.3],)
dx = [2.2, 3.3]
t = (2.2, 3.3)
5.5

julia> gradient(l, x)
x = (dx = [2.2, 3.3],)
dx = [2.2, 3.3]
t = (2.2, 3.3)
(#= REPL[8]:6 =# @show t) = nothing
(#= REPL[8]:4 =# @show dx) = nothing
ERROR: MethodError: no method matching (::ChainRulesCore.ProjectTo{AbstractArray, @NamedTuple{element::ChainRulesCore.ProjectTo{Float64, @NamedTuple{}}, axes::Tuple{Base.OneTo{Int64}}}})(::Tuple{Float64, Float64})

julia> ForwardDiff.gradient(l, x)
x = (dx = ForwardDiff.Dual{ForwardDiff.Tag{typeof(l), Float64}, Float64, 2}[Dual{ForwardDiff.Tag{typeof(l), Float64}}(2.2,1.0,0.0), Dual{ForwardDiff.Tag{typeof(l), Float64}}(3.3,0.0,1.0)],)
dx = ForwardDiff.Dual{ForwardDiff.Tag{typeof(l), Float64}, Float64, 2}[Dual{ForwardDiff.Tag{typeof(l), Float64}}(2.2,1.0,0.0), Dual{ForwardDiff.Tag{typeof(l), Float64}}(3.3,0.0,1.0)]
t = (Dual{ForwardDiff.Tag{typeof(l), Float64}}(2.2,1.0,0.0), Dual{ForwardDiff.Tag{typeof(l), Float64}}(3.3,0.0,1.0))
2-element Vector{Float64}:
 1.0
 1.0