aesara-devs / aehmc

An HMC/NUTS implementation in Aesara

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Initial Aesara HMC conversion

brandonwillard opened this issue · comments

We want to start this Aesara HMC implementation by converting a simple, yet representative function from blackjax. I believe we decided to try the velocity_verlet function.

From what I can tell, the following will be need to be done:

  • Clarify the inputs and return value of velocity_verlet.
    • Aesara isn't based on the compilation/tracing of functions, so, instead of potential and kinetic energy functions, those arguments will need to be potential and kinetic energy graphs. Such graphs are essentially the body of the original functions, so there's not much of a change.
  • Create an Aesara Type for IntegratorState and, eventually, a corresponding Numba conversion.
    • This is probably a worthwhile convenience, but it's not absolutely necessary. I believe that ParamsType is essentially what we want, so there might not be much to do anyway.
  • Replace jax.tree_util.tree_multimap with Aesara's Scan Op.

@rlouf, I've sketched out a simple conversion of velocity_verlet in Aesara to get you started: aesara_hmc/integrators.py.