Type unstable gradients
irisallevi opened this issue · comments
irisallevi commented
I copy-pasted this lines from the documentation https://fluxml.ai/Zygote.jl/stable/#Explicit-and-Implicit-Parameters-1
struct Linear
W
b
end
(l::Linear)(x) = l.W * x .+ l.b
model = Linear(rand(2, 5), rand(2))
x = rand(5)
Why does the gradient seem type unstable?
@code_warntype gradient(model -> sum(model(x)), model)
MethodInstance for Zygote.gradient(::var"#11#12", ::Linear)
from gradient(f, args...) in Zygote at [...]
Arguments
#self#e[36m::Core.Const(Zygote.gradient)e[39m
fe[36m::Core.Const(var"#11#12"())e[39m
argse[36m::Tuple{Linear}e[39m
Locals
@_4e[36m::Int64e[39m
grade[33me[1m::Union{Nothing, Tuple}e[22me[39m
backe[91me[1m::Zygote.var"#75#76"e[22me[39m
ye[91me[1m::Anye[22me[39m
Bodye[33me[1m::Union{Nothing, Tuple{Any}}e[22me[39m
e[90m1 ─e[39m %1 = Core.tuple(f)e[36m::Core.Const((var"#11#12"(),))e[39m
e[90m│ e[39m %2 = Core._apply_iterate(Base.iterate, Zygote.pullback, %1, args)e[91me[1m::Tuple{Any, Zygote.var"#75#76"}e[22me[39m
e[90m│ e[39m %3 = Base.indexed_iterate(%2, 1)e[36m::Core.PartialStruct(Tuple{Any, Int64}, Any[Any, Core.Const(2)])e[39m
e[90m│ e[39m (y = Core.getfield(%3, 1))
e[90m│ e[39m (@_4 = Core.getfield(%3, 2))
e[90m│ e[39m %6 = Base.indexed_iterate(%2, 2, @_4::Core.Const(2))e[36m::Core.PartialStruct(Tuple{Zygote.var"#75#76", Int64}, Any[Zygote.var"#75#76", Core.Const(3)])e[39m
e[90m│ e[39m (back = Core.getfield(%6, 1))
e[90m│ e[39m %8 = Zygote.sensitivity(y)e[91me[1m::Anye[22me[39m
e[90m│ e[39m (grad = (back)(%8))
e[90m│ e[39m %10 = Zygote.isnothing(grad)e[36m::Boole[39m
e[90m└──e[39m goto #3 if not %10
e[90m2 ─e[39m return Zygote.nothing
e[90m3 ─e[39m %13 = Zygote.map(Zygote._project, args, grad::Tuple)e[91me[1m::Tuple{Any}e[22me[39m
e[90m└──e[39m return %13
Brian Chen commented
Closing for the same reason (question is answered on Discourse) as #1476.