akanimax / pro_gan_pytorch

Unofficial PyTorch implementation of the paper titled "Progressive growing of GANs for improved Quality, Stability, and Variation"

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

WGAN_GP Discriminator loss

ArtjomUEA opened this issue · comments

For WGAN_GP, the Discriminator's loss is calculated as follows:
loss = (th.mean(fake_out) - th.mean(real_out) + (self.drift * th.mean(real_out ** 2)))

Could you explain please what the last bit (self.drift * th.mean(real_out ** 2)) does and where it comes from ? I could not find any information either in the paper or the presentation. Thank you in advance !

@ArtjomUEA,

Sure, The term you highlighted is the drift penalty introduced (or rather used) in the paper. Please refer to the last few lines of the last paragraph of section A.1 in the supplementary material of the paper. It says:

Additionally, we introduce a fourth term into the discriminator loss with an extremely small weight to keep the discriminator output from drifting too far away from zero ...

My implementation is a complete reproduction of all techniques described in the paper. Please feel free to ask if you have any more questions. Suggestions / Feedback / Contribution is highly welcome.

Best regards,
@akanimax

@akanimax,

Found it, thank you very much !!