Add function to create a new HMC state
rlouf opened this issue · comments
Rémi Louf commented
It is cumbersome to have to systematically write:
q = aet.vector('q')
potential_energy = -logprob_fn(q)
potential_energy_grad = aesara.grad(potentiel_energy, wrt=q)
when initializing the chain's state (HMC or NUTS) when we could instead just write:
init_state = hmc.new_state(q, logprob_fn)
We should implement this new_state
helper function which will work for any algorithm in the HMC family.