ksachdeva / rethinking-tensorflow-probability

Statistical Rethinking (2nd Ed) with Tensorflow Probability

Home Page:https://ksachdeva.github.io/rethinking-tensorflow-probability/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Solution to the ODE model in chapter 15

rasoolianbehnam opened this issue · comments

Just FYI. I have written the following model for the ODE part in chapter 16 (Hare/Lynx). It's working but very slow.

def Ind(d, reinterpreted_batch_ndims=1, **kwargs):
    return tfd.Independent(d, reinterpreted_batch_ndims=reinterpreted_batch_ndims, **kwargs)
    
root = tfd.JointDistributionCoroutine.Root

N = len(data)
@tf.function
def get_HL(b_h, m_h, b_l, m_l, H1, L1):
    @tf.function
    def ode_fn(t, y):
        H = y[..., 0]
        L = y[..., 1]
        a = tf.stack([b_h - m_h * L, b_l * H - m_l], axis=-1)
        return a * y

    t_init = 0
    y_init = tf.stack([H1, L1], axis=-1)
    solver = tfp.math.ode.BDF(rtol=1e-3, atol=1e-3, max_num_steps=500)
    results = solver.solve(ode_fn, t_init, y_init, solution_times=tf.range(0, N))

    HL = einsum("t...k->...tk", results.states)

    H = HL[..., 0]
    L = HL[..., 1]
    return H, L

@tfd.JointDistributionCoroutine
def m03():
    mx = tf.float32.max
    m_l = yield root(tfd.TruncatedNormal(1, .5, 0, mx, name='m_l'))
    m_h = yield root(tfd.TruncatedNormal(.05, .05, 0, mx, name='m_h'))
    b_l = yield root(tfd.TruncatedNormal(.05, .05, 0, mx, name='b_l'))
    b_h = yield root(tfd.TruncatedNormal(1, .5, 0, mx, name='b_h'))
    
    batch_shape = m_l.shape
    
    sigma_h = yield root(tfd.Exponential(1, name="sigma_h"))
    sigma_l = yield root(tfd.Exponential(1, name="sigma_l"))
    
   
    H1 = yield root(tfd.LogNormal(tf.math.log(10.), 1, name='H1'))
    L1 = yield root(tfd.LogNormal(tf.math.log(10.), 1, name='L1'))
    
    p_h = yield root(tfd.Beta(40, 200, name='p_h'))
    p_l = yield root(tfd.Beta(40, 200, name='p_l'))
    

    H, L = get_HL(b_h, m_h, b_l, m_l, H1, L1)
    yield Ind(tfd.LogNormal(tf.math.log(p_h[..., None]*H), sigma_h[..., None]), name="H_obs")
    yield Ind(tfd.LogNormal(tf.math.log(p_l[..., None]*L), sigma_l[..., None]), name="L_obs")