tpapp / DynamicHMC.jl

Implementation of robust dynamic Hamiltonian Monte Carlo methods (NUTS) in Julia.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

"Estimating parameters of a mixture"

jariji opened this issue · comments

Hello @tpapp,

I'm trying to write the "Estimating parameters of a mixture" model from the Stan docs.

I get an error about a bad initial value. I'm assuming I constructed the model or transformation incorrectly, but I don't have enough experience to narrow it down very well. The error message is included at the bottom. I wonder if you might be able to straighten me out. Cheers.

using Random, StatsFuns, TransformVariables, LogDensityProblems, DynamicHMC
using Distributions: Categorical, loglikelihood, logpdf
import Distributions as Ds


struct LatentDiscreteProblem
    y::Vector{Float64}
    K::Int
end

logprior(params) = let
    K = length(params.μ)
    ℓ = 0.
    θ = params.θ
    for i in 1:K
        μ, σ = params.μ[i], params.σ[i]
        ℓ += logpdf(Ds.Normal(0., 10.), μ)
        ℓ += logpdf(Ds.LogNormal(0., 2.), σ)
        ℓ += sum(logpdf(Ds.Dirichlet(ones(K)), θ))
    endend

loglik(params, problem) = let
    (;y) = problem
    (;μ,σ,θ) = params
    K = length(μ)
    N = length(θ)
    ℓ = 0.
    log_theta = log.(θ)
    for n in 1:N
        lps = copy(log_theta)
        for k in 1:K
            lps[k] += logpdf(Ds.Normal(μ[k], σ[k]), y[n])
        end+= logsumexp(lps)
    endend

(problem::LatentDiscreteProblem)(θ) = logprior(θ) + loglik(θ, problem)


problem_transformation(p::LatentDiscreteProblem) = as((;μ=as(Vector, asℝ, p.K), σ=as(Vector, asℝ₊, p.K), θ=as(Vector, asℝ₊, p.K)))

let N = 13
    K = 3
    μ=rand(K) # TODO Ensure the values are in increasing order for identification.
    σ=rand(K)
    θ=rand(Ds.Dirichlet(ones(K)))
    y = rand(N)
    params = (;μ,σ,θ)
    p = LatentDiscreteProblem(y, K)
    p(params)

    P = TransformedLogDensity(problem_transformation(p), p)
    ∇P = ADgradient(:ForwardDiff, P)

    results = mcmc_with_warmup(Random.GLOBAL_RNG, ∇P, 1000; reporter = NoProgressReport())
    # DomainError with [-0.750034, -0.901617, -0.250397, 0.681036, 0.600515, 2.69998, -0.971549, 1.65164, -1.93758]:
    # Starting point has non-finite density.

end

DynamicHMC v3.1.1