WGAN_GP Discriminator loss
ArtjomUEA opened this issue · comments
For WGAN_GP, the Discriminator's loss is calculated as follows:
loss = (th.mean(fake_out) - th.mean(real_out) + (self.drift * th.mean(real_out ** 2)))
Could you explain please what the last bit (self.drift * th.mean(real_out ** 2))
does and where it comes from ? I could not find any information either in the paper or the presentation. Thank you in advance !
Sure, The term you highlighted is the drift penalty introduced (or rather used) in the paper. Please refer to the last few lines of the last paragraph of section A.1
in the supplementary material of the paper. It says:
Additionally, we introduce a fourth term into the discriminator loss with an extremely small weight to keep the discriminator output from drifting too far away from zero ...
My implementation is a complete reproduction of all techniques described in the paper. Please feel free to ask if you have any more questions. Suggestions / Feedback / Contribution is highly welcome.
Best regards,
@akanimax