aesara-devs / aehmc

An HMC/NUTS implementation in Aesara

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Add function to create a new HMC state

rlouf opened this issue · comments

It is cumbersome to have to systematically write:

q = aet.vector('q')
potential_energy = -logprob_fn(q)
potential_energy_grad = aesara.grad(potentiel_energy, wrt=q)

when initializing the chain's state (HMC or NUTS) when we could instead just write:

init_state = hmc.new_state(q, logprob_fn)

We should implement this new_state helper function which will work for any algorithm in the HMC family.