question on kl rescaling!
ifiaposto opened this issue · comments
Hello guys!
Could you please confirm that for mnist, KL is properly rescaled (according to Equation 8 in here)? To be more precise, to evenly distribute the KL loss across data points, in def get_losses_and_metrics(model, n_train):
(1) def negative_log_likelihood(y, rv_y):
del rv_y # unused arg
return -model.output.distribution.log_prob(tf.squeeze(y))
shouldn't be:
def negative_log_likelihood(y, rv_y):
del rv_y # unused arg
return -model.output.distribution.log_prob(tf.squeeze(y))*n_train
and
(2)
def kl(y_true, y_sample):
"""KL-divergence."""
del y_true # unused arg
del y_sample # unused arg
sampling_cost = sum( [l.kl_cost_weight + l.kl_cost_bias for l in model.layers])
return sum(model.losses) * n_train + sampling_cost
shouldn't be
def kl(y_true, y_sample):
"""KL-divergence."""
del y_true # unused arg
del y_sample # unused arg
sampling_cost = sum(
[l.kl_cost_weight + l.kl_cost_bias for l in model.layers])
return sum(model.losses) + sampling_cost
Thanks!
ok, my bad. kl is normalized in the lenet definition :)