Autodiff support
SebastianM-C opened this issue · comments
Sebastian Micluța-Câmpeanu commented
Would it make sense to be able to autodiff through knn
s?
I tried using ForwardDiff
using NearestNeighbors
using ForwardDiff
data = rand(3, 10^3)
kdtree = KDTree(data)
ForwardDiff.gradient(x->sum(knn(kdtree, x, 4)[2]), [0.,0.,0.])
but due to the definition of distances
dist = Vector{get_T(eltype(V))}(undef, k)
and the fact that a dual number distance is asigned to elements of that vector
dist_d = evaluate(tree.metric, tree.data[idx], point, do_end)
...
best_dists[1] = dist_d
I get the following error with the code above:
julia> ForwardDiff.gradient(x->sum(knn(kdtree, x, 4)[2]), [0.,0.,0.])
ERROR: MethodError: no method matching Float64(::ForwardDiff.Dual{ForwardDiff.Tag{var"#3#4",Float64},Float64,3})
Closest candidates are:
Float64(::Real, ::RoundingMode) where T<:AbstractFloat at rounding.jl:200
Float64(::T) where T<:Number at boot.jl:718
Float64(::Int8) at float.jl:60
...
Stacktrace:
[1] convert(::Type{Float64}, ::ForwardDiff.Dual{ForwardDiff.Tag{var"#3#4",Float64},Float64,3}) at .\number.jl:7
[2] setindex!(::Array{Float64,1}, ::ForwardDiff.Dual{ForwardDiff.Tag{var"#3#4",Float64},Float64,3}, ::Int64) at .\array.jl:782
[3] add_points_knn! at C:\Users\sebastian\.julia\packages\NearestNeighbors\pb8hw\src\tree_ops.jl:104 [inlined]
[4] knn_kernel!(::KDTree{StaticArrays.SArray{Tuple{3},Float64,1,3},Euclidean,Float64}, ::Int64, ::Array{ForwardDiff.Dual{ForwardDiff.Tag{var"#3#4",Float64},Float64,3},1}, ::Array{Int64,1}, ::Array{Float64,1}, ::ForwardDiff.Dual{ForwardDiff.Tag{var"#3#4",Float64},Float64,3}, ::typeof(NearestNeighbors.always_false)) at C:\Users\sebastian\.julia\packages\NearestNeighbors\pb8hw\src\kd_tree.jl:174
[5] knn_kernel!(::KDTree{StaticArrays.SArray{Tuple{3},Float64,1,3},Euclidean,Float64}, ::Int64, ::Array{ForwardDiff.Dual{ForwardDiff.Tag{var"#3#4",Float64},Float64,3},1}, ::Array{Int64,1}, ::Array{Float64,1}, ::ForwardDiff.Dual{ForwardDiff.Tag{var"#3#4",Float64},Float64,3}, ::typeof(NearestNeighbors.always_false)) at C:\Users\sebastian\.julia\packages\NearestNeighbors\pb8hw\src\kd_tree.jl:196 (repeats 7 times)
[6] _knn(::KDTree{StaticArrays.SArray{Tuple{3},Float64,1,3},Euclidean,Float64}, ::Array{ForwardDiff.Dual{ForwardDiff.Tag{var"#3#4",Float64},Float64,3},1}, ::Array{Int64,1}, ::Array{Float64,1}, ::Function) at C:\Users\sebastian\.julia\packages\NearestNeighbors\pb8hw\src\kd_tree.jl:158
[7] knn_point!(::KDTree{StaticArrays.SArray{Tuple{3},Float64,1,3},Euclidean,Float64}, ::Array{ForwardDiff.Dual{ForwardDiff.Tag{var"#3#4",Float64},Float64,3},1}, ::Bool, ::Array{Float64,1}, ::Array{Int64,1}, ::Function) at C:\Users\sebastian\.julia\packages\NearestNeighbors\pb8hw\src\knn.jl:31
[8] knn at C:\Users\sebastian\.julia\packages\NearestNeighbors\pb8hw\src\knn.jl:44 [inlined]
[9] knn(::KDTree{StaticArrays.SArray{Tuple{3},Float64,1,3},Euclidean,Float64}, ::Array{ForwardDiff.Dual{ForwardDiff.Tag{var"#3#4",Float64},Float64,3},1}, ::Int64) at C:\Users\sebastian\.julia\packages\NearestNeighbors\pb8hw\src\knn.jl:41
[10] (::var"#3#4")(::Array{ForwardDiff.Dual{ForwardDiff.Tag{var"#3#4",Float64},Float64,3},1}) at .\REPL[7]:1
[11] vector_mode_gradient(::var"#3#4", ::Array{Float64,1}, ::ForwardDiff.GradientConfig{ForwardDiff.Tag{var"#3#4",Float64},Float64,3,Array{ForwardDiff.Dual{ForwardDiff.Tag{var"#3#4",Float64},Float64,3},1}}) at C:\Users\sebastian\.julia\packages\ForwardDiff\Asf4O\src\apiutils.jl:37
[12] gradient(::Function, ::Array{Float64,1}, ::ForwardDiff.GradientConfig{ForwardDiff.Tag{var"#3#4",Float64},Float64,3,Array{ForwardDiff.Dual{ForwardDiff.Tag{var"#3#4",Float64},Float64,3},1}}, ::Val{true}) at C:\Users\sebastian\.julia\packages\ForwardDiff\Asf4O\src\gradient.jl:17
[13] gradient(::Function, ::Array{Float64,1}, ::ForwardDiff.GradientConfig{ForwardDiff.Tag{var"#3#4",Float64},Float64,3,Array{ForwardDiff.Dual{ForwardDiff.Tag{var"#3#4",Float64},Float64,3},1}}) at C:\Users\sebastian\.julia\packages\ForwardDiff\Asf4O\src\gradient.jl:15 (repeats 2 times)
[14] top-level scope at REPL[7]:1
Kristoffer Carlsson commented
Wouldn't just computing knn by itself and then differentiating through the distance function give the same result?
Sebastian Micluța-Câmpeanu commented
I'm not sure I understand what you mean. Could you please elaborate a bit more?
Παναγιώτης Γεωργακόπουλος commented
I think this is still an issue.
I'm trying defining a loss function as:
function loss(model, x, y)
sum(nn(tree, model(x))[2] .^ 2)
end
(Yes, instead of y
I'm using the nearest neighbor!)
And Flux.train!
returns
Mutating arrays is not supported -- called setindex!(Vector{Int64}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)
stacktrace
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] _throw_mutation_error(f::Function, args::Vector{Int64})
@ Zygote /opt/julia/packages/Zygote/TSj5C/src/lib/array.jl:88
[3] (::Zygote.var"#550#551"{Vector{Int64}})(#unused#::Nothing)
@ Zygote /opt/julia/packages/Zygote/TSj5C/src/lib/array.jl:100
[4] (::Zygote.var"#2620#back#552"{Zygote.var"#550#551"{Vector{Int64}}})(Δ::Nothing)
@ Zygote /opt/julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:71
[5] Pullback
@ /opt/julia/packages/NearestNeighbors/huCPc/src/knn.jl:41 [inlined]
[6] (::Zygote.Pullback{Tuple{typeof(NearestNeighbors.knn_point!), KDTree{StaticArraysCore.SVector{2, Float32}, Euclidean, Float32}, StaticArraysCore.SVector{2, Float32}, Bool, Vector{Float32}, Vector{Int64}, typeof(NearestNeighbors.always_false)}, Any})(Δ::Nothing)
@ Zygote /opt/julia/packages/Zygote/TSj5C/src/compiler/interface2.jl:0
[7] Pullback
@ /opt/julia/packages/NearestNeighbors/huCPc/src/knn.jl:24 [inlined]
[8] (::Zygote.Pullback{Tuple{typeof(knn), KDTree{StaticArraysCore.SVector{2, Float32}, Euclidean, Float32}, Vector{StaticArraysCore.SVector{2, Float32}}, Int64, Bool, typeof(NearestNeighbors.always_false)}, Any})(Δ::Tuple{Nothing, Vector{Zygote.OneElement{Float32, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}}})
@ Zygote /opt/julia/packages/Zygote/TSj5C/src/compiler/interface2.jl:0
[9] Pullback
@ /opt/julia/packages/NearestNeighbors/huCPc/src/knn.jl:63 [inlined]
[10] (::Zygote.Pullback{Tuple{typeof(knn), KDTree{StaticArraysCore.SVector{2, Float32}, Euclidean, Float32}, Matrix{Float32}, Int64, Bool, typeof(NearestNeighbors.always_false)}, Any})(Δ::Tuple{Nothing, Vector{Zygote.OneElement{Float32, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}}})
@ Zygote /opt/julia/packages/Zygote/TSj5C/src/compiler/interface2.jl:0
[11] Pullback
@ /opt/julia/packages/NearestNeighbors/huCPc/src/knn.jl:70 [inlined]
[12] (::Zygote.Pullback{Tuple{typeof(NearestNeighbors._nn), KDTree{StaticArraysCore.SVector{2, Float32}, Euclidean, Float32}, Matrix{Float32}, typeof(NearestNeighbors.always_false)}, Tuple{Zygote.Pullback{Tuple{typeof(knn), KDTree{StaticArraysCore.SVector{2, Float32}, Euclidean, Float32}, Matrix{Float32}, Int64, Bool, typeof(NearestNeighbors.always_false)}, Any}}})(Δ::Tuple{Nothing, Vector{Zygote.OneElement{Float32, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}}})
@ Zygote /opt/julia/packages/Zygote/TSj5C/src/compiler/interface2.jl:0
[13] Pullback
@ /opt/julia/packages/NearestNeighbors/huCPc/src/knn.jl:68 [inlined]
[14] (::Zygote.Pullback{Tuple{typeof(nn), KDTree{StaticArraysCore.SVector{2, Float32}, Euclidean, Float32}, Matrix{Float32}, typeof(NearestNeighbors.always_false)}, Tuple{Zygote.Pullback{Tuple{typeof(|>), Tuple{Vector{Vector{Int64}}, Vector{Vector{Float32}}}, typeof(NearestNeighbors._firsteach)}, Tuple{Zygote.Pullback{Tuple{typeof(NearestNeighbors._firsteach), Tuple{Vector{Vector{Int64}}, Vector{Vector{Float32}}}}, Tuple{Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Vector{Float32}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcasted), typeof(first), Vector{Vector{Float32}}}, Tuple{Zygote.var"#2138#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing}, Tuple{}}, Zygote.var"#combine_styles_pullback#1169"{Tuple{Nothing, Nothing}}}}, Zygote.var"#1982#back#200"{typeof(identity)}, Zygote.var"#2841#back#683"{Zygote.var"#map_back#677"{typeof(Base.Broadcast.broadcastable), 1, Tuple{Tuple{}}, Tuple{Val{0}}, Tuple{}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcastable), Vector{Vector{Float32}}}, Tuple{}}, Zygote.var"#1982#back#200"{typeof(identity)}, Zygote.var"#2138#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{}}, Zygote.var"#4086#back#1368"{Zygote.var"#∇broadcasted#1364"{Tuple{Vector{Vector{Float32}}}, Vector{Tuple{Float32, Zygote.Pullback{Tuple{typeof(first), Vector{Float32}}, Tuple{Zygote.var"#2571#back#528"{Zygote.var"#538#540"{1, Float32, Vector{Float32}, Tuple{Int64}}}, Zygote.Pullback{Tuple{typeof(first), Base.OneTo{Int64}}, Tuple{Zygote.Pullback{Tuple{typeof(oneunit), Type{Int64}}, Tuple{Zygote.ZBack{Zygote.var"#IntX_pullback#330"}, Zygote.ZBack{ChainRules.var"#one_pullback#792"}}}}}, Zygote.ZBack{ChainRules.var"#eachindex_pullback#375"{Tuple{Vector{Float32}}}}}}}}, Val{2}}}}}}}, Zygote.Pullback{Tuple{typeof(last), Tuple{Vector{Vector{Int64}}, Vector{Vector{Float32}}}}, Tuple{Zygote.var"#2014#back#218"{Zygote.var"#back#216"{2, Zygote.Context{false}, Int64, Vector{Vector{Float32}}}}, Zygote.Pullback{Tuple{typeof(lastindex), Tuple{Vector{Vector{Int64}}, Vector{Vector{Float32}}}}, Tuple{Zygote.ZBack{ChainRules.var"#length_pullback#746"}}}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Vector{Int64}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcasted), typeof(first), Vector{Vector{Int64}}}, Tuple{Zygote.var"#2138#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing}, Tuple{}}, Zygote.var"#combine_styles_pullback#1169"{Tuple{Nothing, Nothing}}}}, Zygote.var"#1982#back#200"{typeof(identity)}, Zygote.var"#2841#back#683"{Zygote.var"#map_back#677"{typeof(Base.Broadcast.broadcastable), 1, Tuple{Tuple{}}, Tuple{Val{0}}, Tuple{}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcastable), Vector{Vector{Int64}}}, Tuple{}}, Zygote.var"#1982#back#200"{typeof(identity)}, Zygote.var"#2138#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{}}, Zygote.var"#4086#back#1368"{Zygote.var"#∇broadcasted#1364"{Tuple{Vector{Vector{Int64}}}, Vector{Tuple{Int64, Zygote.Pullback{Tuple{typeof(first), Vector{Int64}}, Tuple{Zygote.var"#2571#back#528"{Zygote.var"#538#540"{1, Int64, Vector{Int64}, Tuple{Int64}}}, Zygote.Pullback{Tuple{typeof(first), Base.OneTo{Int64}}, Tuple{Zygote.Pullback{Tuple{typeof(oneunit), Type{Int64}}, Tuple{Zygote.ZBack{Zygote.var"#IntX_pullback#330"}, Zygote.ZBack{ChainRules.var"#one_pullback#792"}}}}}, Zygote.ZBack{ChainRules.var"#eachindex_pullback#375"{Tuple{Vector{Int64}}}}}}}}, Val{2}}}}}}}, Zygote.var"#2094#back#269"{Zygote.var"#264#268"{Tuple{Nothing}}}, Zygote.var"#1982#back#200"{typeof(identity)}}}}}, Zygote.Pullback{Tuple{typeof(NearestNeighbors._nn), KDTree{StaticArraysCore.SVector{2, Float32}, Euclidean, Float32}, Matrix{Float32}, typeof(NearestNeighbors.always_false)}, Tuple{Zygote.Pullback{Tuple{typeof(knn), KDTree{StaticArraysCore.SVector{2, Float32}, Euclidean, Float32}, Matrix{Float32}, Int64, Bool, typeof(NearestNeighbors.always_false)}, Any}}}}})(Δ::Tuple{Nothing, Vector{Float32}})
@ Zygote /opt/julia/packages/Zygote/TSj5C/src/compiler/interface2.jl:0
[15] Pullback
@ /opt/julia/packages/NearestNeighbors/huCPc/src/knn.jl:68 [inlined]
[16] Pullback
@ ./In[41]:2 [inlined]
[17] (::Zygote.Pullback{Tuple{typeof(loss), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Matrix{Float32}, Vector{Float32}}, Tuple{Zygote.var"#1891#back#157"{Zygote.var"#153#156"}, Zygote.var"#1998#back#209"{Zygote.var"#back#207"{2, 2, Zygote.Context{false}, Vector{Float32}}}, Zygote.Pullback{Tuple{typeof(nn), KDTree{StaticArraysCore.SVector{2, Float32}, Euclidean, Float32}, Matrix{Float32}}, Tuple{Zygote.Pullback{Tuple{typeof(nn), KDTree{StaticArraysCore.SVector{2, Float32}, Euclidean, Float32}, Matrix{Float32}, typeof(NearestNeighbors.always_false)}, Tuple{Zygote.Pullback{Tuple{typeof(|>), Tuple{Vector{Vector{Int64}}, Vector{Vector{Float32}}}, typeof(NearestNeighbors._firsteach)}, Tuple{Zygote.Pullback{Tuple{typeof(NearestNeighbors._firsteach), Tuple{Vector{Vector{Int64}}, Vector{Vector{Float32}}}}, Tuple{Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Vector{Float32}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcasted), typeof(first), Vector{Vector{Float32}}}, Tuple{Zygote.var"#2138#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing}, Tuple{}}, Zygote.var"#combine_styles_pullback#1169"{Tuple{Nothing, Nothing}}}}, Zygote.var"#1982#back#200"{typeof(identity)}, Zygote.var"#2841#back#683"{Zygote.var"#map_back#677"{typeof(Base.Broadcast.broadcastable), 1, Tuple{Tuple{}}, Tuple{Val{0}}, Tuple{}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcastable), Vector{Vector{Float32}}}, Tuple{}}, Zygote.var"#1982#back#200"{typeof(identity)}, Zygote.var"#2138#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{}}, Zygote.var"#4086#back#1368"{Zygote.var"#∇broadcasted#1364"{Tuple{Vector{Vector{Float32}}}, Vector{Tuple{Float32, Zygote.Pullback{Tuple{typeof(first), Vector{Float32}}, Tuple{Zygote.var"#2571#back#528"{Zygote.var"#538#540"{1, Float32, Vector{Float32}, Tuple{Int64}}}, Zygote.Pullback{Tuple{typeof(first), Base.OneTo{Int64}}, Tuple{Zygote.Pullback{Tuple{typeof(oneunit), Type{Int64}}, Tuple{Zygote.ZBack{Zygote.var"#IntX_pullback#330"}, Zygote.ZBack{ChainRules.var"#one_pullback#792"}}}}}, Zygote.ZBack{ChainRules.var"#eachindex_pullback#375"{Tuple{Vector{Float32}}}}}}}}, Val{2}}}}}}}, Zygote.Pullback{Tuple{typeof(last), Tuple{Vector{Vector{Int64}}, Vector{Vector{Float32}}}}, Tuple{Zygote.var"#2014#back#218"{Zygote.var"#back#216"{2, Zygote.Context{false}, Int64, Vector{Vector{Float32}}}}, Zygote.Pullback{Tuple{typeof(lastindex), Tuple{Vector{Vector{Int64}}, Vector{Vector{Float32}}}}, Tuple{Zygote.ZBack{ChainRules.var"#length_pullback#746"}}}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Vector{Int64}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcasted), typeof(first), Vector{Vector{Int64}}}, Tuple{Zygote.var"#2138#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing}, Tuple{}}, Zygote.var"#combine_styles_pullback#1169"{Tuple{Nothing, Nothing}}}}, Zygote.var"#1982#back#200"{typeof(identity)}, Zygote.var"#2841#back#683"{Zygote.var"#map_back#677"{typeof(Base.Broadcast.broadcastable), 1, Tuple{Tuple{}}, Tuple{Val{0}}, Tuple{}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcastable), Vector{Vector{Int64}}}, Tuple{}}, Zygote.var"#1982#back#200"{typeof(identity)}, Zygote.var"#2138#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{}}, Zygote.var"#4086#back#1368"{Zygote.var"#∇broadcasted#1364"{Tuple{Vector{Vector{Int64}}}, Vector{Tuple{Int64, Zygote.Pullback{Tuple{typeof(first), Vector{Int64}}, Tuple{Zygote.var"#2571#back#528"{Zygote.var"#538#540"{1, Int64, Vector{Int64}, Tuple{Int64}}}, Zygote.Pullback{Tuple{typeof(first), Base.OneTo{Int64}}, Tuple{Zygote.Pullback{Tuple{typeof(oneunit), Type{Int64}}, Tuple{Zygote.ZBack{Zygote.var"#IntX_pullback#330"}, Zygote.ZBack{ChainRules.var"#one_pullback#792"}}}}}, Zygote.ZBack{ChainRules.var"#eachindex_pullback#375"{Tuple{Vector{Int64}}}}}}}}, Val{2}}}}}}}, Zygote.var"#2094#back#269"{Zygote.var"#264#268"{Tuple{Nothing}}}, Zygote.var"#1982#back#200"{typeof(identity)}}}}}, Zygote.Pullback{Tuple{typeof(NearestNeighbors._nn), KDTree{StaticArraysCore.SVector{2, Float32}, Euclidean, Float32}, Matrix{Float32}, typeof(NearestNeighbors.always_false)}, Tuple{Zygote.Pullback{Tuple{typeof(knn), KDTree{StaticArraysCore.SVector{2, Float32}, Euclidean, Float32}, Matrix{Float32}, Int64, Bool, typeof(NearestNeighbors.always_false)}, Any}}}}}}}, Zygote.var"#2987#back#777"{Zygote.var"#771#775"{Vector{Float32}}}, Zygote.var"#1955#back#190"{Zygote.var"#186#189"{Zygote.Context{false}, GlobalRef, KDTree{StaticArraysCore.SVector{2, Float32}, Euclidean, Float32}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Vector{Float32}}, Tuple{}}, Zygote.Pullback{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Matrix{Float32}}, Tuple{Zygote.var"#2149#back#299"{Zygote.var"#back#298"{:weight, Zygote.Context{false}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Matrix{Float32}}}, Zygote.ZBack{ChainRules.var"#size_pullback#918"}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Matrix{Float32}}, Tuple{}}, Zygote.ZBack{Flux.var"#_size_check_pullback#204"{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Matrix{Float32}, Pair{Int64, Int64}}}}, Zygote.var"#3878#back#1250"{Zygote.var"#1246#1249"}, Zygote.ZBack{ChainRules.var"#times_pullback#1486"{Matrix{Float32}, Matrix{Float32}}}, Zygote.var"#2149#back#299"{Zygote.var"#back#298"{:weight, Zygote.Context{false}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Matrix{Float32}}}, Zygote.ZBack{Flux.var"#177#178"}, Zygote.Pullback{Tuple{Type{Pair}, Int64, Int64}, Tuple{Zygote.ZBack{ChainRules.var"#fieldtype_pullback#420"}, Zygote.ZBack{ChainRules.var"#fieldtype_pullback#420"}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Int64}, Int64}, Tuple{}}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Int64}, Int64}, Tuple{}}, Zygote.var"#2176#back#309"{Zygote.Jnew{Pair{Int64, Int64}, Nothing, false}}}}, Zygote.Pullback{Tuple{typeof(NNlib.fast_act), typeof(identity), Matrix{Float32}}, Tuple{}}, Zygote.var"#2149#back#299"{Zygote.var"#back#298"{:bias, Zygote.Context{false}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Vector{Float32}}}, Zygote.var"#3734#back#1184"{Zygote.var"#1178#1182"{Tuple{Matrix{Float32}, Vector{Float32}}}}, Zygote.var"#2149#back#299"{Zygote.var"#back#298"{:σ, Zygote.Context{false}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, typeof(identity)}}}}, Zygote.var"#3862#back#1242"{Zygote.var"#1238#1241"{2, Vector{Float32}}}}})(Δ::Float32)
@ Zygote /opt/julia/packages/Zygote/TSj5C/src/compiler/interface2.jl:0
[18] #287
@ /opt/julia/packages/Zygote/TSj5C/src/lib/lib.jl:206 [inlined]
[19] #2138#back
@ /opt/julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:71 [inlined]
[20] Pullback
@ /opt/julia/packages/Flux/Nzh8J/src/train.jl:107 [inlined]
[21] (::Zygote.Pullback{Tuple{Flux.Train.var"#4#5"{typeof(loss), Tuple{Matrix{Float32}, Vector{Float32}}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}, Tuple{Zygote.var"#2138#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing}, Tuple{Nothing, Nothing}}, Zygote.Pullback{Tuple{typeof(loss), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Matrix{Float32}, Vector{Float32}}, Tuple{Zygote.var"#1891#back#157"{Zygote.var"#153#156"}, Zygote.var"#1998#back#209"{Zygote.var"#back#207"{2, 2, Zygote.Context{false}, Vector{Float32}}}, Zygote.Pullback{Tuple{typeof(nn), KDTree{StaticArraysCore.SVector{2, Float32}, Euclidean, Float32}, Matrix{Float32}}, Tuple{Zygote.Pullback{Tuple{typeof(nn), KDTree{StaticArraysCore.SVector{2, Float32}, Euclidean, Float32}, Matrix{Float32}, typeof(NearestNeighbors.always_false)}, Tuple{Zygote.Pullback{Tuple{typeof(|>), Tuple{Vector{Vector{Int64}}, Vector{Vector{Float32}}}, typeof(NearestNeighbors._firsteach)}, Tuple{Zygote.Pullback{Tuple{typeof(NearestNeighbors._firsteach), Tuple{Vector{Vector{Int64}}, Vector{Vector{Float32}}}}, Tuple{Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Vector{Float32}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcasted), typeof(first), Vector{Vector{Float32}}}, Tuple{Zygote.var"#2138#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing}, Tuple{}}, Zygote.var"#combine_styles_pullback#1169"{Tuple{Nothing, Nothing}}}}, Zygote.var"#1982#back#200"{typeof(identity)}, Zygote.var"#2841#back#683"{Zygote.var"#map_back#677"{typeof(Base.Broadcast.broadcastable), 1, Tuple{Tuple{}}, Tuple{Val{0}}, Tuple{}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcastable), Vector{Vector{Float32}}}, Tuple{}}, Zygote.var"#1982#back#200"{typeof(identity)}, Zygote.var"#2138#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{}}, Zygote.var"#4086#back#1368"{Zygote.var"#∇broadcasted#1364"{Tuple{Vector{Vector{Float32}}}, Vector{Tuple{Float32, Zygote.Pullback{Tuple{typeof(first), Vector{Float32}}, Tuple{Zygote.var"#2571#back#528"{Zygote.var"#538#540"{1, Float32, Vector{Float32}, Tuple{Int64}}}, Zygote.Pullback{Tuple{typeof(first), Base.OneTo{Int64}}, Tuple{Zygote.Pullback{Tuple{typeof(oneunit), Type{Int64}}, Tuple{Zygote.ZBack{Zygote.var"#IntX_pullback#330"}, Zygote.ZBack{ChainRules.var"#one_pullback#792"}}}}}, Zygote.ZBack{ChainRules.var"#eachindex_pullback#375"{Tuple{Vector{Float32}}}}}}}}, Val{2}}}}}}}, Zygote.Pullback{Tuple{typeof(last), Tuple{Vector{Vector{Int64}}, Vector{Vector{Float32}}}}, Tuple{Zygote.var"#2014#back#218"{Zygote.var"#back#216"{2, Zygote.Context{false}, Int64, Vector{Vector{Float32}}}}, Zygote.Pullback{Tuple{typeof(lastindex), Tuple{Vector{Vector{Int64}}, Vector{Vector{Float32}}}}, Tuple{Zygote.ZBack{ChainRules.var"#length_pullback#746"}}}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Vector{Int64}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcasted), typeof(first), Vector{Vector{Int64}}}, Tuple{Zygote.var"#2138#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing}, Tuple{}}, Zygote.var"#combine_styles_pullback#1169"{Tuple{Nothing, Nothing}}}}, Zygote.var"#1982#back#200"{typeof(identity)}, Zygote.var"#2841#back#683"{Zygote.var"#map_back#677"{typeof(Base.Broadcast.broadcastable), 1, Tuple{Tuple{}}, Tuple{Val{0}}, Tuple{}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcastable), Vector{Vector{Int64}}}, Tuple{}}, Zygote.var"#1982#back#200"{typeof(identity)}, Zygote.var"#2138#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{}}, Zygote.var"#4086#back#1368"{Zygote.var"#∇broadcasted#1364"{Tuple{Vector{Vector{Int64}}}, Vector{Tuple{Int64, Zygote.Pullback{Tuple{typeof(first), Vector{Int64}}, Tuple{Zygote.var"#2571#back#528"{Zygote.var"#538#540"{1, Int64, Vector{Int64}, Tuple{Int64}}}, Zygote.Pullback{Tuple{typeof(first), Base.OneTo{Int64}}, Tuple{Zygote.Pullback{Tuple{typeof(oneunit), Type{Int64}}, Tuple{Zygote.ZBack{Zygote.var"#IntX_pullback#330"}, Zygote.ZBack{ChainRules.var"#one_pullback#792"}}}}}, Zygote.ZBack{ChainRules.var"#eachindex_pullback#375"{Tuple{Vector{Int64}}}}}}}}, Val{2}}}}}}}, Zygote.var"#2094#back#269"{Zygote.var"#264#268"{Tuple{Nothing}}}, Zygote.var"#1982#back#200"{typeof(identity)}}}}}, Zygote.Pullback{Tuple{typeof(NearestNeighbors._nn), KDTree{StaticArraysCore.SVector{2, Float32}, Euclidean, Float32}, Matrix{Float32}, typeof(NearestNeighbors.always_false)}, Tuple{Zygote.Pullback{Tuple{typeof(knn), KDTree{StaticArraysCore.SVector{2, Float32}, Euclidean, Float32}, Matrix{Float32}, Int64, Bool, typeof(NearestNeighbors.always_false)}, Any}}}}}}}, Zygote.var"#2987#back#777"{Zygote.var"#771#775"{Vector{Float32}}}, Zygote.var"#1955#back#190"{Zygote.var"#186#189"{Zygote.Context{false}, GlobalRef, KDTree{StaticArraysCore.SVector{2, Float32}, Euclidean, Float32}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Vector{Float32}}, Tuple{}}, Zygote.Pullback{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Matrix{Float32}}, Tuple{Zygote.var"#2149#back#299"{Zygote.var"#back#298"{:weight, Zygote.Context{false}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Matrix{Float32}}}, Zygote.ZBack{ChainRules.var"#size_pullback#918"}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Matrix{Float32}}, Tuple{}}, Zygote.ZBack{Flux.var"#_size_check_pullback#204"{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Matrix{Float32}, Pair{Int64, Int64}}}}, Zygote.var"#3878#back#1250"{Zygote.var"#1246#1249"}, Zygote.ZBack{ChainRules.var"#times_pullback#1486"{Matrix{Float32}, Matrix{Float32}}}, Zygote.var"#2149#back#299"{Zygote.var"#back#298"{:weight, Zygote.Context{false}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Matrix{Float32}}}, Zygote.ZBack{Flux.var"#177#178"}, Zygote.Pullback{Tuple{Type{Pair}, Int64, Int64}, Tuple{Zygote.ZBack{ChainRules.var"#fieldtype_pullback#420"}, Zygote.ZBack{ChainRules.var"#fieldtype_pullback#420"}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Int64}, Int64}, Tuple{}}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Int64}, Int64}, Tuple{}}, Zygote.var"#2176#back#309"{Zygote.Jnew{Pair{Int64, Int64}, Nothing, false}}}}, Zygote.Pullback{Tuple{typeof(NNlib.fast_act), typeof(identity), Matrix{Float32}}, Tuple{}}, Zygote.var"#2149#back#299"{Zygote.var"#back#298"{:bias, Zygote.Context{false}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Vector{Float32}}}, Zygote.var"#3734#back#1184"{Zygote.var"#1178#1182"{Tuple{Matrix{Float32}, Vector{Float32}}}}, Zygote.var"#2149#back#299"{Zygote.var"#back#298"{:σ, Zygote.Context{false}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, typeof(identity)}}}}, Zygote.var"#3862#back#1242"{Zygote.var"#1238#1241"{2, Vector{Float32}}}}}}}, Zygote.var"#2149#back#299"{Zygote.var"#back#298"{:d_splat, Zygote.Context{false}, Flux.Train.var"#4#5"{typeof(loss), Tuple{Matrix{Float32}, Vector{Float32}}}, Tuple{Matrix{Float32}, Vector{Float32}}}}, Zygote.var"#1982#back#200"{typeof(identity)}, Zygote.var"#2149#back#299"{Zygote.var"#back#298"{:loss, Zygote.Context{false}, Flux.Train.var"#4#5"{typeof(loss), Tuple{Matrix{Float32}, Vector{Float32}}}, typeof(loss)}}}})(Δ::Float32)
@ Zygote /opt/julia/packages/Zygote/TSj5C/src/compiler/interface2.jl:0
[22] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{Flux.Train.var"#4#5"{typeof(loss), Tuple{Matrix{Float32}, Vector{Float32}}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}, Tuple{Zygote.var"#2138#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing}, Tuple{Nothing, Nothing}}, Zygote.Pullback{Tuple{typeof(loss), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Matrix{Float32}, Vector{Float32}}, Tuple{Zygote.var"#1891#back#157"{Zygote.var"#153#156"}, Zygote.var"#1998#back#209"{Zygote.var"#back#207"{2, 2, Zygote.Context{false}, Vector{Float32}}}, Zygote.Pullback{Tuple{typeof(nn), KDTree{StaticArraysCore.SVector{2, Float32}, Euclidean, Float32}, Matrix{Float32}}, Tuple{Zygote.Pullback{Tuple{typeof(nn), KDTree{StaticArraysCore.SVector{2, Float32}, Euclidean, Float32}, Matrix{Float32}, typeof(NearestNeighbors.always_false)}, Tuple{Zygote.Pullback{Tuple{typeof(|>), Tuple{Vector{Vector{Int64}}, Vector{Vector{Float32}}}, typeof(NearestNeighbors._firsteach)}, Tuple{Zygote.Pullback{Tuple{typeof(NearestNeighbors._firsteach), Tuple{Vector{Vector{Int64}}, Vector{Vector{Float32}}}}, Tuple{Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Vector{Float32}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcasted), typeof(first), Vector{Vector{Float32}}}, Tuple{Zygote.var"#2138#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing}, Tuple{}}, Zygote.var"#combine_styles_pullback#1169"{Tuple{Nothing, Nothing}}}}, Zygote.var"#1982#back#200"{typeof(identity)}, Zygote.var"#2841#back#683"{Zygote.var"#map_back#677"{typeof(Base.Broadcast.broadcastable), 1, Tuple{Tuple{}}, Tuple{Val{0}}, Tuple{}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcastable), Vector{Vector{Float32}}}, Tuple{}}, Zygote.var"#1982#back#200"{typeof(identity)}, Zygote.var"#2138#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{}}, Zygote.var"#4086#back#1368"{Zygote.var"#∇broadcasted#1364"{Tuple{Vector{Vector{Float32}}}, Vector{Tuple{Float32, Zygote.Pullback{Tuple{typeof(first), Vector{Float32}}, Tuple{Zygote.var"#2571#back#528"{Zygote.var"#538#540"{1, Float32, Vector{Float32}, Tuple{Int64}}}, Zygote.Pullback{Tuple{typeof(first), Base.OneTo{Int64}}, Tuple{Zygote.Pullback{Tuple{typeof(oneunit), Type{Int64}}, Tuple{Zygote.ZBack{Zygote.var"#IntX_pullback#330"}, Zygote.ZBack{ChainRules.var"#one_pullback#792"}}}}}, Zygote.ZBack{ChainRules.var"#eachindex_pullback#375"{Tuple{Vector{Float32}}}}}}}}, Val{2}}}}}}}, Zygote.Pullback{Tuple{typeof(last), Tuple{Vector{Vector{Int64}}, Vector{Vector{Float32}}}}, Tuple{Zygote.var"#2014#back#218"{Zygote.var"#back#216"{2, Zygote.Context{false}, Int64, Vector{Vector{Float32}}}}, Zygote.Pullback{Tuple{typeof(lastindex), Tuple{Vector{Vector{Int64}}, Vector{Vector{Float32}}}}, Tuple{Zygote.ZBack{ChainRules.var"#length_pullback#746"}}}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Vector{Int64}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcasted), typeof(first), Vector{Vector{Int64}}}, Tuple{Zygote.var"#2138#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing}, Tuple{}}, Zygote.var"#combine_styles_pullback#1169"{Tuple{Nothing, Nothing}}}}, Zygote.var"#1982#back#200"{typeof(identity)}, Zygote.var"#2841#back#683"{Zygote.var"#map_back#677"{typeof(Base.Broadcast.broadcastable), 1, Tuple{Tuple{}}, Tuple{Val{0}}, Tuple{}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcastable), Vector{Vector{Int64}}}, Tuple{}}, Zygote.var"#1982#back#200"{typeof(identity)}, Zygote.var"#2138#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{}}, Zygote.var"#4086#back#1368"{Zygote.var"#∇broadcasted#1364"{Tuple{Vector{Vector{Int64}}}, Vector{Tuple{Int64, Zygote.Pullback{Tuple{typeof(first), Vector{Int64}}, Tuple{Zygote.var"#2571#back#528"{Zygote.var"#538#540"{1, Int64, Vector{Int64}, Tuple{Int64}}}, Zygote.Pullback{Tuple{typeof(first), Base.OneTo{Int64}}, Tuple{Zygote.Pullback{Tuple{typeof(oneunit), Type{Int64}}, Tuple{Zygote.ZBack{Zygote.var"#IntX_pullback#330"}, Zygote.ZBack{ChainRules.var"#one_pullback#792"}}}}}, Zygote.ZBack{ChainRules.var"#eachindex_pullback#375"{Tuple{Vector{Int64}}}}}}}}, Val{2}}}}}}}, Zygote.var"#2094#back#269"{Zygote.var"#264#268"{Tuple{Nothing}}}, Zygote.var"#1982#back#200"{typeof(identity)}}}}}, Zygote.Pullback{Tuple{typeof(NearestNeighbors._nn), KDTree{StaticArraysCore.SVector{2, Float32}, Euclidean, Float32}, Matrix{Float32}, typeof(NearestNeighbors.always_false)}, Tuple{Zygote.Pullback{Tuple{typeof(knn), KDTree{StaticArraysCore.SVector{2, Float32}, Euclidean, Float32}, Matrix{Float32}, Int64, Bool, typeof(NearestNeighbors.always_false)}, Any}}}}}}}, Zygote.var"#2987#back#777"{Zygote.var"#771#775"{Vector{Float32}}}, Zygote.var"#1955#back#190"{Zygote.var"#186#189"{Zygote.Context{false}, GlobalRef, KDTree{StaticArraysCore.SVector{2, Float32}, Euclidean, Float32}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Vector{Float32}}, Tuple{}}, Zygote.Pullback{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Matrix{Float32}}, Tuple{Zygote.var"#2149#back#299"{Zygote.var"#back#298"{:weight, Zygote.Context{false}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Matrix{Float32}}}, Zygote.ZBack{ChainRules.var"#size_pullback#918"}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Matrix{Float32}}, Tuple{}}, Zygote.ZBack{Flux.var"#_size_check_pullback#204"{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Matrix{Float32}, Pair{Int64, Int64}}}}, Zygote.var"#3878#back#1250"{Zygote.var"#1246#1249"}, Zygote.ZBack{ChainRules.var"#times_pullback#1486"{Matrix{Float32}, Matrix{Float32}}}, Zygote.var"#2149#back#299"{Zygote.var"#back#298"{:weight, Zygote.Context{false}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Matrix{Float32}}}, Zygote.ZBack{Flux.var"#177#178"}, Zygote.Pullback{Tuple{Type{Pair}, Int64, Int64}, Tuple{Zygote.ZBack{ChainRules.var"#fieldtype_pullback#420"}, Zygote.ZBack{ChainRules.var"#fieldtype_pullback#420"}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Int64}, Int64}, Tuple{}}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Int64}, Int64}, Tuple{}}, Zygote.var"#2176#back#309"{Zygote.Jnew{Pair{Int64, Int64}, Nothing, false}}}}, Zygote.Pullback{Tuple{typeof(NNlib.fast_act), typeof(identity), Matrix{Float32}}, Tuple{}}, Zygote.var"#2149#back#299"{Zygote.var"#back#298"{:bias, Zygote.Context{false}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Vector{Float32}}}, Zygote.var"#3734#back#1184"{Zygote.var"#1178#1182"{Tuple{Matrix{Float32}, Vector{Float32}}}}, Zygote.var"#2149#back#299"{Zygote.var"#back#298"{:σ, Zygote.Context{false}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, typeof(identity)}}}}, Zygote.var"#3862#back#1242"{Zygote.var"#1238#1241"{2, Vector{Float32}}}}}}}, Zygote.var"#2149#back#299"{Zygote.var"#back#298"{:d_splat, Zygote.Context{false}, Flux.Train.var"#4#5"{typeof(loss), Tuple{Matrix{Float32}, Vector{Float32}}}, Tuple{Matrix{Float32}, Vector{Float32}}}}, Zygote.var"#1982#back#200"{typeof(identity)}, Zygote.var"#2149#back#299"{Zygote.var"#back#298"{:loss, Zygote.Context{false}, Flux.Train.var"#4#5"{typeof(loss), Tuple{Matrix{Float32}, Vector{Float32}}}, typeof(loss)}}}}})(Δ::Float32)
@ Zygote /opt/julia/packages/Zygote/TSj5C/src/compiler/interface.jl:45
[23] withgradient(f::Function, args::Dense{typeof(identity), Matrix{Float32}, Vector{Float32}})
@ Zygote /opt/julia/packages/Zygote/TSj5C/src/compiler/interface.jl:133
[24] macro expansion
@ /opt/julia/packages/Flux/Nzh8J/src/train.jl:107 [inlined]
[25] macro expansion
@ /opt/julia/packages/ProgressLogging/6KXlp/src/ProgressLogging.jl:328 [inlined]
[26] train!(loss::Function, model::Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, data::Vector{Tuple{Matrix{Float32}, Vector{Float32}}}, opt::NamedTuple{(:weight, :bias, :σ), Tuple{Optimisers.Leaf{Optimisers.Descent{Float64}, Nothing}, Optimisers.Leaf{Optimisers.Descent{Float64}, Nothing}, Tuple{}}}; cb::Nothing)
@ Flux.Train /opt/julia/packages/Flux/Nzh8J/src/train.jl:105
[27] train!(loss::Function, model::Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, data::Vector{Tuple{Matrix{Float32}, Vector{Float32}}}, rule::Optimisers.Descent{Float64}; cb::Nothing)
@ Flux.Train /opt/julia/packages/Flux/Nzh8J/src/train.jl:118
[28] train!(loss::Function, model::Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, data::Vector{Tuple{Matrix{Float32}, Vector{Float32}}}, opt::Descent; cb::Nothing)
@ Flux /opt/julia/packages/Flux/Nzh8J/src/deprecations.jl:126
[29] train!(loss::Function, model::Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, data::Vector{Tuple{Matrix{Float32}, Vector{Float32}}}, opt::Descent)
@ Flux /opt/julia/packages/Flux/Nzh8J/src/deprecations.jl:126
[30] top-level scope
@ In[42]:1