imlixinyang / HiSD

Official pytorch implementation of paper "Image-to-image Translation via Hierarchical Style Disentanglement" (CVPR 2021 Oral).

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

作者您好,想请问一下关于discriminator的问题

zhushuqi2333 opened this issue · comments

` class Dis(nn.Module):
def init(self, hyperparameters):
super().init()
self.tags = hyperparameters['tags']
channels = hyperparameters['discriminators']['channels']
#[64, 128, 256, 512, 1024, 2048]
self.conv = nn.Sequential(
nn.Conv2d(hyperparameters['input_dim'], channels[0], 1, 1, 0),
*[DownBlock(channels[i], channels[i + 1]) for i in range(len(channels) - 1)],
nn.AdaptiveAvgPool2d(1),
)
self.fcs = nn.ModuleList([nn.Sequential(
nn.Conv2d(channels[-1] + #2048
# ALI part which is not shown in the original submission but help disentangle the extracted style.
#ALI部分未在原始提交中显示,但有助于解耦提取到的 style。
hyperparameters['style_dim'] + #256
# Tag-irrelevant part. Sec.3.4
self.tags[i]['tag_irrelevant_conditions_dim'], #2 2 2
# One for translated, one for cycle. Eq.4
len(self.tags[i]['attributes'] * 2), 1, 1, 0), #4 4 6
) for i in range(len(self.tags))]) #这里的i控制的是三个tag里面的哪个

def forward(self, x, s, y, i):
    f = self.conv(x)
    fsy = torch.cat([f, tile_like(s, f), tile_like(y, f)], 1)
    #按照第一维度,也就是列维度,叠加起来,也就是横着串起来
    return self.fcs[i](fsy).view(f.size(0), 2, -1) `

作者你好,关于判别器我有几个不太懂的点,还希望您可以教教我

  1. 对于判别器是怎么不去改变两个无关标签我不是很理解这其中的过程

  2. 判别器的forward那边最后的.view(f.size(0), 2, -1),第一维是batch_size,第二维我不懂是什么,为啥是2,第三维是控制的属性吗,这边看不太懂

  3. 关于计算生成器的对抗损失这边,为什么真实图片取的[:,0]和[:,1]的平均之和,而两张fake图片分别取的[:,0]和[:,1]的平均?这边不太理解。代码如下:
    ` def calc_gen_loss_real(self, x, s, y, i, j):#
    loss = 0
    out = self.forward(x, s, y, i)[:, :, j]#选到那个属性
    #比如是[8, 2, 2], 截取[:,:,1] 就变成了[8, 2]了
    loss += out[:, 0].mean()
    loss += out[:, 1].mean()
    return loss

    def calc_gen_loss_fake_trg(self, x, s, y, i, j):
    out = self.forward(x, s, y, i)[:, :, j]
    loss = - out[:, 0].mean()
    return loss

    def calc_gen_loss_fake_cyc(self, x, s, y, i, j):
    out = self.forward(x, s, y, i)[:, :, j]
    loss = - out[:, 1].mean()
    return loss `
    希望您可以解答我的疑惑,谢谢作者!

  1. 关于无关条件对于解耦的帮助:鉴别器可以看到无关标签,意味着鉴别器可以进一步区分什么样的图像才符合我们的目标(例如不是男的就更像戴眼镜,而是眼镜这个特征本身让图像更像戴眼镜),也就可以促使生成器的解耦了。
  2. 鉴别器输出的第二维代表着是鉴别cycle translation过程中的翻译后图像还是cycle回来后的图像(因此是两维,因为来自于不同阶段的输出,有着不同的鉴别难度),相对来说,前者要更难一点,而后者只要重构就可以了。
  3. 与2同样,理解了2就可以理解为什么真实图片都有,但假图片是分开的了。

非常感谢您的耐心解答!关于第2,3两点,我想我大概是有点明白了,在这里复述一下,希望您可以看一下我说的有没有什么错误。
因为真实图片,既不是翻译后的,也不是cycle后的,所以要把第二维的两列数据都加起来求平均,让其最小
而翻译后的图片,只需要关注第二维度的[:, 0],把第0列数据的平均值求出来,让其得分最大,能骗过鉴别器,cycle后的图像同理。

关于无条件对于解耦的帮助,我想我可能还是有点不太明白,会不会是我预处理数据的时候出错了,我用8张图片做了个测试,您看一下,我的y值是对的吗?
image
期待您的回复!

“因为真实图片,既不是翻译后的,也不是cycle后的,所以要把第二维的两列数据都加起来求平均,让其最小”
这里有两个小错误。真实样本实际上即是翻译后假样本对应的真样本,也是cycle后假样本对应的真样本。且这里是一个先平均再加和,因为这两维实际上是完全独立的,你可以遮住一个一个看。
你的y值从格式上是对的,从原理上来说,任何你认为在翻译某个tag的时候不应该改变的图像标签都可以作为tag-无关标签丢给鉴别器(比如在改变Tag “eyeglasses”的时候引入标签“young”和“male”两个标签),它的目的是在翻译前者的时候抑制后者的变化。

关于loss的计算上应该是懂了,太感谢您的耐心指导了!!带给了我很大的帮助!!
关于鉴别器部分,从原理上我大概是明白了,输入一张图片和给定tag-无关标签,然后抽象起来看,鉴别器执行大致是以下步骤
图和条件匹配->高分
图和条件不匹配->低分
也就是学习戴眼镜的过程中,不去转变性别和年龄,也就是抑制无关标签的变化。
但是对于代码的实现我还是不太明白,这是我生成的eyeglasses_with.txt文件,无关标签是一张图片对应的male和young真实值,但是鉴别器看到的只是 [-1,1]两个数字,想请教一下,鉴别器是怎么知道这两个数字具体对应哪两个属性的,期待您的回复!
image

不客气,一般来说,如果按照原始脚本的话,“-1 1”代表的就是male和young的属性,例如第一个数是-1即代表这是女性。这个具体的你可以参考这里。其中21和40就代表着相应的标签,你也可以自行进行更改或增减哈。

好的,我大概是明白了,谢谢您的耐心回复(~ ̄▽ ̄)~