google / uncertainty-baselines

High-quality implementations of standard and SOTA methods on a variety of tasks.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

question on kl rescaling!

ifiaposto opened this issue · comments

Hello guys!

Could you please confirm that for mnist, KL is properly rescaled (according to Equation 8 in here)? To be more precise, to evenly distribute the KL loss across data points, in def get_losses_and_metrics(model, n_train):

(1) def negative_log_likelihood(y, rv_y):
del rv_y # unused arg
return -model.output.distribution.log_prob(tf.squeeze(y))

shouldn't be:

def negative_log_likelihood(y, rv_y):
del rv_y # unused arg
return -model.output.distribution.log_prob(tf.squeeze(y))*n_train

and

(2)

def kl(y_true, y_sample):
"""KL-divergence."""
del y_true # unused arg
del y_sample # unused arg
sampling_cost = sum( [l.kl_cost_weight + l.kl_cost_bias for l in model.layers])
return sum(model.losses) * n_train + sampling_cost

shouldn't be

def kl(y_true, y_sample):
"""KL-divergence."""
del y_true # unused arg
del y_sample # unused arg
sampling_cost = sum(
[l.kl_cost_weight + l.kl_cost_bias for l in model.layers])
return sum(model.losses) + sampling_cost

Thanks!

ok, my bad. kl is normalized in the lenet definition :)