Kernel_size
flydephone opened this issue · comments
The Conv2dReparameterization
only allows kernels with same dim (e.g., 2×2)
However, some CNN model has different kernels (e.g., in inceptionresnetv2
, the kernel size in block17
is 1×7)
So, I modified the code in conv_variational.py
from:
self.mu_kernel = Parameter(
torch.Tensor(out_channels, in_channels // groups, kernel_size,
kernel_size))
self.rho_kernel = Parameter(
torch.Tensor(out_channels, in_channels // groups, kernel_size,
kernel_size))
self.register_buffer(
'eps_kernel',
torch.Tensor(out_channels, in_channels // groups, kernel_size,
kernel_size),
persistent=False)
self.register_buffer(
'prior_weight_mu',
torch.Tensor(out_channels, in_channels // groups, kernel_size,
kernel_size),
persistent=False)
self.register_buffer(
'prior_weight_sigma',
torch.Tensor(out_channels, in_channels // groups, kernel_size,
kernel_size),
persistent=False)
to
self.mu_kernel = Parameter(
torch.Tensor(out_channels, in_channels // groups, kernel_size[0],
kernel_size[1]))
self.rho_kernel = Parameter(
torch.Tensor(out_channels, in_channels // groups, kernel_size[0],
kernel_size[1]))
self.register_buffer(
'eps_kernel',
torch.Tensor(out_channels, in_channels // groups, kernel_size[0],
kernel_size[1]),
persistent=False)
self.register_buffer(
'prior_weight_mu',
torch.Tensor(out_channels, in_channels // groups, kernel_size[0],
kernel_size[1]),
persistent=False)
self.register_buffer(
'prior_weight_sigma',
torch.Tensor(out_channels, in_channels // groups, kernel_size[0],
kernel_size[1]),
persistent=False)
also, the kernel_size=d.kernel_size[0]
was changes to kernel_size=d.kernel_size
in dnn_to_cnn.py
@flydephone Thank you for creating the issue and your contribution to Bayesian-Torch. All the convolutional layers (Conv2dReparameterization, Conv3dReparameterization, ConvTranspose2dReparameterization, ConvTranspose3dReparameterization, Conv2dFlipout, Conv3dFlipout, ConvTranspose2dFlipout, ConvTranspose3dFlipout) are now enabled to support arbitrary kernel sizes.
Commit a8543ad (@msubedar ) fixes this issue.