pymc-devs / nutpie

Python wrapper for nuts-rs

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Use pymc's dlogp with rewrites

dehorsley opened this issue · comments

With pymc-devs/pymc#6736 merged, pymc rewrites the logp graph before applying the gradient. This removes some Ops from the graph, which should have some performance benefits in general but mainly target at pymc-devs/pymc#6717. Currently, nutpie calculates dlogp itself here:

grads = pytensor.gradient.grad(logp, value_vars)

I think this can be replaced with a simple

grads = model.dlogp(value_vars) 

Though I think we can also drop the value_vars, it only seems to be used for the grad.

Thanks for the suggestion :-)

I just included this in #42 (dc3cf2b)
It would probably be cleaner to reuse some of the compilation machinery in pymc instead of building this again...