Initial Aesara HMC conversion
brandonwillard opened this issue · comments
Brandon T. Willard commented
We want to start this Aesara HMC implementation by converting a simple, yet representative function from blackjax
. I believe we decided to try the velocity_verlet
function.
From what I can tell, the following will be need to be done:
- Clarify the inputs and return value of
velocity_verlet
.- Aesara isn't based on the compilation/tracing of functions, so, instead of potential and kinetic energy functions, those arguments will need to be potential and kinetic energy graphs. Such graphs are essentially the body of the original functions, so there's not much of a change.
- Create an Aesara
Type
forIntegratorState
and, eventually, a corresponding Numba conversion.- This is probably a worthwhile convenience, but it's not absolutely necessary. I believe that
ParamsType
is essentially what we want, so there might not be much to do anyway.
- This is probably a worthwhile convenience, but it's not absolutely necessary. I believe that
- Replace
jax.tree_util.tree_multimap
with Aesara'sScan
Op
.
Brandon T. Willard commented
@rlouf, I've sketched out a simple conversion of velocity_verlet
in Aesara to get you started: aesara_hmc/integrators.py
.