probcomp / Gen.jl

A general-purpose probabilistic programming system with programmable inference

Home Page:https://gen.dev

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Automated construction of diffeomorphism/bijector trace translators

bgroenks96 opened this issue · comments

One feature of Turing (and similar PPLs) that I am sorely missing in Gen is the ability to automatically reparameterize a probabilistic program with one or more variables (or I guess, "choices" in the language of Gen) that have bounded support under a transformed space with unbounded support. This means automatically applying transforms such as log and logit for the cases of single and double bounded support respectively. As far as I can tell, it's possible to do this manually for any particular program with @transform, but there does not appear to be any kind of automatic transformation based on the support of the prior distributions.

My question is, therefore, how hard would it be to implement a function (or macro?) that would automatically generate a trace translator alongside the generative function which applies such a reparameterization? Being somewhat new to Gen and its internals, it's a bit hard for me to judge this. I can see where it might be tricky, especially for dynamic generative functions.

I would be curious to hear the opinions of the primary maintainers on what it would take to implement this. I would imagine that this has been discussed before.

Hi @bgroenks96, I discussed this briefly with @alex-lew, and he suggested the following alternative solution that makes use of the @dist DSL to define custom distributions:

  1. Defining a @dist on the transformed space, e.g. @dist loggamma(…) = log(gamma(…))
  2. Defining a generative function @gen gamma(…) = exp({:loggamma} ~ loggamma(…))

Then use the generative function to define your model. The trace’s variables are in an unconstrained space, but your program looks to you like you are working in the constrained space, and given a trace, you can still extract the gamma value as tr[addr] (it’s the return value of the gamma call).

This way, it should be possible to use unconstrained proposal distributions that target the value of the unconstrained random variable, without having to use trace translators as well. I believe gradient-based moves like MALA and HMC should also work, (though it's possible that gradient-support in the @dist DSL isn't fully debugged).

Of course, the use of a wrapper @dist and @gen function adds a bit of overhead, but it should also be possible to write a custom generative function that does the same thing with less overhead.

Hope this helps!

Ok, thanks. That sounds like a workable solution for many cases.

However, I think it is a bit unfortunate to offload the burden onto the user of writing their generative function already reparameterized in this way. In my opinion, a (really good) PPL should let you define the model as close as possible to its formal specification and then automatically apply such reparameterizations, which are often only really necessary for algorithmic purposes, with minimal effort from the user.

Do you think it would be at all possible to design such a scheme? Or is the necessary information lost when the generative function is constructed?

I think this would be possible if Gen.jl had stronger support for trace typing, allowing us to determine the distributions associated with a particular random variable at compile time rather than run time. It should then be possible to write, e.g. , a HMC kernel that looks at the type of the input trace, figures out what random variables need to be transformed based on the support of their associated distribution, and performs those transformations before applying gradient-based updates.

Note that you can already sort of do this with Gen's static modeling language (since each trace of a static generative function does store information about the trace's structure), but it requires relying on implementation-specific features of that modeling language which are not part of the generative function interface. Pyro uses a similar strategy, but I think they're able to do that because they don't aim to implement inference algorithms like HMC using a standard interface like GFI.

That makes sense. I guess Turing can do this as well because the type information of the distributions is known at compile-time rather than runtime. This of course comes at the cost of less flexibility and a lot of compile time overhead.