wuhuikai / DeepGuidedFilter

Official Implementation of Fast End-to-End Trainable Guided Filter, CVPR 2018

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

About the structure of ConvGuidedFilter

nanjingzhouyu opened this issue · comments

Hi, I'm confused when I try to understand the structure of ConvGuidedFilter. Could you please tell me which part of this code is about "Dilated Conv" and which part is about"pointwise convolution Block"?
And the transformation function F(I) is “conv_a” in the code, right?

class ConvGuidedFilter(nn.Module):
    def __init__(self, radius=1, norm=nn.BatchNorm2d):# batchnorm 归一化
        super(ConvGuidedFilter, self).__init__()

        self.box_filter = nn.Conv2d(3, 3, kernel_size=3, padding=radius, dilation=radius, bias=False, groups=3)
        self.conv_a = nn.Sequential(nn.Conv2d(6, 32, kernel_size=1, bias=False),
                                    norm(32),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(32, 32, kernel_size=1, bias=False),
                                    norm(32),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(32, 3, kernel_size=1, bias=False))
        self.box_filter.weight.data[...] = 1.0

    def forward(self, x_lr, y_lr, x_hr):
        _, _, h_lrx, w_lrx = x_lr.size()
        _, _, h_hrx, w_hrx = x_hr.size()

        N = self.box_filter(x_lr.data.new().resize_((1, 3, h_lrx, w_lrx)).fill_(1.0)) #添加噪声
        ## mean_x
        mean_x = self.box_filter(x_lr)/N
        ## mean_y
        mean_y = self.box_filter(y_lr)/N
        ## cov_xy
        cov_xy = self.box_filter(x_lr * y_lr)/N - mean_x * mean_y
        ## var_x
        var_x  = self.box_filter(x_lr * x_lr)/N - mean_x * mean_x

        ## A
        A = self.conv_a(torch.cat([cov_xy, var_x], dim=1))
        ## b
        b = mean_y - A * mean_x

        ## mean_A; mean_b
        mean_A = F.interpolate(A, (h_hrx, w_hrx), mode='bilinear', align_corners=True)  #上采样
        mean_b = F.interpolate(b, (h_hrx, w_hrx), mode='bilinear', align_corners=True)  #上采样

        return mean_A * x_hr + mean_b#mean_A = Ah  mean_b = bh

Dilated Conv is box_filter; pointwise convolution Block is conv_a; F(I) is implemented here.

Thank you so much for the reply