openai / improved-gan

Code for the paper "Improved Techniques for Training GANs"

Home Page:https://arxiv.org/abs/1606.03498

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Get confused on how mini batch is implemented in tensorflow

mlzxy opened this issue · comments

What is f2? I think f1 is the whole thing based on the definition in the paper. Moreover, the second one will give a empty slice, which will give a exception in tensorflow from my experience (not tried yet).

    def half(tens, second):
        m, n, _ = tens.get_shape()
        m = int(m)
        n = int(n)
        return tf.slice(tens, [0, 0, second * self.batch_size], [m, n, self.batch_size])

    f1 = tf.reduce_sum(half(masked, 0), 2) / tf.reduce_sum(half(mask, 0))
    f2 = tf.reduce_sum(half(masked, 1), 2) / tf.reduce_sum(half(mask, 1))

Thank you for clarification.

Hi @benbbear , I'm afraid this answer might be a bit late but f2, as far as I understand from the author's reference (https://arxiv.org/abs/1701.00160) and reading the code, is the minibatch discrimination set of distances for the fake part of the data. If you observe at the beginning of the code there is:

batch_size = int(image.get_shape()[0])
assert batch_size == 2 * self.batch_size

which means that for every minibatch of real data there is also an appended minibatch of fake data, such that distances to both types can be computed for a discriminator inference. I paste the author's description from the referenced tutorial for clarification (pg. 36): "The basic idea of minibatch features is to allow the discriminator to compare an example to a minibatch of generated samples and a minibatch of real samples. By measuring distances to these other samples in latent spaces, the discriminator can detect if a sample is unusually similar to other generated samples".

is there any available code for minibatch discriminator in TensorFlow?