Hourout / GAN-keras

tensorflow2.x implementations of Generative Adversarial Networks.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

请问能实现一下Semi-Supervised GANs的Feature-match

King-Of-Knights opened this issue · comments

commented

RT

hi, 您能具体的指出Semi-Supervised GANs的Feature-match是哪一方面吗?
paper:https://arxiv.org/abs/1606.01583

commented

@Hourout https://arxiv.org/pdf/1606.03498.pdf Ian GoodFollow 在part 5 里面提到
"This approach introduces an interaction between G and our classifier that we do not fully understand yet, but empirically we find that optimizing G using feature matching GAN works very well for semi-supervised learning, while training G using GAN with minibatch discrimination does not work at all"
Feature-match 这个机制我不太会实现,大体上应该是把判别器的对真实图片和生成器产生图片时的最后一个特征提取层之间的差距作为生成器的损失函数,应该效果更好!

hi,主要的改动如下

改写Dnet

def discriminator(num_classes=10, image_shape=(28,28,1)):
    feature = []
    image = tf.keras.Input(shape=image_shape)
    x = tf.keras.layers.Conv2D(32, kernel_size=3, strides=2, padding="same")(image)
    x = tf.keras.layers.LeakyReLU(alpha=0.2)(x)
    feature.append(x)
    x = tf.keras.layers.Dropout(0.25)(x)
    x = tf.keras.layers.Conv2D(64, kernel_size=3, strides=2, padding="same")(x)
    x = tf.keras.layers.ZeroPadding2D(padding=((0,1),(0,1)))(x)
    x = tf.keras.layers.LeakyReLU(alpha=0.2)(x)
    feature.append(x)
    x = tf.keras.layers.Dropout(0.25)(x)
    x = tf.keras.layers.BatchNormalization(momentum=0.8)(x)
    x = tf.keras.layers.Conv2D(128, kernel_size=3, strides=2, padding="same")(x)
    x = tf.keras.layers.LeakyReLU(alpha=0.2)(x)
    feature.append(x)
    x = tf.keras.layers.Dropout(0.25)(x)
    x = tf.keras.layers.BatchNormalization(momentum=0.8)(x)
    x = tf.keras.layers.Conv2D(256, kernel_size=3, strides=1, padding="same")(x)
    x = tf.keras.layers.LeakyReLU(alpha=0.2)(x)
    feature.append(x)
    x = tf.keras.layers.Dropout(0.25)(x)
    x = tf.keras.layers.Flatten()(x)
    valid = tf.keras.layers.Dense(1, activation="sigmoid")(x)
    label = tf.keras.layers.Dense(num_classes+1, activation="softmax")(x)
    dnet = tf.keras.Model(image, [valid, label, feature])
    return dnet

改写模型

dnet= discriminator(num_classes, image_shape)
dnet.compile(loss=['binary_crossentropy', 'categorical_crossentropy'],
                 optimizer=tf.keras.optimizers.Adam(0.0002, 0.5),
                 metrics=['accuracy'])

noise = tf.keras.Input(shape=(latent_dim,))
gnet = generator(latent_dim)
dnet.trainable = False
image = gnet(noise)
valid, _, feature_fake = dnet(image)
sgan = tf.keras.Model(noise, valid)
sgan.compile(loss=['binary_crossentropy'],
                 optimizer=tf.keras.optimizers.Adam(0.0002, 0.5),
                 metrics=['accuracy'])

true_image = tf.keras.Input(shape=(28, 28, 1))
valid, _, feature_real = dnet(true_image)
fea_list = []
for i in range(4):
    x = tf.keras.layers.Subtract()([feature_fake[i], feature_real[i]])
    x = tf.keras.layers.Multiply()([x, x])
    fea_list.append(x)
x = tf.keras.layers.Lambda(lambda x:K.mean(x))(x)
aux_sgan = tf.keras.Model([noise, image], x)
aux_sgan.compile(loss=['mse'],
                 optimizer=tf.keras.optimizers.Adam(0.0002, 0.5),
                 metrics=['mae'])

改写train

d_loss_real, feature_r = dnet.train_on_batch(batch_image, ...)
d_loss_fake, feature_f = dnet.train_on_batch(batch_image_gen, ...)
...
feature_match_loss = aux_sgan.train_on_batch([batch_noise, batch_image], np.zeros((batch_size, 1)))

主要就是mse的实现,供参考

commented

@Hourout 这里的dnet不需要compile吗? 毕竟后面用了dnet.train_on_batch