About the structure of ConvGuidedFilter
nanjingzhouyu opened this issue · comments
Hi, I'm confused when I try to understand the structure of ConvGuidedFilter. Could you please tell me which part of this code is about "Dilated Conv" and which part is about"pointwise convolution Block"?
And the transformation function F(I) is “conv_a” in the code, right?
class ConvGuidedFilter(nn.Module):
def __init__(self, radius=1, norm=nn.BatchNorm2d):# batchnorm 归一化
super(ConvGuidedFilter, self).__init__()
self.box_filter = nn.Conv2d(3, 3, kernel_size=3, padding=radius, dilation=radius, bias=False, groups=3)
self.conv_a = nn.Sequential(nn.Conv2d(6, 32, kernel_size=1, bias=False),
norm(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 32, kernel_size=1, bias=False),
norm(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 3, kernel_size=1, bias=False))
self.box_filter.weight.data[...] = 1.0
def forward(self, x_lr, y_lr, x_hr):
_, _, h_lrx, w_lrx = x_lr.size()
_, _, h_hrx, w_hrx = x_hr.size()
N = self.box_filter(x_lr.data.new().resize_((1, 3, h_lrx, w_lrx)).fill_(1.0)) #添加噪声
## mean_x
mean_x = self.box_filter(x_lr)/N
## mean_y
mean_y = self.box_filter(y_lr)/N
## cov_xy
cov_xy = self.box_filter(x_lr * y_lr)/N - mean_x * mean_y
## var_x
var_x = self.box_filter(x_lr * x_lr)/N - mean_x * mean_x
## A
A = self.conv_a(torch.cat([cov_xy, var_x], dim=1))
## b
b = mean_y - A * mean_x
## mean_A; mean_b
mean_A = F.interpolate(A, (h_hrx, w_hrx), mode='bilinear', align_corners=True) #上采样
mean_b = F.interpolate(b, (h_hrx, w_hrx), mode='bilinear', align_corners=True) #上采样
return mean_A * x_hr + mean_b#mean_A = Ah mean_b = bh
Dilated Conv
is box_filter
; pointwise convolution Block
is conv_a
; F(I)
is implemented here.
Thank you so much for the reply