Solution to the ODE model in chapter 15
rasoolianbehnam opened this issue · comments
Just FYI. I have written the following model for the ODE part in chapter 16 (Hare/Lynx). It's working but very slow.
def Ind(d, reinterpreted_batch_ndims=1, **kwargs):
return tfd.Independent(d, reinterpreted_batch_ndims=reinterpreted_batch_ndims, **kwargs)
root = tfd.JointDistributionCoroutine.Root
N = len(data)
@tf.function
def get_HL(b_h, m_h, b_l, m_l, H1, L1):
@tf.function
def ode_fn(t, y):
H = y[..., 0]
L = y[..., 1]
a = tf.stack([b_h - m_h * L, b_l * H - m_l], axis=-1)
return a * y
t_init = 0
y_init = tf.stack([H1, L1], axis=-1)
solver = tfp.math.ode.BDF(rtol=1e-3, atol=1e-3, max_num_steps=500)
results = solver.solve(ode_fn, t_init, y_init, solution_times=tf.range(0, N))
HL = einsum("t...k->...tk", results.states)
H = HL[..., 0]
L = HL[..., 1]
return H, L
@tfd.JointDistributionCoroutine
def m03():
mx = tf.float32.max
m_l = yield root(tfd.TruncatedNormal(1, .5, 0, mx, name='m_l'))
m_h = yield root(tfd.TruncatedNormal(.05, .05, 0, mx, name='m_h'))
b_l = yield root(tfd.TruncatedNormal(.05, .05, 0, mx, name='b_l'))
b_h = yield root(tfd.TruncatedNormal(1, .5, 0, mx, name='b_h'))
batch_shape = m_l.shape
sigma_h = yield root(tfd.Exponential(1, name="sigma_h"))
sigma_l = yield root(tfd.Exponential(1, name="sigma_l"))
H1 = yield root(tfd.LogNormal(tf.math.log(10.), 1, name='H1'))
L1 = yield root(tfd.LogNormal(tf.math.log(10.), 1, name='L1'))
p_h = yield root(tfd.Beta(40, 200, name='p_h'))
p_l = yield root(tfd.Beta(40, 200, name='p_l'))
H, L = get_HL(b_h, m_h, b_l, m_l, H1, L1)
yield Ind(tfd.LogNormal(tf.math.log(p_h[..., None]*H), sigma_h[..., None]), name="H_obs")
yield Ind(tfd.LogNormal(tf.math.log(p_l[..., None]*L), sigma_l[..., None]), name="L_obs")