tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow

Home Page:https://www.tensorflow.org/probability/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Dirichlet distribution sampling issue when jit_compile=True

LorenzoRimella opened this issue · comments

It seems that some seeds produce nans when sampling from a Dirichlet distribution. Any idea why? Example script below that was tested on Google Colab.

import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp

dirichlet_lambda = tf.convert_to_tensor([2., 5., 0., 10., 10., 12., 10., 10., 1., 1.], dtype = tf.float32)
seed_s2 = tf.convert_to_tensor([-1012227931,  -757448172], dtype = tf.int32)
seed_s3 = tf.convert_to_tensor([-1012227931,  -757448170], dtype = tf.int32)

@tf.function(jit_compile = True)
def jitwhat(concentration, seed):
    theta_j_k = tfp.distributions.Dirichlet(concentration = concentration).sample((13, 10), seed = seed) #.sample(seed = seed_s2) #

    return theta_j_k

foo = jitwhat(dirichlet_lambda, seed_s2)
np.where(np.isnan(foo))

Note that the Dirichlet distribution is "degenerate" as it has one of the parameters that is zero. However generally the output from the sampling method is just a zero in the corresponding position, while with that specific seed it gives NaN.

Verified as a potential bug. Colab here.