keras-team / keras

Deep Learning for humans

Home Page:http://keras.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Unnecessary Warning for trainable parameters

mercy0387 opened this issue · comments

Considering GAN model like below

discriminator = get_discriminator_model()
discriminator.compile(loss='categorical_crossentropy', optimizer='adam')
discriminator.trainable = False

generator = get_generator_model()
gan = Sequential([generator, discriminator])
gan.compile(loss='categorical_crossentropy', optimizer='adam')

I found Warning that says /path/to/keras/engine/training.py:973: UserWarning: Discrepancy between trainable weights and collected trainable weights, did you set `model.trainable` without calling `model.compile` after ? 'Discrepancy between trainable weights and collected trainable' when the first calling discriminator.train_on_batch .

I know this warning was not appeared before (like version 2.0.2) . I updated keras version to avoid this ( #8121 ) problem.
That's resolved but this problem is occurred.

I understand this warning is meaningful when not-trainable parameters that is not intended. However, some models (like GAN) intend to have not-trainable parameters. I think there should be some conditions to decide not-trainable parameters are intended or not.

I introduced the warning to fix #8121, so this might be my fault.

Do you have a minimal reproducible example where you get the warning ?

What does you training code look like ? I think if you set discriminator.trainable = False and then want to call discriminator.train_on_batch, you should call discriminator.compile in between to make sure the train function gets built based on the trainable parameters in your model.

@julienr Thank you for reply.
I'm sorry for insufficient code. This is reproducible example.

# making discriminator
d_input = Input(shape=(2,))
d_output = Activation('softmax')(Dense(2)(d_input))
discriminator = Model(inputs=d_input, outputs=d_output)
discriminator.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['categorical_accuracy'])

# making generator
g_input = Input(shape=(2,))
g_output = Activation('relu')(Dense(2)(g_input))
generator = Model(inputs=g_input, outputs=g_output)

# making gan(generator -> discriminator)
discriminator.trainable = False
gan = Model(inputs=g_input, outputs=discriminator(g_output))
gan.compile(loss='categorical_crossentropy', optimizer='adam')

# training
BATCH_SIZE = 3
some_input_data = np.array([[1,2],[3,4],[5,6]])
some_target_data = np.array([[1,1],[2,2],[3,3]])
# update discriminator
generated = generator.predict(some_input_data, verbose=0)
X = np.concatenate((some_target_data, generated), axis=0)
y = [[0,1]]*BATCH_SIZE + [[1,0]]*BATCH_SIZE
d_metrics = discriminator.train_on_batch(X, y)
# update generator
g_metrics = gan.train_on_batch(some_input_data, [[0,1]]*BATCH_SIZE)
# loop these operations for batches...

Let me know if I have wrong way to train just generator part when gan.train_on_batch .

Thanks for the example code. I have a question though : if discriminator.trainable is False, will d_metrics = discriminator.train_on_batch(X, y) update anything ?


Regardless, this made me think about the warning introduced to fix #8121 :

Currently, the Model class populates Model._collected_trainable_weights from Model.trainable_weights in Model.compile. As far as I can see, Model._collected_trainable_weights is only used in Model._make_train_function.

Can we remove Model._collected_trainable_weights alltogether and just get the trainable weights in Model._make_train_function ? I think this would actually remove the need for the consistency check and the warning.

I could do this, what do you think @fchollet ?

@julienr Thank you for your taking this problem!

For your question, I can see discriminator's weights can be updated when train_on_batch is called after discriminator.trainable is False . Probably because discriminator.compile is called before trainable is changed.

About your ideas,
I think if Model._collected_trainable_weights is collected in Model._make_train_function as you say, my code cannot train discriminator because discriminator has no trainable weights when Model._make_train_function is called.

It's too difficult for me to solve this problem, so I'm really grateful for your help :)

@julienr I found the work around for this. I can use keras.engine.topology.Container to make non-trainable discriminator like this.

# making discriminator
d_input = Input(shape=(2,))
d_output = Activation('softmax')(Dense(2)(d_input))
discriminator = Model(inputs=d_input, outputs=d_output)
discriminator.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['categorical_accuracy'])

discriminator_fixed = Container(inputs=d_input, outputs=d_output)    # -> I Add THIS CONTAINER!

# making generator
g_input = Input(shape=(2,))
g_output = Activation('relu')(Dense(2)(g_input))
generator = Model(inputs=g_input, outputs=g_output)

# making gan(generator -> discriminator)
discriminator_fixed.trainable = False    # -> change trainable of CONTAINER
gan = Model(inputs=g_input, outputs=discriminator_fixed(g_output))     # -> use CONTAINER
gan.compile(loss='categorical_crossentropy', optimizer='adam')

# training
BATCH_SIZE = 3
some_input_data = np.array([[1,2],[3,4],[5,6]])
some_target_data = np.array([[1,1],[2,2],[3,3]])
# update discriminator
generated = generator.predict(some_input_data, verbose=0)
X = np.concatenate((some_target_data, generated), axis=0)
y = [[0,1]]*BATCH_SIZE + [[1,0]]*BATCH_SIZE
d_metrics = discriminator.train_on_batch(X, y)
# update generator
g_metrics = gan.train_on_batch(some_input_data, [[0,1]]*BATCH_SIZE)
# loop these operations for batches...

Running this code, warning disappeared!

And after that,

from keras import backend as K
print([K.eval(w) for w in discriminator.weights])
print([K.eval(w) for w in discriminator_fixed.weights])

these are the same values.

However, I don't know whether this is a good way. If I have no need to increase unnecessary variables (like discriminator_fixed ), it's ideal.

@mercy0387 Why don't you call discriminator.compile() after discriminator.trainable = False? Would it cause any problems to your code? I have the same warning as you did.

@lionlai1989 I want to update discriminator's weights just when discriminator.fit() is called but when gan.fit() is called.

If I call discriminator.compile() after discriminator.trainable = False, discriminator's weights are not updated even when discriminator.fit() is called.

That's I understanding. Let me know if your understanding is different from me.

Hi,

I have the same issue. My case is like described below:

  • Create discriminator model
  • Compile the model (at this point, it's trainable - default)
  • Create generator model (do not compile it)
  • Create adversarial model (Sequential)
    • add generator
    • add discriminator
  • Compile the model
  • Loop through my epochs
  • Set discriminator model to trainable True
  • Train discriminator with call to train_on_batch(X, y)
  • Set discriminator model to trainable False
  • Train adversarial model with call to train_on_batch(X_noise, y_noise)
  • End of the loop

Is this warning causing issues during training or is it just a warning?

Thanks in advance.

Cheers,
Wilder

I think that the trainable is only effective after calling compile. And the indeed trainable status changing would not be effective without new compile.

So if you set discriminator.trainable = False, and call GAN.compile, and then set discriminator.trainable = True and call discriminator.compile, indeed GAN.train_on_batch will not change discriminator's weights while discriminator.train_on_batch will.

This can be a rather misleading warning. To programatically verify that your model is training as expected in spite of the warning, you can do the following:

Assuming that you have a GAN setup where you

  1. Create and compile the discriminator
  2. Create the generator
  3. Set discriminator.trainable = False
  4. Create a combined model from the generator and the discriminator and compile it

check for the following conditions:

  • Save n_disc_trainable = len(discriminator.trainable_weights) after step 1 from above and make sure len(discriminator._collected_trainable_weights) == n_disc_trainable after step 4 from above.
  • Save n_gen_trainable = len(generator.trainable_weights) after step 2 from above and make sure len(combined._collected_trainable_weights) == n_gen_trainable after step 4 from above.

Keras uses the Model._collected_trainable_weights property during training and this one only changes if Model.compile is called (it is only at this point that the value of Model.trainable matters).

i get same warning too.but i figure out it ,by calling discriminator.compile() after setting trainabel, the step follows:

  • ...
  • discriminator.trainable = False
  • discriminator.compile(...)
  • ...
    when i forget call compile ,the discriminator loss is higer!
    however when i call compile , the loss begin to decrease

@abiro I agree this seems to be annoying for GAN users. The warning was introduced to fix #8121 to warn users that they have to call compile again after changing .trainable. I think it make sense to have this warning because this is an otherwise surprising behaviour for new users.

If you know what you are doing (like implementing a GAN), you can either ignore the warning or use @mercy0387 solution of using a Container.

I'm not sure how we can fix this for GAN users while retaining the warning for new users. If anybody has ideas, I'm open to discuss them.

It seems to me the root of the problem is that you have to specify .trainable at the model level, but a model can be included in multiple other models and you would rather want to be able to specify .trainable explicitely for each .train() invocation. One way to achieve this would be to add a trainable=[<list of models/layers>] argument to the train function, but that's a big change. What do you think ?

Do we have to set D.trainable=False after loading the Model ? @julienr

keras.engine.topology.Container appears to be deprecated. Where can we find the replacement?

@ColinConwell I found alternative to Container here #10023 .

Network is it. But I think it is too hard to find this name by Google because it’s a general word...

@mercy0387 How do you make sure that the weights in the 'gan' model are being updated after the train_on_batch of the discriminator? I tried the K.eval() but it's really hard to see in this way...

Is there a function that keeps the two models (the trainable model and the container/network model) connected, so that when I train the model, the container will be updated?

I have a code as follow

` # get models for discriminator and segmentor, the fixed models are Networks(keras class) objects
discriminator,discriminator_fixed = model.get_discriminator()
segmentor,segmentor_fixed = model.get_segmentor()

##model that trains the discriminator maximizing the MAE loss
segmentor_fixed.trainable = False
predictions_patch_fixed = segmentor_fixed(image_patch)
output_error_disc = discriminator(image_patch,predictions_patch_fixed,ground_truth_patch)
combined_discriminator = Model(inputs=[image_patch,ground_truth_patch],outputs = output_error_disc)
combined_discriminator.compile(optimizer = dopt, loss = neg_logcosh_disc) ##neg_out_mae

##model that train the segmentor minimizing the MAE loss and the dice loss from the predictions
discriminator_fixed.trainable = False
predictions_patch = segmentor(image_patch)
output_error_seg = discriminator_fixed([image_patch,predictions_patch,ground_truth_patch])
combined_segmentor = Model(inputs=[image_patch,ground_truth_patch],outputs = [output_error_seg,predictions_patch])
combined_model.compile(optimizer = opt, loss = combined_losses, loss_weights = combined_weights, metrics = combined_losses)`

and I would like to train both models in the GAN fashion, first the discriminator maximizing the loss and then the generator/segmentor to minimize the loss.

Thanks,
Pedro

Thanks to @mercy0387 and @abiro's suggestions, this worked for me without getting a warning

# Build discriminator
inputs_dsc, outputs_dsc = construct_discriminator()

discriminator = keras.models.Model(
    inputs_dsc,
    outputs_dsc,
    name='discriminator'
)

# Compile discriminator
discriminator.compile(
    loss='binary_crossentropy',
    optimizer=self.optimizer,
    metrics=['accuracy']
)

# Build "frozen discriminator"
frozen_discriminator = keras.engine.network.Network(
    inputs_dsc,
    outputs_dsc,
    name='frozen_discriminator'
)

frozen_discriminator.trainable = False

# Debug 1/3: discriminator weights
n_disc_trainable = len(discriminator.trainable_weights)

# Build generator
inputs_gen, outputs_gen = construct_generator()

generator = keras.models.Model(
    inputs_gen,
    outputs_gen,
    name='generator'
)

# Debug 2/3: generator weights
n_gen_trainable = len(generator.trainable_weights)

# Build adversarial model
adversarial_model = keras.models.Model(
    inputs_gen,
    frozen_discriminator(outputs_gen),
    name='adversarial_model'
)

# Compile adversarial model
adversarial_model.compile(
    loss='binary_crossentropy',
    optimizer=Adam(),
    metrics=['accuracy']
)

# Debug 3/3: compare if trainable weights correct
assert(len(discriminator._collected_trainable_weights) == n_disc_trainable)
assert(len(adversarial_model._collected_trainable_weights) == n_gen_trainable)

Training sequence is the same as @mercy0387's.

For a WGAN, it seems I also obtain normal results by using keras.engine.network.Network and without warnings (still need to check the results thoroughly).

I also think this is the neatest way to do it since you actually want to put the architecture of a generator below the discri and the architecture of a discri above the generator. Thus, you indeed want a Network instance, not a Model instance.

Concerning the warning itself, its presence assumes that trainable must be set before calling compile, whereas trainable is not an argument of the compile function.
Thus, either you could make it an argument and then you ensure the consistency within compile and the warning becomes obsolete.
Or, you could just keep the original flexibility of trainable being independent of compile by removing the warning. (while putting an example in the doc to clarify this point for newcomers).

I would rather to disable all warnings than showing this annoying warning. The disturbing thing is that the stupid warning shows for every batch!

I tried some online searched methods but nothing works for this stupid warning. Is there a way to disable all warnings? I do not care about any warnings; that unnecessary warning already disturbs my work too much. Thank you.

Currently, I can only change the source code to remove this warning.

I find one solution,

base_discriminator = get_discriminator_model()
base_generator     = get_generator_model()
########
discriminator
########
discriminator      = Model(inputs=base_discriminator.inputs, outputs=base_discriminator.outputs)
discriminator.compile(loss=losses,optimizer=optimizer,metrics=['accuracy'])

########
GAN
########
generator          = Model(inputs=base_generator.inputs,outputs=base_generator.outputs)
frozen_D           = Model(inputs=base_discriminator.inputs, outputs=base_discriminator.outputs)
frozen_D.trainable = False

I am getting an error which I believe is somehow related to this conversation, would you please advise? https://stackoverflow.com/questions/55567933/invalidargumenterror-error-when-using-tensorflow-gpu-2-0-0a0-to-build-tf-keras

Does model.summary display the collected trainable weights or the standard trainable weights?

It's unclear whether your model is actually trainable or not based upon the summary (for example, calling trainable after compiling affects the summary).

Does this not actually affect the trainable weights in the model?

I met the same problem even I compile model after setting trainable.
Anyway, I save all weights to my debug log to check whether weights have been updated(frozen) or not , for my code ,my debug log proves this warning message is wrong, the weights actually have been update(frozen) as my code.
Please modify this disturbing warning message.

This will disable the warning. It's a temporary workaround but it doesn't involve any major changes like running your model in a container.

def _check_trainable_weights_consistency(self):
    return
keras.Model._check_trainable_weights_consistency = _check_trainable_weights_consistency

can any one give me a simple solution how to write those gan compile and discriminator trainable?

The use of a Network worked for me, it got rid of the warnings. I was also able to check to see if the weights (as tensors) were the same before and after compiling of the models.

I also did a check of the weight values before and after each separate train_on_batch, from this I could assert that the correct model's weights are being updated for each of the separate train_on_batch steps.

I'm also working on a WGAN-GP which uses a critic model built up from the generator and discriminator. Doing this has helped me to know that the critic isn't updating the generator loss, and when I train the gan only the generator weights are updated and that the discriminator weights aren't.

To further improve on veya2ztn's response:

base_discriminator: Model = self.build_discriminator ()
base_generator: Model = self.build_generator ()

discriminator = Model (inputs = base_discriminator.inputs, outputs = base_discriminator.outputs, name = 'discriminator')
discriminator.compile (loss = 'binary_crossentropy', optimizer = self.optimizer, metrics = ['accuracy'])

generator = Model (inputs = base_generator.inputs, outputs = base_generator.outputs, name = 'generator')
frozen_discriminator = Model (inputs = base_discriminator.inputs, outputs = base_discriminator.outputs)
frozen_discriminator.trainable = False

discriminator_trainable_weights = len (discriminator.trainable_weights)  # for asserts, below
generator_trainable_weights = len (generator.trainable_weights)

noise_input = Input (shape = self.NOISE_SHAPE)
image = generator (noise_input)
validity_output = frozen_discriminator (image)

combined = Model (inputs = noise_input, outputs = validity_output, name = 'combined')
combined.compile (loss = 'binary_crossentropy', optimizer = self.optimizer)

assert (len (discriminator._collected_trainable_weights) == discriminator_trainable_weights)
assert (len (combined._collected_trainable_weights) == generator_trainable_weights)

This will disable the warning. It's a temporary workaround but it doesn't involve any major changes like running your model in a container.

def _check_trainable_weights_consistency(self):
    return
keras.Model._check_trainable_weights_consistency = _check_trainable_weights_consistency

Saved my sanity

I am getting the same warning: is this something to worry

/home/aindani/.conda/envs/tf_env_cpu/lib/python3.7/site-packages/keras/engine/training.py:297: UserWarning: Discrepancy between trainable weights and collected trainable weights, did you set model.trainable without calling model.compile after ?
'Discrepancy between trainable weights and collected trainable'

I am getting the same warning: is this something to worry

/home/aindani/.conda/envs/tf_env_cpu/lib/python3.7/site-packages/keras/engine/training.py:297: UserWarning: Discrepancy between trainable weights and collected trainable weights, did you set model.trainable without calling model.compile after ?
'Discrepancy between trainable weights and collected trainable'

Nothing to worry about. See issue title, "unnecessary warning". You can disable this warning by running this code

def _check_trainable_weights_consistency(self):
    return
keras.Model._check_trainable_weights_consistency = _check_trainable_weights_consistency