dfdx / Umlaut.jl

The Code Tracer

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Tracing static parameters

cscherrer opened this issue · comments

Lowered code sometimes contains :($(Expr(:static_parameter, 1))), which seems to confuse Umlaut. The compiler's use of :static_parameter is new to me, so I don't yet have any suggestions for the right way to handle this.

There are probably simpler-still examples where this problem comes up, but I'm seeing it in tracing a Tilde.jl model. So this works fine:

julia> m = @model begin
           x ~ Normal()
       end;

julia> r = rand(m())
(x = -0.8876255812527766,)

julia> logdensityof(m(), r)
-1.3128781194518375

But tracing this last call with Umlaut errors with

julia> trace(logdensityof, m(), r)

ERROR: MethodError: no method matching getproperty(::NamedTuple{(:x,), Tuple{Float64}}, ::Expr)
Closest candidates are:
  getproperty(::Any, ::Symbol) at ~/julia/julia-1.8.0-beta1/share/julia/base/Base.jl:38
  getproperty(::Any, ::Symbol, ::Symbol) at ~/julia/julia-1.8.0-beta1/share/julia/base/Base.jl:50
Stacktrace:
  [1] mkcall(::Function, ::Variable, ::Vararg{Any}; val::Missing, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ Umlaut ~/git/Umlaut.jl/src/tape.jl:196
  [2] mkcall(::Function, ::Variable, ::Vararg{Any})
    @ Umlaut ~/git/Umlaut.jl/src/tape.jl:180
  [3] record_primitive!(::Tape{Umlaut.BaseCtx}, ::Function, ::Vararg{Any})
    @ Umlaut ~/git/Umlaut.jl/src/trace.jl:137
  [4] record_or_recurse!(::Umlaut.Tracer{Umlaut.BaseCtx}, ::Function, ::Vararg{Any})
    @ Umlaut ~/git/Umlaut.jl/src/trace.jl:194
  [5] trace!(::Umlaut.Tracer{Umlaut.BaseCtx}, ::Core.CodeInfo, ::Variable, ::Vararg{Any})
    @ Umlaut ~/git/Umlaut.jl/src/trace.jl:212
  [6] record_or_recurse!(::Umlaut.Tracer{Umlaut.BaseCtx}, ::Variable, ::Vararg{Any})
    @ Umlaut ~/git/Umlaut.jl/src/trace.jl:196
  [7] trace!(::Umlaut.Tracer{Umlaut.BaseCtx}, ::Core.CodeInfo, ::Function, ::Vararg{Any})
    @ Umlaut ~/git/Umlaut.jl/src/trace.jl:212
  [8] record_or_recurse!(::Umlaut.Tracer{Umlaut.BaseCtx}, ::Function, ::Vararg{Any})
    @ Umlaut ~/git/Umlaut.jl/src/trace.jl:196
  [9] trace!(::Umlaut.Tracer{Umlaut.BaseCtx}, ::Core.CodeInfo, ::Function, ::Vararg{Any})
    @ Umlaut ~/git/Umlaut.jl/src/trace.jl:212
 [10] record_or_recurse!(::Umlaut.Tracer{Umlaut.BaseCtx}, ::Function, ::Vararg{Any})
    @ Umlaut ~/git/Umlaut.jl/src/trace.jl:196
 [11] trace!(::Umlaut.Tracer{Umlaut.BaseCtx}, ::Core.CodeInfo, ::Function, ::Vararg{Any})
    @ Umlaut ~/git/Umlaut.jl/src/trace.jl:212
 [12] record_or_recurse!(::Umlaut.Tracer{Umlaut.BaseCtx}, ::Function, ::Vararg{Any})
    @ Umlaut ~/git/Umlaut.jl/src/trace.jl:196
 [13] trace!(::Umlaut.Tracer{Umlaut.BaseCtx}, ::Core.CodeInfo, ::Function, ::Vararg{Any})
    @ Umlaut ~/git/Umlaut.jl/src/trace.jl:212
 [14] record_or_recurse!(::Umlaut.Tracer{Umlaut.BaseCtx}, ::Function, ::Vararg{Any})
    @ Umlaut ~/git/Umlaut.jl/src/trace.jl:196
 [15] trace!(::Umlaut.Tracer{Umlaut.BaseCtx}, ::Core.CodeInfo, ::Variable, ::Vararg{Any})
    @ Umlaut ~/git/Umlaut.jl/src/trace.jl:212
 [16] trace(::Function, ::Tilde.ModelClosure{Model{NamedTuple{()}, GeneralizedGenerated.NGG.TypeLevel{Expr, "Buf{40}()"}, GeneralizedGenerated.NGG.TypeLevel{Module, "Buf{17}()"}}, NamedTuple{(), Tuple{}}}, ::Vararg{Any}; ctx::Umlaut.BaseCtx, fargtypes::Nothing, deprecated_kws::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ Umlaut ~/git/Umlaut.jl/src/trace.jl:344
 [17] trace(::Function, ::Tilde.ModelClosure{Model{NamedTuple{()}, GeneralizedGenerated.NGG.TypeLevel{Expr, "Buf{40}()"}, GeneralizedGenerated.NGG.TypeLevel{Module, "Buf{17}()"}}, NamedTuple{(), Tuple{}}}, ::NamedTuple{(:x,), Tuple{Float64}})
    @ Umlaut ~/git/Umlaut.jl/src/trace.jl:334
 [18] top-level scope
    @ REPL[91]:1

If I change mkcall to

function mkcall(fn, args...; val=missing, kwargs...)
    kwargs = NamedTuple(kwargs)
    if !isempty(kwargs)
        args = (kwargs, fn, args...)
        fn = Core.kwfunc(fn)
    end
    fargs = (fn, args...)
    calculable = all(
        a -> !isa(a, Variable) ||                      # not variable
        (a._op !== nothing && a._op.val !== missing),  # bound variable
        fargs
    )
    if val === missing && calculable
        fargs_ = map_vars(v -> v._op.val, fargs)
        fn_, args_ = fargs_[1], fargs_[2:end]
        @show fn_
        @show args_
        val_ = fn_(args_...)
    else
        val_ = val
    end
    return Call(0, val_, fn, [args...])
end

I get that when the error is thrown, we have

fn_ = getproperty
args_ = ((x = -0.8876255812527766,), :($(Expr(:static_parameter, 1))))

This makes me think Umlaut seems some extra code to handle :($(Expr(:static_parameter, 1))) as a special case.

Yeah, it's kind of a known issue that often happens in low-level built-in functions, e.g. try trace(sin, 2.0). Such expressions are part of the lowered code which Umlaut parses, for example::

julia> @code_lowered sin(2.0)
CodeInfo(
1 ─      Core.NewvarNode(:(@_3))
│          Core.NewvarNode(:(y))
│          Core.NewvarNode(:(n))
│          absx = Base.Math.abs(x)
│    %5  = absx
│    %6  = ($(Expr(:static_parameter, 1)))(Base.Math.pi)
│    %7  = %6 / 4
...%11 = Base.Math.eps($(Expr(:static_parameter, 1)))
│    %12 = Base.Math.sqrt(%11)
│    %13 = %10 < %12
...
)

The best doc on :static_parameter that I found is this issue, from which I infer it's an accessor to some static field of a MethodInstance, but I'll need to ask the devs for more details. Usually, I don't dive that deep and simply make the low-level function like sin a primitive so that Umlaut doesn't try to parse it. Is it an option for you in this case?

Thanks @dfdx, I just asked about this here:
https://discourse.julialang.org/t/what-does-expr-static-parameter-1-do/79302

Usually, I don't dive that deep and simply make the low-level function like sin a primitive so that Umlaut doesn't try to parse it. Is it an option for you in this case?

Good question. I'm really not sure. If Umlaut tracing throws an error, is there a way to get the stack it was working with? I don't mean the one with all the Umlaut calls, but the stack as Umlaut sees it.

If Umlaut tracing throws an error, is there a way to get the stack it was working with?

No, but we should certainly have a mechanism for that :) Give me a day to think out proper design for that that won't turn the code into a mess.

In the call-stack branch we now have get_latest_tracer() function that in case of error returns the latest state of the internal tracer including the whole call stack so far. Here's an example related:

julia> using Umlaut
[ Info: Precompiling Umlaut [92992a2b-8ce5-4a9c-bb9d-58be9a7dc841]

julia> f(::Val{N}) where N = N + 1
f (generic function with 1 method)

julia> g(x) = f(Val(x))
g (generic function with 1 method)

julia> trace(g, 1)
ERROR: MethodError: no method matching +(::Expr, ::Int64)
  ...

julia> t = Umlaut.get_latest_tracer()
Umlaut.Tracer{Umlaut.BaseCtx}(Tape{Umlaut.BaseCtx}
  inp %1::typeof(g)
  inp %2::Int64
  %3 = Val(%2)::Val{1}
, Umlaut.Frame[Frame(
  %1 => %3
  _1 => %1
  _2 => %2
), Frame(
  _1 => f
  _2 => %3
)])

julia> t.tape                       # tape at the moment of the error
Tape{Umlaut.BaseCtx}
  inp %1::typeof(g)
  inp %2::Int64
  %3 = Val(%2)::Val{1}


julia> t.stack[end].v_fargs   # function and arguments being traced as constants or variables on the tape
(f, %3)                                  # e.g. `f` is recorded as is

julia> t.stack[end].ci             # CodeInfo object, roughly equal to @code_lowered f(args...)
CodeInfo(
1%1 = $(Expr(:static_parameter, 1)) + 1
└──      return %1
)

I haven't tried it on Tilde.jl yet, but you should be able to walk through the call stack and see if you dive to deep:

for frame in t.stack
    v_f = frame.v_fargs[1]
    fn = v_f isa Variable ? t.tape[v_f].val : v_f
    println(fn)
end

Fix for static_parameter itself is on the way.

Sounds great, I'll try it on Tilde and let you know how it goes. Thanks for the quick update!

It works, but I don't understand Umlaut tapes well enough to use the result effectively. Can I get to something that's like a typical stack trace?

Or maybe a better question... how would you recommend determining which functions to tell Umlaut to treat as primitive? I was thinking that nested call information you'd get in a stack trace would make it easier to judge the right depth to use as a cut-off point. But maybe there's a better way to go about this?

t.stack is exactly the stack trace so far, just with the ability to see concrete values collected to the tape :) Let me expand the previous example a little bit:

function print_stack_trace()
    t = get_latest_tracer()
    for (i, frame) in enumerate(reverse(t.stack))
        fn, args... = [v isa V ? t.tape[v].val : v for v in frame.v_fargs]
        meth = which(fn, map(typeof, args))
        println("[$i] $meth")
        # println("  @ $(meth.module) $(meth.file)")
    end
end

f(::Val{N}) where N = N + 1
g(x) = f(Val(x))
trace(g, 1)

print_stack_trace()

which prints:

[1] f(::Val{N}) where N in Main at REPL[40]:1
[2] g(x) in Main at REPL[41]:1

meaning that the error happened while tracing f(::Val{N}). If you feel like f() is too low-level for your use case, you can go through this stack and choose the needed level. Otherwise, report and issue and I'll try to make it work for this function too :)

(I tried to run this code on Tilde.jl, but logdensityof(m(), r) gives me UndefVarError: partialstatic not defined at the moment)

I've just pushed fix for the static parameters to the call-stack branch. It works on example above as well as trace(sin, 2.0), but I'll be thankful if you could confirm it also works for your example in Tilde.jl.

It's getting stuck at this line:

val_ = fn_(args_...)

For one run I'm getting

fn_ = Core._apply_iterate
args_ = (iterate, CompositionsBase.opcompose, (@optic _[1]))    # `@optic` is from Accessors.jl

So calling fn_(args_...) crashes

Do you have a git hash / reproducible example that I can debug?

No, but I'll get one together

Ok, just sent it to you on Zulip. It has the toml files pointing to specific commits, and a test.jl file

I see that this new error is not related to static parameters, but instead is a new independent bug. Simpler way to reproduce it (for myself mostly):

record_or_recurse!(Tracer(Tape(BaseCtx())), Accessors.opticcompose, Accessors.@optic _[1])

The bug is not the easiest one, so I'll need some time to fix it. Meanwhile, you may want to add some of the functions in Tilde to the list op primitives. E.g.:

struct MyCtx end

function Umlaut.isprimitive(::MyCtx, f, args...)
    return isprimitive(Umlaut.BaseCtx, f, args...) ||
        f in [
            Accessors.opticcompose,
            Tilde.tilde,
            Tilde.known
        ]
end

...

trace(logdensityof, m(x), r; ctx=MyCtx())

This doesn't excuse the bugs when tracing these functions, but maybe it will unblock some of your activities while I'm fixing them.

I'm finding trace calling string, and in some cases there's no method for this. Should this be sprint?

julia> string(True())
ERROR: MethodError: no method matching True(::Int64)
Closest candidates are:
  (::Type{T})(::T) where T<:Number at ~/julia/julia-1.8.0-beta1/share/julia/base/boot.jl:772
  True() at ~/.julia/packages/Static/8hh0B/src/bool.jl:9
  (::Type{T})(::AbstractChar) where T<:Union{AbstractChar, Number} at ~/julia/julia-1.8.0-beta1/share/julia/base/char.jl:50
  ...
Stacktrace:
 [1] convert(#unused#::Type{True}, x::Int64)
   @ Base ./number.jl:7
 [2] oftype(x::True, y::Int64)
   @ Base ./essentials.jl:391
 [3] zero(x::True)
   @ Base ./number.jl:296
 [4] unsigned(x::True)
   @ Base ./int.jl:208
 [5] split_sign(n::True)
   @ Base ./intfuncs.jl:771
 [6] string(n::True; base::Int64, pad::Int64)
   @ Base ./intfuncs.jl:800
 [7] string(n::True)
   @ Base ./intfuncs.jl:792
 [8] top-level scope
   @ REPL[24]:1

julia> sprint(show, True())
"static(true)"

Think I have a fix, I'll PR soon

I've just pushed the fix to call-stack branch, the following now works:

m = @model begin
      x ~ Normal()
end;

r = rand(m())
logdensityof(m(), r)
trace(logdensityof, m(), r)

Great! It's working for me too. Thanks again for this. Would you like to close this now, or wait until it's merged into master?

I will merge it later today and close the issue.

Thanks for providing such a great test case - I've fixed a major bug which would be much harder to detect otherwise!

Published in Umlaut 0.2.4.