dfdx / Ghost.jl

The Code Tracer

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Correctness Issue with Loop Tracing

mharradon opened this issue · comments

Hello,

Nice project! I'm currently exploring building off of it. I wanted to raise the following issue with using should_trace_loops!(true) with non-fixed inner traces:

julia> should_trace_loops!(true)
true
julia> function looper(x)
         for i in 1:4
           if i%2 ==0
             x += 1
           else
             x -= 1
           end
         end
         return x
       end
looper (generic function with 1 method)
julia> looper(3)
3
julia> play!(trace(looper, 3)[2], 3)
-1

This may be a documentation issue - or perhaps an exception could be thrown if this can be detected?

I think the issue is not with the loop but rather with the if which is indeed not (dtnamically) supported. Looks like x -= 1 is recorded during the first loop iteration and thus the generated code looks more like this:

for i=1:4
   x = x - 1
end

which explains the result. Sorry if the documentation isn't clear about this detail of the tracing, I'll think how to make it more explicit.

Yes, that was what I figured. I like the capability - in practice it can reduce trace size dramatically. That just seems like a bit of a footgun. I haven't reviewed the implementation in detail yet, but 3 options in my mind:

  1. Documentation
  2. Detection and exception/warning on control flow detection
  3. More general control flow support (e.g. if you had a similar representation for if then this would work). I know that gets very hard as discussed on the other issues. This is also probably related to plans to expose various IR processing from the compiler team - if you could get SSA IR with Phi nodes that might be preferable

If there's something that I could use I will certainly contribute back. It would be nice if there's some clever 4th option with intermediate properties (e.g. I was musing if it might be useful to keep the full inner block around inside that while loop and leave control flow / tracing to the user).

Detection of control flow can be pretty tricky because the compiler and intermediate libraries (e.g. IRTools) are free to rearrange code blocks in any way as long as the flow is preserved. Loop detection by itself is essentially a hack based on assumption that the loop ends at the point of a backward gotos, but even this simple logic actually has its flaws (e.g. last loop instruction vs. loop exit point). As for detecting ifs from forward gotos, I'm not sure it's even possible without explicit Phi nodes or something similar. So option (2) sounds almost as hard as (3), but the benefit is much smaller.

Fully supporting loops and conditions is a great end goal. Unfortunately, the current tracing-based approach can't do it by definition, e.g. in:

function foo(x)
    if x > 0
        return f(x)
    else
        return g(x)
    end
end

if during tracing, execution took the 1st path, we will never hit the 2nd path - neither the original code, nor the code added by the IR transformation.

Instead we need a way to do everything fully statically, without even calling the function. In the example above, it means that we need to recurse into both - f() and g(). As far as I know, there are currently no convenient tools to do it, not to say the compiler tools landscape is very unstable at the moment. Hopefully, late Julia 1.7 or 1.8 will bring enough tooling to deal with these issues.

Looks like at the moment (1) is the only real option.

I think (2) might be possible in a robust way. On every jump statement you can check the condition and record the destination block to the tape. Then when entering a block you can assert that the last destination block (written statically to the tape) matches the destination block calculated on the tape. I'm thinking first for branching, but I think the same thing could be done for loops.

I created a quick prototype, see https://github.com/mharradon/Ghost.jl/blob/mharradon/BranchesAndLoops/test/test_trace.jl#L235

There's a fair bit of overhead, but it could potentially be useful to detect incompatible branch structure and recompile a trace for a different structure.

julia> should_assert_branches!(true)
true

julia> v, tape = trace(myabs, 3)
(3.0, Tape{Dict{Any, Any}}
  inp %1::typeof(myabs)
  inp %2::Int64
  %3 = _check_block(1, nothing)::Nothing
  %4 = >=(%2, 0.0)::Bool
  %5 = _check_and_set_branch!(%4, 3, %3)::Nothing
  %6 = _check_and_set_branch!(nothing, 2, %5)::Int64
  %7 = _check_block(2, %6)::Nothing
  %8 = *(1.0, %2)::Float64
  %9 = _check_and_set_branch!(nothing, 0, %7)::Int64
)

julia> play!(tape, myabs, 2)
2.0

julia> play!(tape, myabs, -2)
ERROR: AssertionError: if next_block !== nothing
    next_block == block_id
else
    true
end
Stacktrace:
 [1] _check_block(block_id::Int64, next_block::Int64)
   @ Ghost ~/Ghost.jl/src/trace.jl:222
 [2] exec!(tape::Ghost.Tape{Dict{Any, Any}}, op::Ghost.Call{typeof(Ghost._check_block)})
   @ Ghost ~/Ghost.jl/src/tape.jl:476
 [3] play!(::Ghost.Tape{Dict{Any, Any}}, ::Function, ::Vararg{Any, N} where N; debug::Bool)
   @ Ghost ~/Ghost.jl/src/tape.jl:553
 [4] play!(::Ghost.Tape{Dict{Any, Any}}, ::Function, ::Vararg{Any, N} where N)
   @ Ghost ~/Ghost.jl/src/tape.jl:545
 [5] top-level scope
   @ REPL[7]:1

Looks interesting! Would you mind converting it to a PR? I think I also need some explanation of last_block semantics.

Happy to PR - see #24.

I renamed last_block to last_block_jmp_target. Essentially this is replicating the branch logic at the end of each block to determine the destination block based on the played-forward values on the tape. That resulting value is what's written to the tape at last_block_jmp_target. On entering any block then an assertion is written to the tape checking that the tape value for last_block_jmp_target matches the entered block id (and wipes the value).

A few issues with the current implementation:

  1. Config is checked at IR mod time rather than IR execution time (as trace_loops does now). Maybe those should be consistent? If so it would be worth confirming these calls can get optimized away I think.
  2. Perhaps a value should be returned / written to the Tape rather than an Exception thrown on bad branches? That would improve the performance of adaptive tracing systems using this.