Gradient for the real part of a quantum state
ChenZhao44 opened this issue · comments
using Yao
include("zygote_patch.jl")
circ = put(1, 1=> Rx(0))
function loss(circ, α, β)
θ = α + β
circ = dispatch!(circ, [θ]) # this works
ψ = zero_state(1)
ψ = apply!(ψ, circ)
return imag.(state(ψ)*im)
end
By using jacobian((a, b) -> loss(circ, a, b), 1, 2)
, I can get the gradient.
However, if I change the loss function to
function loss(circ, α, β)
θ = α + β
circ = dispatch!(circ, [θ]) # this works
ψ = zero_state(1)
ψ = apply!(ψ, circ)
return real.(state(ψ))
end
It will throw an error, although they seem equivalent.
┌ Warning: Input type of `ArrayReg` is not Complex, got Float64
└ @ YaoArrayRegister /Users/chenzhao/.julia/packages/YaoArrayRegister/UxWWn/src/register.jl:58
InexactError: Float64(0.0 + 0.9974949866040544im)
Stacktrace:
[1] Real
@ ./complex.jl:37 [inlined]
[2] convert
@ ./number.jl:7 [inlined]
[3] setindex!
@ ./array.jl:841 [inlined]
[4] u1rows!
@ ~/.julia/packages/YaoArrayRegister/UxWWn/src/utils.jl:122 [inlined]
[5] instruct_kernel
@ ~/.julia/packages/YaoArrayRegister/UxWWn/src/instruct.jl:195 [inlined]
[6] instruct!(state::Matrix{Float64}, #unused#::Val{:Rx}, ::Tuple{Int64}, theta::Float64)
@ YaoArrayRegister ~/.julia/packages/YaoArrayRegister/UxWWn/src/instruct.jl:379
[7] instruct!(r::ArrayReg{1, Float64, Matrix{Float64}}, op::Val{:Rx}, locs::Tuple{Int64}, theta::Float64)
@ YaoArrayRegister ~/.julia/packages/YaoArrayRegister/UxWWn/src/instruct.jl:52
[8] _apply!
@ ~/.julia/packages/YaoBlocks/TIlDJ/src/composite/put_block.jl:170 [inlined]
[9] apply!
@ ~/.julia/packages/YaoBlocks/TIlDJ/src/abstract_block.jl:10 [inlined]
[10] apply_back!(st::Tuple{ArrayReg{1, ComplexF64, Matrix{ComplexF64}}, ArrayReg{1, Float64, Matrix{Float64}}}, block::PutBlock{1, 1, RotationGate{1, Float64, XGate}}, collector::Vector{Any})
@ YaoBlocks.AD ~/.julia/packages/YaoBlocks/TIlDJ/src/autodiff/apply_back.jl:44
[11] #apply_back#14
@ ~/.julia/packages/YaoBlocks/TIlDJ/src/autodiff/apply_back.jl:151 [inlined]
[12] apply_back
@ ~/.julia/packages/YaoBlocks/TIlDJ/src/autodiff/apply_back.jl:150 [inlined]
[13] #81
@ ~/Desktop/YaoAD/zygote_patch.jl:8 [inlined]
[14] (::var"#326#back#83"{var"#81#82"{PutBlock{1, 1, RotationGate{1, Float64, XGate}}, ArrayReg{1, ComplexF64, Matrix{ComplexF64}}}})(Δ::ArrayReg{1, Float64, Matrix{Float64}})
@ Main ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
[15] Pullback
@ ./In[67]:5 [inlined]
[16] (::typeof(∂(loss)))(Δ::Matrix{Float64})
@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
[17] Pullback
@ ./In[68]:2 [inlined]
[18] #180
@ ~/.julia/packages/Zygote/6HN9x/src/lib/lib.jl:194 [inlined]
[19] #1689#back
@ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
[20] Pullback
@ ./operators.jl:938 [inlined]
[21] (::Zygote.var"#41#42"{typeof(∂(ComposedFunction{typeof(Zygote._jvec), var"#179#180"}(Zygote._jvec, var"#179#180"())))})(Δ::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:41
[22] jacobian(::Function, ::Int64, ::Int64)
@ Zygote ~/.julia/packages/Zygote/6HN9x/src/lib/grad.jl:148
[23] top-level scope
@ In[68]:2
[24] eval
@ ./boot.jl:360 [inlined]
[25] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)
@ Base ./loading.jl:1116
Thanks for your issue, I just ported the AD rules to ChainRulesCore, can you try the patch in lastest master of QuAlgorithmZoo?
using Yao
include("chainrules_patch.jl")
using Zygote
circ = put(1, 1=> Rx(0))
function loss(circ, α, β)
θ = α + β
circ = dispatch!(circ, [θ]) # this works
ψ = zero_state(1)
ψ = apply!(ψ, circ)
return real.(state(ψ)*im)
end
jacobian((a, b) -> loss(circ, a, b), 1, 2)
Still get errors when trying
function loss(circ, α, β)
θ = α + β
circ = dispatch!(circ, [θ]) # this works
ψ = zero_state(1)
ψ = apply!(ψ, circ)
# return imag.(state(ψ) * im) # works
return real.(state(ψ)) # error
end
jacobian((a, b) -> loss2(circ, a, b), 1, 2)
imag.(state(ψ) * im)
should be equivalent to real.(state(ψ))
.
┌ Warning: Input type of `ArrayReg` is not Complex, got Float64
└ @ YaoArrayRegister /Users/chenzhao/.julia/packages/YaoArrayRegister/UxWWn/src/register.jl:58
InexactError: Float64(0.0 + 0.9974949866040544im)
Stacktrace:
[1] Real
@ ./complex.jl:37 [inlined]
[2] convert
@ ./number.jl:7 [inlined]
[3] setindex!
@ ./array.jl:841 [inlined]
[4] u1rows!
@ ~/.julia/packages/YaoArrayRegister/UxWWn/src/utils.jl:122 [inlined]
[5] instruct_kernel
@ ~/.julia/packages/YaoArrayRegister/UxWWn/src/instruct.jl:195 [inlined]
[6] instruct!(state::Matrix{Float64}, #unused#::Val{:Rx}, ::Tuple{Int64}, theta::Float64)
@ YaoArrayRegister ~/.julia/packages/YaoArrayRegister/UxWWn/src/instruct.jl:379
[7] instruct!(r::ArrayReg{1, Float64, Matrix{Float64}}, op::Val{:Rx}, locs::Tuple{Int64}, theta::Float64)
@ YaoArrayRegister ~/.julia/packages/YaoArrayRegister/UxWWn/src/instruct.jl:52
[8] _apply!
@ ~/.julia/packages/YaoBlocks/TIlDJ/src/composite/put_block.jl:170 [inlined]
[9] apply!
@ ~/.julia/packages/YaoBlocks/TIlDJ/src/abstract_block.jl:10 [inlined]
[10] apply_back!(st::Tuple{ArrayReg{1, ComplexF64, Matrix{ComplexF64}}, ArrayReg{1, Float64, Matrix{Float64}}}, block::PutBlock{1, 1, RotationGate{1, Float64, XGate}}, collector::Vector{Any})
@ YaoBlocks.AD ~/.julia/packages/YaoBlocks/TIlDJ/src/autodiff/apply_back.jl:44
[11] #apply_back#14
@ ~/.julia/packages/YaoBlocks/TIlDJ/src/autodiff/apply_back.jl:151 [inlined]
[12] apply_back
@ ~/.julia/packages/YaoBlocks/TIlDJ/src/autodiff/apply_back.jl:150 [inlined]
[13] (::var"#1#2"{PutBlock{1, 1, RotationGate{1, Float64, XGate}}, ArrayReg{1, ComplexF64, Matrix{ComplexF64}}})(outδ::ArrayReg{1, Float64, Matrix{Float64}})
@ Main ~/Desktop/YaoAD/chainrules_patch.jl:7
[14] ZBack
@ ~/.julia/packages/Zygote/6HN9x/src/compiler/chainrules.jl:77 [inlined]
[15] Pullback
@ ./In[3]:5 [inlined]
[16] (::typeof(∂(loss2)))(Δ::Matrix{Float64})
@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
[17] Pullback
@ ./In[3]:8 [inlined]
[18] #180
@ ~/.julia/packages/Zygote/6HN9x/src/lib/lib.jl:194 [inlined]
[19] #1729#back
@ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
[20] Pullback
@ ./operators.jl:938 [inlined]
[21] (::Zygote.var"#41#42"{typeof(∂(ComposedFunction{typeof(Zygote._jvec), var"#62#63"}(Zygote._jvec, var"#62#63"())))})(Δ::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:41
[22] jacobian(::Function, ::Int64, ::Int64)
@ Zygote ~/.julia/packages/Zygote/6HN9x/src/lib/grad.jl:148
[23] top-level scope
@ In[3]:8
[24] eval
@ ./boot.jl:360 [inlined]
[25] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)
@ Base ./loading.jl:1116
Maybe it is a problem of Zygote rule, can you check your Zygote version?
You should not have this warning
┌ Warning: Input type of `ArrayReg` is not Complex, got Float64
My version is
(@v1.7) pkg> st Zygote
Status `~/.julia/environments/v1.7/Project.toml`
[e88e6eb3] Zygote v0.6.19
I still have this warning after updating Zygote to v0.6.19.
I see, the problem is related to the gradient of the real
function
julia> gradient(x -> real.(x)[1], [1+2.0im])
([1.0],)
I made a quick fix to the chainrules_patch.jl
, it should work now.