probcomp / Gen.jl

A general-purpose probabilistic programming system with programmable inference

Home Page:https://gen.dev

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Involution DSL appears unable to read ":continuous" and ":discrete" annotation through variables

fsaad opened this issue · comments

julia> @transform f (model_in, aux_in) to (model_out, aux_out) begin
           z = :continuous
           @write(model_out[:x], 1, z)
       end

ERROR: LoadError: MethodError: no method matching var"@write"(::LineNumberNode, ::Module, ::Expr, ::Int64, ::Symbol)
Closest candidates are:
  var"@write"(::LineNumberNode, ::Module, ::Any, ::Any, ::QuoteNode) at ~/.julia/dev/Gen/src/inference/trace_translators.jl:202

Seems like the issue is that we're hardcoding how the :continuous annotation gets parsed by the macro:

"""
@read(<source>, <annotation>)
Macro for reading the value of a random choice from an input trace in the [Trace Transform DSL](@ref).
<source> is of the form <trace>[<addr>] where <trace> is an input trace, and <annotation>
is either :discrete or :continuous.
"""
macro read(src, ann::QuoteNode)
return quote read($(esc(bij_state)), $(esc(src)), $(esc(typed(ann.value)))) end
end
"""

const DISCRETE = [:discrete, :disc]
const CONTINUOUS = [:continuous, :cont]
function typed(annotation::Symbol)
if annotation in DISCRETE
return DiscreteAnn()
elseif annotation in CONTINUOUS
return ContinuousAnn()
else
error("error")
end
end

The fixed definition would be something like:

macro read(src, ann)
    _typed = GlobalRef(Gen, :typed)
    return quote read($(esc(bij_state)), $(esc(src)), $(_typed)($(esc(ann.value)))) end
end

The downside is that it would lead to a somewhat slower dynamic call to the typed function at run-time, instead of handling that logic at compile time. I'm not sure how much slower, but perhaps it's worth not making the switch and just clearly documenting that the user has to directly specify whether they want something to be :continuous or :discrete.