Can not infer term type in egraph
vitrun opened this issue · comments
Something is wrong when matching against pure literal rules, such as pi + 3 --> cos(4)
, im+(pi+3) --> sin(4)
. The following demo can reproduce the issue. I've tried different versions including master and v1.3.3.
using Metatheory
using Metatheory.EGraphs
using TermInterface
struct Term{T}
f::Any
args::Vector{Any}
end
function Term(f, args)
T = if length(args) == 0
Any
elseif length(args) == 1
promote_type(symtype(args[1]))
else
promote_type(symtype(args[1]), symtype(args[2]))
end
Term{T}(f, args)
end
Base.promote_type(::Type{Irrational{:π}}, ::Type{Int64}) = Real
TermInterface.exprhead(e::Term) = :call
TermInterface.operation(e::Term) = e.f
TermInterface.arguments(e::Term) = e.args
TermInterface.istree(e::Term) = true
TermInterface.symtype(::Term{T}) where {T} = T
TermInterface.symtype(::T) where {T} = T
function TermInterface.similarterm(x::Term, head, args; metadata = nothing, exprhead = :call)
Term(head, args)
end
function EGraphs.egraph_reconstruct_expression(
T::Type{Term{S}},
op,
args;
metadata = nothing,
exprhead = nothing,
) where {S}
Term(op, args)
end
pt = @theory a b c begin
im + (pi + 3) --> sin(4)
# pi + 3 --> cos(4)
end
# let's create an egraph
ex = Term(+, [im, Term(+, [pi, 3])])
g = EGraph(ex)
settermtype!(g, Term{symtype(ex)})
# settermtype!(g, :+, 2, Term{Real})
saturate!(g, pt)
r = extract!(g, astsize)
println(r)
I digged into the code and found following function in ematch.jl
, which I believe is to blame.
function lookup_pat(g::EGraph, p::PatTerm)
@assert isground(p)
eh = exprhead(p)
op = operation(p)
args = arguments(p)
ar = arity(p)
T = gettermtype(g, op, ar)
ids = [lookup_pat(g, pp) for pp in args]
if all(i -> i isa EClassId, ids)
n = ENodeTerm{T}(eh, op isa Symbol ? eval(op) : op, ids)
ec = lookup(g, n)
mn = ENodeTerm{T}(eh, +, [1, 2])
ec2 = lookup(g, mn)
println("T: $T, op: $(typeof(op)), n: $n, ec: $ec, mn:$mn, ec2: $ec2")
return ec
else
return nothing
end
end
In the demo above, overloaded promote_type
is used to decide the type parameter of Term{T}
, and egraph has no idea of it. Meanwhile, is T = gettermtype(g, op, ar)
sufficient to decide the type of an enode term? I doubt that. Consider pi + 3
and im +3
, they have same op +
and ar 2
, but the resulting termtypes are totally different.