FluxML / Zygote.jl

21st century AD

Home Page:https://fluxml.ai/Zygote.jl/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Spurious "Output is complex, so the gradient is not defined" error

seadra opened this issue · comments

commented

(Cross-posting this DiffEqFlux issue here since this seems to be Zygote related)

With the code

using Lux, Random, DifferentialEquations, ComponentArrays, Optimization, OptimizationFlux, Lux, Zygote, LinearAlgebra, DiffEqFlux

const T = 10.0;
const ω = π/T;

ann = Lux.Chain(Lux.Dense(1,32), Lux.Dense(32,32,tanh), Lux.Dense(32,1));
rng = Random.default_rng();
ip, st = Lux.setup(rng, ann);

function f_nn(u, p, t)
    a = ann([t],p,st)[1];
    A = [1.0 a; a -1.0];
    return -im*A*u;
end


u0 = [Complex{Float64}(1) 0; 0 1];
tspan = (0.0, T)
prob_ode = ODEProblem(f_nn, u0, tspan, ip);

utarget = [Complex{Float64}(0) im; im 0];

function predict_adjoint(p)
  return solve(prob_ode, Tsit5(), p=Complex{Float64}.(p), abstol=1e-12, reltol=1e-12)
end

function loss_adjoint(p)
    prediction = predict_adjoint(p)
    usol = last(prediction)
    loss = 1.0 - abs(tr(usol*utarget')/2)^2
    return loss
end

opt_f = Optimization.OptimizationFunction((x, p) -> loss_adjoint(x), Optimization.AutoZygote());
opt_prob = Optimization.OptimizationProblem(opt_f, ComponentArray(ip));
optimized_sol_nn = Optimization.solve(opt_prob, ADAM(0.1), maxiters = 100, progress=true);

I'm getting the following error

┌ Warning: ZygoteVJP tried and failed in the automated AD choice algorithm with the following error. (To turn off this printing, add `verbose = false` to the `solve` call)
└ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/M3EmS/src/concrete_solve.jl:94
Output is complex, so the gradient is not defined.┌ Warning: ReverseDiffVJP tried and failed in the automated AD choice algorithm with the following error. (To turn off this printing, add `verbose = false` to the `solve` call)
└ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/M3EmS/src/concrete_solve.jl:111
MethodError: no method matching gradient(::SciMLSensitivity.var"#258#264"{ODEProblem{Matrix{ComplexF64}, Tuple{Float64, Float64}, false, ComponentVector{ComplexF64, Vector{ComplexF64}, Tuple{Axis{(layer_1 = ViewAxis(1:64, Axis(weight = ViewAxis(1:32, ShapedAxis((32, 1), NamedTuple())), bias = ViewAxis(33:64, ShapedAxis((32, 1), NamedTuple())))), layer_2 = ViewAxis(65:1120, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_3 = ViewAxis(1121:1153, Axis(weight = ViewAxis(1:32, ShapedAxis((1, 32), NamedTuple())), bias = ViewAxis(33:33, ShapedAxis((1, 1), NamedTuple())))))}}}, ODEFunction{false, SciMLBase.AutoSpecialize, typeof(f_nn), UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}}, ::Matrix{ComplexF64}, ::ComponentVector{ComplexF64, Vector{ComplexF64}, Tuple{Axis{(layer_1 = ViewAxis(1:64, Axis(weight = ViewAxis(1:32, ShapedAxis((32, 1), NamedTuple())), bias = ViewAxis(33:64, ShapedAxis((32, 1), NamedTuple())))), layer_2 = ViewAxis(65:1120, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_3 = ViewAxis(1121:1153, Axis(weight = ViewAxis(1:32, ShapedAxis((1, 32), NamedTuple())), bias = ViewAxis(33:33, ShapedAxis((1, 1), NamedTuple())))))}}})

Closest candidates are:
  gradient(::Any, ::Any)
   @ ReverseDiff ~/.julia/packages/ReverseDiff/UJhiD/src/api/gradients.jl:21
  gradient(::Any, ::Any, ::ReverseDiff.GradientConfig)
   @ ReverseDiff ~/.julia/packages/ReverseDiff/UJhiD/src/api/gradients.jl:21
┌ Warning: TrackerVJP tried and failed in the automated AD choice algorithm with the following error. (To turn off this printing, add `verbose = false` to the `solve` call)
└ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/M3EmS/src/concrete_solve.jl:129
Function output is not scalar┌ Warning: Reverse-Mode AD VJP choices all failed. Falling back to numerical VJPs
└ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/M3EmS/src/concrete_solve.jl:139

MethodError: no method matching default_relstep(::Nothing, ::Type{ComplexF64})

Closest candidates are:
  default_relstep(::Type, ::Any)
   @ FiniteDiff ~/.julia/packages/FiniteDiff/grio1/src/epsilons.jl:25
  default_relstep(::Val{fdtype}, ::Type{T}) where {fdtype, T<:Number}
   @ FiniteDiff ~/.julia/packages/FiniteDiff/grio1/src/epsilons.jl:26

I tried explicitly casting the output of the loss as return real(loss) but that didn't help either.

A phrase in the error does appear in this package:

sensitivity(y::Complex) = error("Output is complex, so the gradient is not defined.")

Can you post the full stack trace, showing lines where any Zygote code is being called?

Or better, can you reproduce this error with just Zygote?

commented

It turns out that I made a couple of mistakes. First, I need to add using DiffEqSensitivity, second

    a = ann([t],p,st)[1];
    A = [1.0 a; a -1.0];

should read

    local a, _ = ann([t/T],p,st);
    local A = [a[1] 0.0; 0.0 -a[1]];

After these changes, it worked!