dfdx / Umlaut.jl

The Code Tracer

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Using Umlaut with Tilde.jl

cscherrer opened this issue · comments

Hi, I'm wondering if Umlaut could be a nice foundation for some work I'm doing in Tilde.jl. Say we have a model for a random walk

m = @model x begin
    σ ~ Exponential()
    x[1] ~ Normal() 
    for j in 2:length(x)
        x[j] ~ Normal= x[j-1], σ = σ)
    end
    return x
end;

A for loop will need to be converted into a canonical form, something like

m = @model x begin
    σ ~ Exponential()
    x[1] ~ Normal()
    iter = 2:length(x)

    temp = iterate(iter)
    if temp === nothing
        @goto done
    end

    @label loop
    j, state = temp
    x[j] ~ Normal= x[j-1], σ = σ)
    temp = iterate(iter, state)
    if temp !== nothing
        @goto loop
    end

    @label done
    return x
end;

Depending on the model structure, there can be lots of ways to go about inference. One popular approach is Metropolis-Hastings, where we have a function that starts with some existing trace (choice of sampled values) and proposes a new one, which may be accepted or not. Sweeping details under the rug here, because they're a little distracting and you're likely familiar with this anyway.

Naively, each proposal could "run the model" from beginning to end. But this leads to lots of redundant computation, so it's much better if we have the option to jump to a given ~ statement and run from there. This might run to the end, or stop early.

For example, say we want to start at x[3]. Starting there and going until the next sample would look like this:

x[3] ~ Normal= x[2], σ = σ)
temp = iterate(iter, state)
val, state = temp
x[4] ~ Normal= x[3], σ = σ)

From this point, we might decide to keep the previous value for x[4], or draw a new sample and keep going.

Because of the way we're jumping into the middle of this, we'll need a view of the whole loop at once, rather than requiring full iteration through it. Your Loop construct, and the fact that Umlaut is still in development, made me wonder if this could be adapted to work with the kind of problems that come up in PPL.

My first thought here is to associate your Loop with a TupleVector (https://github.com/cscherrer/TupleVectors.jl) Then the jump would just be accessing a given element of the TupleVector.

Tilde is still at prototype stage, so I could say more but it quickly gets pretty fuzzy. Does what I'm saying make sense? Any thoughts on whether this could make sense?

Maybe! But there's a few details to keep in mind.

First and most important - Umlaut.trace() cannot create Loops at the moment. Umlaut re-uses implementation of the Tape and all its operations from Ghost.jl and aims to make tracing - especially loop tracing - more transparent and robust. Yet, so far I only have an approximate plan and no concrete code for that. If you trace a loop, Umlaut will just unroll all its iterations into a flat tape.

If this is not a blocker, e.g. you want to construct a tape manually or are ready to wait for a couple of month until I get to it, then the next thing to check is performance. What you described sounds like executing operations on the tape during the simulation, and, honestly, I'd never tested its speed. Typically, I use Tape to transform code and compile it back to a normal Julia function with all the optimizations. Not that I expect tape execution to be terribly slow, but you know, better check such things from the very beginning.

The rest of the plan sounds reasonable. You might be interested to learn about tape context, which sounds like the easiest way to keep MCMC trace together with the tape. The rest depends on your vision, I'll be happy to answer any questions.

Sounds good. I'm still figuring out what kind of structures I need in place to cover inference. I'm not doing anything with stochastic control flow yet, but I want to get to it. But I'm not in any hurry, mostly I want to get it right :)

Once there's a Loop in Umlaut, I should be able to do this and avoid recompiling the tape each time the control flow changes. At the very least, being able to transform code to a canonical form that's closer to lowered code is already useful.

Thanks for the pointer, I'll read more about tape context.