naoto0804 / pytorch-inpainting-with-partial-conv

Unofficial pytorch implementation of 'Image Inpainting for Irregular Holes Using Partial Convolutions' [Liu+, ECCV2018]

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Is there a worng understand in total variation?

GuardSkill opened this issue · comments

I find this does not conform to the original paper’s method, I think the sum of the abs value should be taken into the Loss(tv), and the tv loss is not the global difference of the whole picture, it just around the hole areas (P is the region of 1-pixel dilation of the hole region).

def total_variation_loss(image):
    # shift one pixel and get difference (for both x and y direction)
    loss = torch.mean(torch.abs(image[:, :, :, :-1] - image[:, :, :, 1:])) + \
        torch.mean(torch.abs(image[:, :, :-1, :] - image[:, :, 1:, :]))
    return loss

maybe should be these?``
def total_variation_loss(image,mask):
hole_mask = 1-mask
loss = torch.sum(torch.abs(hole_mask[:, :, :, :-1](image[:, :, :, 1:] - image[:, :, :, :-1]))) +
torch.sum(hole_mask[:, :, 👎, :]
(torch.abs(image[:, :, :-1, :] - image[:, :, 1:, :])))
return loss

More seriously, it should be these code rather than above (Above code didn't consider the uppest/leftest dilated pixel minus operation)

def dialation_holes(hole_mask):
    b, ch, h, w = hole_mask.shape
    dilation_conv = nn.Conv2d(ch, ch, 3, padding=1, bias=False).to(device)
    torch.nn.init.constant_(dilation_conv.weight, 1.0)
    with torch.no_grad():
        output_mask = dilation_conv(hole_mask)
    updated_holes = output_mask != 0
    return updated_holes.float()

def total_variation_loss(image,mask):
    hole_mask = 1-mask
    dilated_holes=dialation_holes(hole_mask)
    colomns_in_Pset=dilated_holes[:, :, :, 1:] * dilated_holes[:, :, :, :-1]
    rows_in_Pset=dilated_holes[:, :, 1:, :] * dilated_holes[:, :, :-1:, :]
    loss = torch.sum(torch.abs(colomns_in_Pset*(image[:, :, :, 1:] - image[:, :, :, :-1]))) + \
        torch.sum(torch.abs(rows_in_Pset*(image[:, :, :1 :] - image[:, :, -1:, :])))
    return loss

Have you tried the code which you thought it should be? Have it brought any improvement to the result compared to the github author's ?

Hi !
@GuardSkill shouldn't it be mean instead of sum in the total_variation_loss function ?

loss = torch.mean(torch.abs(colomns_in_Pset*(image[:, :, :, 1:] - image[:, :, :, :-1]))) + \ torch.mean(torch.abs(rows_in_Pset*(image[:, :, :1 :] - image[:, :, -1:, :])))

I would argue that, while not the same exact loss as the one proposed in the paper (L_tv), the total_variation_loss() implemented here should behave in just the same way.

Both total_variation_loss() and L_tv are computed on I_comp (output_comp in the code) and not I_out (output), which contains:

  • the ground truth image I_gt (input) in the unmasked part
  • the reconstructed image I_out in in the masked part

Since the ground truth image does not change with I_out, it means that all 1-pixel shifts outside of the mask will always result in the same total variation, outside of the masked region. Inside of the masked region (as well as around the 1-pixel dilation of the mask) the TV loss will instead depend on I_out.

As such, the loss implemented here is L_tv + constant: the two functions thus share the same gradient.

It also seems to me that the current implementation is slightly more efficient, as it does not require computing the dilated mask, nor mask the image.