dfdx / Ghost.jl

The Code Tracer

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Key error with NNlib.DenseConvDims

jw3126 opened this issue · comments

julia> using Ghost

julia> using NNlib: DenseConvDims

julia> function f(x)
           DenseConvDims(zeros(2,2), zeros(2,3), stride=1)
           return x
       end
f (generic function with 2 methods)

julia> Ghost.trace(f, 1.0)
ERROR: KeyError: key 19 not found
Stacktrace:
  [1] getindex
    @ ./dict.jl:482 [inlined]
  [2] get_tape_vars(t::Ghost.IRTracer, farg_irvars::Vector{IRTools.Inner.Variable})
    @ Ghost ~/.julia/packages/Ghost/EcBCA/src/trace.jl:136
  [3] set_return!(t::Ghost.IRTracer, arg_sid_ref::Base.RefValue{IRTools.Inner.Variable})
    @ Ghost ~/.julia/packages/Ghost/EcBCA/src/trace.jl:187
  [4] IRTracer
    @ ~/.julia/packages/NNlib/zo8Ev/src/dim_helpers/DenseConvDims.jl:52 [inlined]
  [5] (::Ghost.IRTracer)(::NNlib.var"##DenseConvDims#7", ::Base.Iterators.Pairs{Symbol, Int64, Tu
ple{Symbol}, NamedTuple{(:stride,), Tuple{Int64}}}, ::Type{DenseConvDims}, ::Matrix{Float64}, ::M
atrix{Float64})
    @ Ghost ~/.julia/packages/IRTools/46viC/src/reflection/dynamo.jl:0
  [6] record_or_recurse!(::Ghost.IRTracer, ::Int64, ::Vector{Any}, ::Function, ::Vararg{Any, N} w
here N)
    @ Ghost ~/.julia/packages/Ghost/EcBCA/src/trace.jl:454
  [7] IRTracer
    @ ~/.julia/packages/NNlib/zo8Ev/src/dim_helpers/DenseConvDims.jl:49 [inlined]
  [8] (::Ghost.IRTracer)(::Core.var"#Type##kw", ::NamedTuple{(:stride,), Tuple{Int64}}, ::Type{De
nseConvDims}, ::Matrix{Float64}, ::Matrix{Float64})
    @ Ghost ~/.julia/packages/IRTools/46viC/src/reflection/dynamo.jl:0
  [9] record_or_recurse!(::Ghost.IRTracer, ::Int64, ::Vector{Any}, ::Function, ::Vararg{Any, N} w
here N)
    @ Ghost ~/.julia/packages/Ghost/EcBCA/src/trace.jl:454
 [10] IRTracer
    @ ./REPL[9]:2 [inlined]
 [11] (::Ghost.IRTracer)(::typeof(f), ::Float64)
    @ Ghost ~/.julia/packages/IRTools/46viC/src/reflection/dynamo.jl:0
 [12] trace(f::Function, args::Float64; is_primitive::Function, primitives::Nothing, ctx::Dict{An
y, Any})
    @ Ghost ~/.julia/packages/Ghost/EcBCA/src/trace.jl:613
 [13] trace(f::Function, args::Float64)
    @ Ghost ~/.julia/packages/Ghost/EcBCA/src/trace.jl:603
 [14] top-level scope
    @ REPL[10]:1
 [15] top-level scope
    @ ~/.julia/packages/CUDA/Ozu5O/src/initialization.jl:52

Working on it.

Meanwhile, I wonder what do you expect from tracing this code? Currently it recursively traces invocations down to functions like Core.apply_type, Base.indexed_iterate, etc. and generates a tape with ~100 operations. For example, in Avalon I block this recursion on the call to DenseConvDims() getting a short differentiable tape. Would you mind sharing your use case?

Fixed in #9

It is a long story. The short version is I have a flux model that eats too much GPU memory see here for a toy example. I think it is very hard to fix this on the Zygote side, so I want to explore if I can record a tape and add some optimizations ad hoc for my model that are hard to do in general (adding lots of unsafe_free calls). Currently I get all kinds of errors, when I try to trace my model, the stuff I report here is related to trying to trace parts of my model.

Perhaps you don't need to trace things like DenseConvDims then. Instead, I'd pass all the NNlib functions as additional primitives to the trace() and see what's generated. Another option is to differentiate the model with Zygote and then trace the transformed function (with or without primitives from NNlib). I expect the closures created by Zygote to get "open" and nice linear tape to appear. This tape may then be optimizable to lower the memory pressure.

Perhaps you don't need to trace things like DenseConvDims then. Instead, I'd pass all the NNlib functions as additional primitives to the trace() and see what's generated. Another option is to differentiate the model with Zygote and then trace the transformed function (with or without primitives from NNlib). I expect the closures created by Zygote to get "open" and nice linear tape to appear. This tape may then be optimizable to lower the memory pressure.

I will try all 4 combinations (Zygote / Yota and custom NNlib primitives yes/no) and open issues, when I hit problems. If you feel like I open too many issues or you feel a certain kind of issue is out of scope, please tell me. And thanks a lot for all the quick responses and fixes you provided so far. And for creating Ghost/Yota in the first place. Being able to trace function calls is so powerful.

So far all the issues are relevant and very much appreciated!