kl_divergence(Bernoulli(0), Bernoulli(0)) == NAN
fritzo opened this issue · comments
It looks like we could improve numerical stability of kl_divergence()
.
-
kl_divergence(Bernoulli(0), Bernoulli(0)) == 0
(currently NAN) -
kl_divergence(Bernoulli(0), Bernoulli(1)) == flaot('inf')
(currently NAN) -
Categorical
-
OneHotCategorical
-
Poisson
@vishwakftw Do you want to look into this?
@fritzo Sorry about the Bernoulli distribution. I have fixed it locally, and currently I am waiting for the build to complete for testing. W.R.T. the other, could you suggest what needs to be done?
No worries 😄 I'm glad we could get the implementations merged early so we can test things in Pyro.
I think it would be good to have tests of edge cases as in Neeraj's test_bernoulli_gradient
. Let's add something like TestKL.test_edge_cases()
with a few tests like
class TestKL(TestCase):
def test_edge_cases(self):
self.assertEqual(kl_divergence(Bernoulli(0), Bernoulli(0)), 0)
self.assertEqual(kl_divergence(Bernoulli(0), Bernoulli(1)), float('inf'), allow_infinity=True)
self.assertEqual(kl_divergence(
Categorical(variable([0, 1])),
Categorical(variable([0.5, 0.5]))), float('inf'), allow_infinity=True)
...
Fixed. Closing in favour of pytorch#4961.