danijar / dreamerv2

Mastering Atari with Discrete World Models

Home Page:https://danijar.com/dreamerv2

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Difference in the KL loss terms in the paper and the code

shivakanthsujit opened this issue · comments

The algorithm for the KL balancing in the paper has the posterior and prior terms given as kl_loss = compute_kl(stop_grad(posterior), prior). So I had assumed that the code would have computed the loss as value = kld(dist(sg(post)), dist(prior)).

But instead the code has the terms reversed, with the KL loss formulated as (in networks.py, line 168) value = kld(dist(prior), dist(sg(post))).

Does that have something to do with the implementation of the kl divergence function in tensorflow_probability?

KL balancing is implemented as weighted average of two terms, the KL with stop-grad prior and the KL with stop-grad posterior.

The value you found in the code is only used for logging. It is not what the gradient is computed of.