IntelLabs / bayesian-torch

A library for Bayesian neural network layers and uncertainty estimation in Deep Learning extending the core of PyTorch

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

add mu_kernel to of delta_kernel in flipout layers?

burntcobalt opened this issue · comments

Hi,

Shouldn't mu_kernel be added to delta_kernel in the code below?

Best,
Lewis

diff --git a/bayesian_torch/layers/flipout_layers/conv_flipout.py b/bayesian_torch/layers/flipout_layers/conv_flipout.py
index 4b3e88d..719cfdc 100644
--- a/bayesian_torch/layers/flipout_layers/conv_flipout.py
+++ b/bayesian_torch/layers/flipout_layers/conv_flipout.py
@@ -165,7 +165,7 @@ class Conv1dFlipout(BaseVariationalLayer_):
         sigma_weight = torch.log1p(torch.exp(self.rho_kernel))
         eps_kernel = self.eps_kernel.data.normal_()
 
-        delta_kernel = (sigma_weight * eps_kernel)
+        delta_kernel = (sigma_weight * eps_kernel) + self.mu_kernel 
 
         kl = self.kl_div(self.mu_kernel, sigma_weight, self.prior_weight_mu,
                          self.prior_weight_sigma)

Hi Lewis,

In Flipout (Wen et al. 2018 ) method, the operation is performed separately with the mean and perturbation component as shown below (Eqn(4) in https://arxiv.org/pdf/1803.04386.pdf). I hope this helps to clarify your question.

flipout

outputs

    outputs = F.conv1d(x,
                       weight=self.mu_kernel,
                       bias=self.mu_bias,
                       stride=self.stride,
                       padding=self.padding,
                       dilation=self.dilation,
                       groups=self.groups)

perturbed_outputs

    perturbed_outputs = F.conv1d(x * sign_input,
                                 bias=bias,
                                 weight=delta_kernel,
                                 stride=self.stride,
                                 padding=self.padding,
                                 dilation=self.dilation,
                                 groups=self.groups) * sign_output

outputs + perturbed_outputs

    return outputs + perturbed_outputs, kl

Best,
Ranganath

Ah, I see. Thank you, Ranganath.