Bug in `logpdf_grad` implementation of `@dist` DSL distributions.
ztangent opened this issue · comments
This happens with both WithLabelDistribution
s and RelabeledDistribution
s.
Minimal example:
@dist labeled_uniform(labels) = labels[uniform_discrete(1, length(labels))]
logpdf_grad(labeled_uniform, :a, [:a, :b, :c])
@dist relabeled_uniform() = [:c, :d, :f][uniform_discrete(1, 3)]
logpdf_grad(relabeled_uniform, :c)
Associated errors:
ERROR: BoundsError: attempt to access 2-element Vector{Nothing} at index [3]
Stacktrace:
[1] getindex
@ .\array.jl:861 [inlined]
[2] logpdf_grad(::Gen.RelabeledDistribution{Symbol, Int64}, ::Symbol, ::Int64, ::Int64)
@ Gen ...\.julia\packages\Gen\Dne3u\src\modeling_library\dist_dsl\relabeled_distribution.jl:83
[3] logpdf_grad(::Gen.CompiledDistWithArgs{Symbol}, ::Symbol)
@ Gen ...\.julia\packages\Gen\Dne3u\src\modeling_library\dist_dsl\dist_dsl.jl:73
[4] top-level scope
ERROR: BoundsError: attempt to access 2-element Vector{Nothing} at index [3]
[1] getindex
@ .\array.jl:861 [inlined]
[2] logpdf_grad(::Gen.WithLabelArg{Any, Int64}, ::Symbol, ::Vector{Symbol}, ::Int64, ::Int64)
@ Gen ...\.julia\packages\Gen\Dne3u\src\modeling_library\dist_dsl\relabeled_distribution.jl:29
[3] logpdf_grad(d::Gen.CompiledDistWithArgs{Any}, x::Symbol, args::Vector{Symbol})
@ Gen ...\.julia\packages\Gen\Dne3u\src\modeling_library\dist_dsl\dist_dsl.jl:73
[4] top-level scope
The fix is relatively straightforward -- just an indexing error. Can get to it sometime soon.
Things appear to work fine for TransformedDistribution
s, except in the case where there are no arguments:
@dist shifted_normal(mu, sigma) = Gen.normal(mu, sigma) + 1.0
logpdf_grad(shifted_normal, 0.0, 0.0, 1.0)
@dist shifted_std_normal() = Gen.normal(0.0, 1.0) + 1.0
logpdf_grad(shifted_std_normal, 0.0)
Calling logpdf_grad
on shifted_std_normal
leads to the following error + backtrace:
ERROR: MethodError: zero(::Type{Union{}}) is ambiguous. Candidates:
...
Stacktrace:
[1] track(x::Vector{Union{}}, ::Type{Union{}}, tp::Vector{ReverseDiff.AbstractInstruction})
@ ReverseDiff ....julia\packages\ReverseDiff\YkVxM\src\tracked.jl:473
[2] ReverseDiff.GradientConfig(input::Vector{Union{}}, ::Type{Union{}}, tp::Vector{ReverseDiff.AbstractInstruction})
@ ReverseDiff ...\.julia\packages\ReverseDiff\YkVxM\src\api\Config.jl:50
[3] ReverseDiff.GradientConfig(input::Vector{Union{}}, tp::Vector{ReverseDiff.AbstractInstruction}) (repeats 2 times)
@ ReverseDiff ...\.julia\packages\ReverseDiff\YkVxM\src\api\Config.jl:35
[4] gradient(f::Function, input::Vector{Union{}})
@ ReverseDiff ...\.julia\packages\ReverseDiff\YkVxM\src\api\gradients.jl:22
[5] (::Gen.var"#46#52"{Vector{Union{}}})(::Tuple{Int64, Float64})
@ Gen .\none:0
[6] iterate
@ .\generator.jl:47 [inlined]
[7] grow_to!
@ .\array.jl:797 [inlined]
[8] collect(itr::Base.Generator{Base.Iterators.Filter{Gen.var"#48#54"{Vector{Bool}}, Base.Iterators.Enumerate{Tuple{Float64, Float64}}}, Gen.var"#46#52"{Vector{Union{}}}})
@ Base .\array.jl:721
[9] logpdf_grad(::Gen.CompiledDistWithArgs{Float64}, ::Float64)
@ Gen ...\.julia\packages\Gen\Dne3u\src\modeling_library\dist_dsl\dist_dsl.jl:78
[10] top-level scope
It's less immediately obvious to me how to fix this one, but I presume we can just write special case for when there are zero arguments.