QuantumBFS / Yao.jl

Extensible, Efficient Quantum Algorithm Design for Humans.

Home Page:https://yaoquantum.org

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Gradient for the real part of a quantum state

ChenZhao44 opened this issue · comments

using Yao
include("zygote_patch.jl")
circ = put(1, 1=> Rx(0))
function loss(circ, α, β)
    θ = α + β
    circ = dispatch!(circ, [θ]) # this works
    ψ = zero_state(1)
    ψ = apply!(ψ, circ)
    return imag.(state(ψ)*im)
end

By using jacobian((a, b) -> loss(circ, a, b), 1, 2), I can get the gradient.

However, if I change the loss function to

function loss(circ, α, β)
    θ = α + β
    circ = dispatch!(circ, [θ]) # this works
    ψ = zero_state(1)
    ψ = apply!(ψ, circ)
    return real.(state(ψ))
end

It will throw an error, although they seem equivalent.

┌ Warning: Input type of `ArrayReg` is not Complex, got Float64
└ @ YaoArrayRegister /Users/chenzhao/.julia/packages/YaoArrayRegister/UxWWn/src/register.jl:58
InexactError: Float64(0.0 + 0.9974949866040544im)

Stacktrace:
  [1] Real
    @ ./complex.jl:37 [inlined]
  [2] convert
    @ ./number.jl:7 [inlined]
  [3] setindex!
    @ ./array.jl:841 [inlined]
  [4] u1rows!
    @ ~/.julia/packages/YaoArrayRegister/UxWWn/src/utils.jl:122 [inlined]
  [5] instruct_kernel
    @ ~/.julia/packages/YaoArrayRegister/UxWWn/src/instruct.jl:195 [inlined]
  [6] instruct!(state::Matrix{Float64}, #unused#::Val{:Rx}, ::Tuple{Int64}, theta::Float64)
    @ YaoArrayRegister ~/.julia/packages/YaoArrayRegister/UxWWn/src/instruct.jl:379
  [7] instruct!(r::ArrayReg{1, Float64, Matrix{Float64}}, op::Val{:Rx}, locs::Tuple{Int64}, theta::Float64)
    @ YaoArrayRegister ~/.julia/packages/YaoArrayRegister/UxWWn/src/instruct.jl:52
  [8] _apply!
    @ ~/.julia/packages/YaoBlocks/TIlDJ/src/composite/put_block.jl:170 [inlined]
  [9] apply!
    @ ~/.julia/packages/YaoBlocks/TIlDJ/src/abstract_block.jl:10 [inlined]
 [10] apply_back!(st::Tuple{ArrayReg{1, ComplexF64, Matrix{ComplexF64}}, ArrayReg{1, Float64, Matrix{Float64}}}, block::PutBlock{1, 1, RotationGate{1, Float64, XGate}}, collector::Vector{Any})
    @ YaoBlocks.AD ~/.julia/packages/YaoBlocks/TIlDJ/src/autodiff/apply_back.jl:44
 [11] #apply_back#14
    @ ~/.julia/packages/YaoBlocks/TIlDJ/src/autodiff/apply_back.jl:151 [inlined]
 [12] apply_back
    @ ~/.julia/packages/YaoBlocks/TIlDJ/src/autodiff/apply_back.jl:150 [inlined]
 [13] #81
    @ ~/Desktop/YaoAD/zygote_patch.jl:8 [inlined]
 [14] (::var"#326#back#83"{var"#81#82"{PutBlock{1, 1, RotationGate{1, Float64, XGate}}, ArrayReg{1, ComplexF64, Matrix{ComplexF64}}}})(Δ::ArrayReg{1, Float64, Matrix{Float64}})
    @ Main ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
 [15] Pullback
    @ ./In[67]:5 [inlined]
 [16] (::typeof(∂(loss)))(Δ::Matrix{Float64})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
 [17] Pullback
    @ ./In[68]:2 [inlined]
 [18] #180
    @ ~/.julia/packages/Zygote/6HN9x/src/lib/lib.jl:194 [inlined]
 [19] #1689#back
    @ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
 [20] Pullback
    @ ./operators.jl:938 [inlined]
 [21] (::Zygote.var"#41#42"{typeof(∂(ComposedFunction{typeof(Zygote._jvec), var"#179#180"}(Zygote._jvec, var"#179#180"())))})(Δ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:41
 [22] jacobian(::Function, ::Int64, ::Int64)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/lib/grad.jl:148
 [23] top-level scope
    @ In[68]:2
 [24] eval
    @ ./boot.jl:360 [inlined]
 [25] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)
    @ Base ./loading.jl:1116

Thanks for your issue, I just ported the AD rules to ChainRulesCore, can you try the patch in lastest master of QuAlgorithmZoo?

using Yao
include("chainrules_patch.jl")
using Zygote
circ = put(1, 1=> Rx(0))
function loss(circ, α, β)
    θ = α + β
    circ = dispatch!(circ, [θ]) # this works
    ψ = zero_state(1)
    ψ = apply!(ψ, circ)
    return real.(state(ψ)*im)
end
jacobian((a, b) -> loss(circ, a, b), 1, 2)

Still get errors when trying

function loss(circ, α, β)
    θ = α + β
    circ = dispatch!(circ, [θ]) # this works
    ψ = zero_state(1)
    ψ = apply!(ψ, circ)
#    return imag.(state(ψ) * im) # works
    return real.(state(ψ)) # error
end
jacobian((a, b) -> loss2(circ, a, b), 1, 2)

imag.(state(ψ) * im) should be equivalent to real.(state(ψ)).

┌ Warning: Input type of `ArrayReg` is not Complex, got Float64
└ @ YaoArrayRegister /Users/chenzhao/.julia/packages/YaoArrayRegister/UxWWn/src/register.jl:58
InexactError: Float64(0.0 + 0.9974949866040544im)

Stacktrace:
  [1] Real
    @ ./complex.jl:37 [inlined]
  [2] convert
    @ ./number.jl:7 [inlined]
  [3] setindex!
    @ ./array.jl:841 [inlined]
  [4] u1rows!
    @ ~/.julia/packages/YaoArrayRegister/UxWWn/src/utils.jl:122 [inlined]
  [5] instruct_kernel
    @ ~/.julia/packages/YaoArrayRegister/UxWWn/src/instruct.jl:195 [inlined]
  [6] instruct!(state::Matrix{Float64}, #unused#::Val{:Rx}, ::Tuple{Int64}, theta::Float64)
    @ YaoArrayRegister ~/.julia/packages/YaoArrayRegister/UxWWn/src/instruct.jl:379
  [7] instruct!(r::ArrayReg{1, Float64, Matrix{Float64}}, op::Val{:Rx}, locs::Tuple{Int64}, theta::Float64)
    @ YaoArrayRegister ~/.julia/packages/YaoArrayRegister/UxWWn/src/instruct.jl:52
  [8] _apply!
    @ ~/.julia/packages/YaoBlocks/TIlDJ/src/composite/put_block.jl:170 [inlined]
  [9] apply!
    @ ~/.julia/packages/YaoBlocks/TIlDJ/src/abstract_block.jl:10 [inlined]
 [10] apply_back!(st::Tuple{ArrayReg{1, ComplexF64, Matrix{ComplexF64}}, ArrayReg{1, Float64, Matrix{Float64}}}, block::PutBlock{1, 1, RotationGate{1, Float64, XGate}}, collector::Vector{Any})
    @ YaoBlocks.AD ~/.julia/packages/YaoBlocks/TIlDJ/src/autodiff/apply_back.jl:44
 [11] #apply_back#14
    @ ~/.julia/packages/YaoBlocks/TIlDJ/src/autodiff/apply_back.jl:151 [inlined]
 [12] apply_back
    @ ~/.julia/packages/YaoBlocks/TIlDJ/src/autodiff/apply_back.jl:150 [inlined]
 [13] (::var"#1#2"{PutBlock{1, 1, RotationGate{1, Float64, XGate}}, ArrayReg{1, ComplexF64, Matrix{ComplexF64}}})(outδ::ArrayReg{1, Float64, Matrix{Float64}})
    @ Main ~/Desktop/YaoAD/chainrules_patch.jl:7
 [14] ZBack
    @ ~/.julia/packages/Zygote/6HN9x/src/compiler/chainrules.jl:77 [inlined]
 [15] Pullback
    @ ./In[3]:5 [inlined]
 [16] (::typeof((loss2)))(Δ::Matrix{Float64})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
 [17] Pullback
    @ ./In[3]:8 [inlined]
 [18] #180
    @ ~/.julia/packages/Zygote/6HN9x/src/lib/lib.jl:194 [inlined]
 [19] #1729#back
    @ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
 [20] Pullback
    @ ./operators.jl:938 [inlined]
 [21] (::Zygote.var"#41#42"{typeof((ComposedFunction{typeof(Zygote._jvec), var"#62#63"}(Zygote._jvec, var"#62#63"())))})(Δ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:41
 [22] jacobian(::Function, ::Int64, ::Int64)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/lib/grad.jl:148
 [23] top-level scope
    @ In[3]:8
 [24] eval
    @ ./boot.jl:360 [inlined]
 [25] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)
    @ Base ./loading.jl:1116

Maybe it is a problem of Zygote rule, can you check your Zygote version?
You should not have this warning

┌ Warning: Input type of `ArrayReg` is not Complex, got Float64

My version is

(@v1.7) pkg> st Zygote
      Status `~/.julia/environments/v1.7/Project.toml`
  [e88e6eb3] Zygote v0.6.19

I still have this warning after updating Zygote to v0.6.19.

I see, the problem is related to the gradient of the real function

julia> gradient(x -> real.(x)[1], [1+2.0im])
([1.0],)

I made a quick fix to the chainrules_patch.jl, it should work now.