`logpdf_grad` errors for `HomogeneousMixture`
fzaiser opened this issue · comments
The following example crashes:
using Gen
@gen function test()
mix = HomogeneousMixture(broadcasted_normal, [1, 0])
means = hcat([0.0, 0.0], [1.0, 1.0], [2.0, 2.0], [3.0, 3.0])
@trace(mix([0.25, 0.25, 0.25, 0.25], means, [0.1, 0.1, 0.1, 0.1]), :x)
end
trace = Gen.simulate(test, ())
result = Gen.hmc(trace, selectall())
It throws the following error:
ERROR: LoadError: DimensionMismatch("new dimensions (1, 2) must be consistent with array size 4")
[1] (::Base.var"#throw_dmrsa#196")(::Tuple{Int64,Int64}, ::Int64) at ./reshapedarray.jl:41
[2] reshape at ./reshapedarray.jl:45 [inlined]
[3] reshape(::Array{Float64,1}, ::Int64, ::Int64) at ./reshapedarray.jl:116
[4] logpdf_grad(::HomogeneousMixture{Array{Float64,N} where N}, ::Array{Float64,1}, ::Array{Float64,1}, ::Array{Float64,2}, ::Array{Float64,1}) at [...]/packages/Gen/[...]/src/modeling_library/mixture.jl:115
...
I believe the reason is that in the line
Gen.jl/src/modeling_library/mixture.jl
Line 117 in fa759d3
length(dist.dims)
should be replaced by K
. This removes the exception, but I don't understand the code well enough to be sure that this is the correct fix or whether other parts of the code have to be fixed too.Haven't looked in depth, but I suspect this is indeed due to an assumption somewhere that args to distributions will be flat, i.e. cannot be array-valued. The use of length
instead of size
/axes
looks suspect to me.
In general, the args to a distribution could be arrays of different shapes. I'm not aware of us having general machinery for flattening and unflattening arrays in the gradient operations (nor am I sure that flattening and unflattening is the right thing to do, necessarily).
(Oops, misread a doc. Deleted comment.)
@bzinberg Thanks for the quick reply! In the documentation for HomogeneousMixture
, there is an example with a multivariate normal distribution, which takes a mean vector and a covariance matrix (i.e. different shapes for the two arguments). Therefore, I thought it was supported. Do you think this functionality would be difficult to implement?
Hi @fzaiser! I think a lot of us were on winter break when you posted this and it fell through the cracks -- sorry about that!
I think you're right that the length(dist.dims)
on that line should be replaced by K
, the number of components. Thanks for tracking this down and finding (then filing) the bug!
(As an aside, HMC will struggle to explore multiple modes in this target — but I think that may be the point of the experiment :).)
As Ben mentioned, there are parts of Gen (including the @dist
DSL) that make certain restrictive assumptions about data shapes, but I don't think you should run into that on this example.
Hi @alex-lew, no problem and thanks for the fix! I hope to have some time to experiment with it soon. Indeed, I'm aware of HMC struggling with such a multi-modal distribution. :) I was just playing around with gradient-based inference methods when I hit the bug and HMC was the simplest way to reproduce it.