cscherrer / Soss.jl

Probabilistic programming via source rewriting

Home Page:https://cscherrer.github.io/Soss.jl/stable/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Help fitting a simple t distribution

cwoode opened this issue · comments

using MeasureTheory
using Soss
using SampleChainsDynamicHMC
using Random
import Distributions as Dist

model_test_t = @model T begin
    ν ~ Exponential(10.0)
    y ~ StudentT(ν) |> iid(T)
end;

y_test = rand(Dist.TDist(2.5),1000);

post_test = sample(model_test_t(T=length(y_test))|(y=y_test,), dynamichmc())

Gives the following:

4000-element MultiChain with 4 chains and schema (ν = Float64,)
(ν = 155.0±51.0,)

I can't seem to get this to work.
Also, any hint on getting advancedHMC to work with the current release? dynamichmc tends to abort on any sampling errors.

Thanks for letting me know about this. The problem is that Soss currently uses logdensity, which for StudentT is defined as

function logdensity(d::StudentT{(:ν,)}, x) 
    ν = d.ν
    return  xlog1py((ν + 1) / (-2), x^2 / ν)
end

function basemeasure(d::StudentT{(:ν,)})
    inbounds(x) = true
    const= 0.0
    varℓ() = loggamma((d.ν+1)/2) - loggamma(d.ν/2) - log* d.ν) / 2
    base = Lebesgue(ℝ)
    FactoredBase(inbounds, constℓ, varℓ, base)
end

If ν is constant, there's no need to compute the normalizing constant loggamma((d.ν+1)/2) - loggamma(d.ν/2) - log(π * d.ν) / 2. But in this case we need it, or the result is wrong.

I need to make Soss more intelligent about tracking this sort of thing, but for now "correct" is more important than "fast". So for a quick fix, I'll tag a new release that uses logpdf instead of logdensity.

Also, I just noticed that I'm using a different default parameterization than Distributions. I'll keep Exponential{(:λ,)} as the "rate" parameterization, but I need to add a "scale" parameterization too, and... maybe that should be the default? I mostly try to match Distributions.jl for defaults where it makes sense.

For now, if I change your ν ~ Exponential(10.0) to ν ~ Exponential(λ=0.1) (making the rate parameter explicit) and make Soss use logpdf, I can do

julia> function f(ν,T=1000)
           y_test = rand(Dists.TDist(ν),T)
           post_test = sample(m(T=T) | (y=y_test,), dynamichmc())
       end
f (generic function with 2 methods)

julia> using TupleVectors: summarize

julia> for ν  (0.1,0.2,0.5,1.0,2.0,5.0,10.0)
           println(ν," => ", summarize(f(ν)))
       end
0.1 =>= 0.05077±0.0024,)
0.2 =>= 0.1084±0.0054,)
0.5 =>= 0.2514±0.014,)
1.0 =>= 0.496±0.032,)
2.0 =>= 0.951±0.08,)
5.0 =>= 1.91±0.22,)
10.0 =>= 3.42±0.61,)

Those numbers look suspicious to me, so I'm not sure yet that the fix is complete. I'll need to do some more checking.

I'd love to get AdvancedHMC working, and it's one of those things that's "not hard in principle", and would probably just take a few days. It's just that there are so many of those :)

Thanks for the quick response. Yes, it doesn't look right. This is an incredible project btw.

Ok, think I got it:

julia> for ν  (0.1,0.2,0.5,1.0,2.0,5.0,10.0)
           println(ν," => ", summarize(f(ν)))
       end
0.1 =>= 0.1015±0.0033,)
0.2 =>= 0.1965±0.0069,)
0.5 =>= 0.4939±0.019,)
1.0 =>= 1.008±0.047,)
2.0 =>= 1.99±0.13,)
5.0 =>= 5.06±0.6,)
10.0 =>= 11.9±2.8,)

and thanks!

This is working now. Thanks

4000-element MultiChain with 4 chains and schema (ν = Float64,)
(ν = 2.183±0.15,)