KL divergence not changing during training
gkwt opened this issue · comments
Hello,
I am trying to make a single layer BNN using the LinearReparameterization
layer. I am unable to get it to give reasonable uncertainty estimates, so I started monitoring the KL term from the layers and noticed that it is not changing at all for each epoch. Even when I scale up the KL term in the loss, it remains unchanged.
I am not sure if this is a bug, or if I am not doing the training correctly.
My model
class BNN(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim=1):
super().__init__()
self.layer1 = LinearReparameterization(input_dim, hidden_dim)
self.layerf = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
kl_sum = 0
x, kl = self.layer1(x)
kl_sum += kl
x = F.relu(x)
x = self.layerf(x)
return x, kl_sum
and my training loop
model = BNN(X_train.shape[-1], 100).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.MSELoss()
for epoch in pbar:
running_kld_loss = 0
running_mse_loss = 0
running_loss = 0
for datapoints, labels in dataloader_train:
optimizer.zero_grad()
output, kl = model(datapoints)
kl = get_kl_loss(model)
# calculate loss with kl term for Bayesian layers
mse_loss = criterion(output, labels)
loss = mse_loss + kl * kld_beta / batch_size
loss.backward()
optimizer.step()
running_mse_loss += mse_loss.detach().numpy()
running_kld_loss += kl.detach().numpy()
running_loss += loss.detach().numpy()
status.update({
'Epoch': epoch,
'loss': running_loss/len(dataloader_train),
'kl': running_kld_loss/len(dataloader_train),
'mse': running_mse_loss/len(dataloader_train)
})
When I print the KL loss, it starts at ~5.0 and does not decrease at all.
Hi @gkwt,
You seem to be already getting kl value from the model, can you try commenting out the get_kl_loss as below?
output, kl = model(datapoints)
#kl = get_kl_loss(model)
The problem persists even without the get_kl_loss
function. I should note that the values are the same as before. The backpropagation still does not change the KL value.
I have also tried this with LinearFlipout. It seems that the KL is not affected by the optimizer. After initialization of the model, I also added
for param in model.parameters():
param.requires_grad = True
to unfreeze the layers. But it had no effect on the training. KL remains constant
There was a bug in my training loop. I was overwriting the model, and so the KL was not changing. Sorry for the confusion!