Spurious "Output is complex, so the gradient is not defined" error
seadra opened this issue · comments
(Cross-posting this DiffEqFlux issue here since this seems to be Zygote related)
With the code
using Lux, Random, DifferentialEquations, ComponentArrays, Optimization, OptimizationFlux, Lux, Zygote, LinearAlgebra, DiffEqFlux
const T = 10.0;
const ω = π/T;
ann = Lux.Chain(Lux.Dense(1,32), Lux.Dense(32,32,tanh), Lux.Dense(32,1));
rng = Random.default_rng();
ip, st = Lux.setup(rng, ann);
function f_nn(u, p, t)
a = ann([t],p,st)[1];
A = [1.0 a; a -1.0];
return -im*A*u;
end
u0 = [Complex{Float64}(1) 0; 0 1];
tspan = (0.0, T)
prob_ode = ODEProblem(f_nn, u0, tspan, ip);
utarget = [Complex{Float64}(0) im; im 0];
function predict_adjoint(p)
return solve(prob_ode, Tsit5(), p=Complex{Float64}.(p), abstol=1e-12, reltol=1e-12)
end
function loss_adjoint(p)
prediction = predict_adjoint(p)
usol = last(prediction)
loss = 1.0 - abs(tr(usol*utarget')/2)^2
return loss
end
opt_f = Optimization.OptimizationFunction((x, p) -> loss_adjoint(x), Optimization.AutoZygote());
opt_prob = Optimization.OptimizationProblem(opt_f, ComponentArray(ip));
optimized_sol_nn = Optimization.solve(opt_prob, ADAM(0.1), maxiters = 100, progress=true);
I'm getting the following error
┌ Warning: ZygoteVJP tried and failed in the automated AD choice algorithm with the following error. (To turn off this printing, add `verbose = false` to the `solve` call)
└ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/M3EmS/src/concrete_solve.jl:94
Output is complex, so the gradient is not defined.┌ Warning: ReverseDiffVJP tried and failed in the automated AD choice algorithm with the following error. (To turn off this printing, add `verbose = false` to the `solve` call)
└ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/M3EmS/src/concrete_solve.jl:111
MethodError: no method matching gradient(::SciMLSensitivity.var"#258#264"{ODEProblem{Matrix{ComplexF64}, Tuple{Float64, Float64}, false, ComponentVector{ComplexF64, Vector{ComplexF64}, Tuple{Axis{(layer_1 = ViewAxis(1:64, Axis(weight = ViewAxis(1:32, ShapedAxis((32, 1), NamedTuple())), bias = ViewAxis(33:64, ShapedAxis((32, 1), NamedTuple())))), layer_2 = ViewAxis(65:1120, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_3 = ViewAxis(1121:1153, Axis(weight = ViewAxis(1:32, ShapedAxis((1, 32), NamedTuple())), bias = ViewAxis(33:33, ShapedAxis((1, 1), NamedTuple())))))}}}, ODEFunction{false, SciMLBase.AutoSpecialize, typeof(f_nn), UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}}, ::Matrix{ComplexF64}, ::ComponentVector{ComplexF64, Vector{ComplexF64}, Tuple{Axis{(layer_1 = ViewAxis(1:64, Axis(weight = ViewAxis(1:32, ShapedAxis((32, 1), NamedTuple())), bias = ViewAxis(33:64, ShapedAxis((32, 1), NamedTuple())))), layer_2 = ViewAxis(65:1120, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_3 = ViewAxis(1121:1153, Axis(weight = ViewAxis(1:32, ShapedAxis((1, 32), NamedTuple())), bias = ViewAxis(33:33, ShapedAxis((1, 1), NamedTuple())))))}}})
Closest candidates are:
gradient(::Any, ::Any)
@ ReverseDiff ~/.julia/packages/ReverseDiff/UJhiD/src/api/gradients.jl:21
gradient(::Any, ::Any, ::ReverseDiff.GradientConfig)
@ ReverseDiff ~/.julia/packages/ReverseDiff/UJhiD/src/api/gradients.jl:21
┌ Warning: TrackerVJP tried and failed in the automated AD choice algorithm with the following error. (To turn off this printing, add `verbose = false` to the `solve` call)
└ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/M3EmS/src/concrete_solve.jl:129
Function output is not scalar┌ Warning: Reverse-Mode AD VJP choices all failed. Falling back to numerical VJPs
└ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/M3EmS/src/concrete_solve.jl:139
MethodError: no method matching default_relstep(::Nothing, ::Type{ComplexF64})
Closest candidates are:
default_relstep(::Type, ::Any)
@ FiniteDiff ~/.julia/packages/FiniteDiff/grio1/src/epsilons.jl:25
default_relstep(::Val{fdtype}, ::Type{T}) where {fdtype, T<:Number}
@ FiniteDiff ~/.julia/packages/FiniteDiff/grio1/src/epsilons.jl:26
I tried explicitly casting the output of the loss as return real(loss)
but that didn't help either.
A phrase in the error does appear in this package:
Zygote.jl/src/compiler/interface.jl
Line 66 in f24b9b2
Can you post the full stack trace, showing lines where any Zygote code is being called?
Or better, can you reproduce this error with just Zygote?
It turns out that I made a couple of mistakes. First, I need to add using DiffEqSensitivity
, second
a = ann([t],p,st)[1];
A = [1.0 a; a -1.0];
should read
local a, _ = ann([t/T],p,st);
local A = [a[1] 0.0; 0.0 -a[1]];
After these changes, it worked!