Dirichlet distribution sampling issue when jit_compile=True
LorenzoRimella opened this issue · comments
Lorenzo Rimella commented
It seems that some seeds produce nans when sampling from a Dirichlet distribution. Any idea why? Example script below that was tested on Google Colab.
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
dirichlet_lambda = tf.convert_to_tensor([2., 5., 0., 10., 10., 12., 10., 10., 1., 1.], dtype = tf.float32)
seed_s2 = tf.convert_to_tensor([-1012227931, -757448172], dtype = tf.int32)
seed_s3 = tf.convert_to_tensor([-1012227931, -757448170], dtype = tf.int32)
@tf.function(jit_compile = True)
def jitwhat(concentration, seed):
theta_j_k = tfp.distributions.Dirichlet(concentration = concentration).sample((13, 10), seed = seed) #.sample(seed = seed_s2) #
return theta_j_k
foo = jitwhat(dirichlet_lambda, seed_s2)
np.where(np.isnan(foo))
Note that the Dirichlet distribution is "degenerate" as it has one of the parameters that is zero. However generally the output from the sampling method is just a zero in the corresponding position, while with that specific seed it gives NaN.
Chris Jewell commented
Verified as a potential bug. Colab here.