About the G loss.
zeakey opened this issue · comments
I get confused about the G
loss.
In the code of Wasserstein(https://github.com/martinarjovsky/WassersteinGAN/blob/f81eafd2aa41e93698f203732f8f395abc70be02/main.py#L212) the author use
errG = netD(fake)
where fake = netG(z)
.
However, in your implementation, the G
loss is
gen_loss = -discriminator(generator(z)).mean().
Theoretically, I believe that the G
loss should be -D(G(z))
because the G
is expected to be able
to 'cheat' the D
.
I think these are the same, except for a choice of sign. In that WGAN code, we have the discriminator loss
errD = errD_real - errD_fake
and the generator loss
errG = netD(fake)
whereas in this code we use
disc_loss = -discriminator(data).mean() + discriminator(generator(z)).mean()
and
gen_loss = -discriminator(generator(z)).mean()
The sign is flipped in both loss functions, but the overall effect is the same.
It's important to note that in WGAN, we don't use cross-entropy loss. Therefore, the losses are invariant to a sign flip as long as we perform the sign flip consistently on the discriminator and generator loss functions.
Yes I understand what you mean.
However, intuitively I think your formulation is more strightforward where for true images f(x) is larger
and for generated images f(x) is smaller.