Issue with 05_01_cyclegan_train using TensorFlow 2.0
braxtonj opened this issue · comments
Trying to run CycleGAN to "Paint Like Monet" (I went with Ukiyoe however for the dataset, not Monet) and am running into issues when we save the model. I get the following error during training while executing self.*.save( ... ) inside CycleGAN.save_model (where * is combined, g_BA or g_AB):
NotImplementedError Traceback (most recent call last)
in ()
5 , test_B_file = TEST_B_FILE
6 , batch_size=BATCH_SIZE
----> 7 , sample_interval=PRINT_EVERY_N_BATCHES)
10 frames
/tensorflow-2.1.0/python3.6/tensorflow_core/python/keras/engine/base_layer.py in get_config(self)
497 # or that get_config
has been overridden:
498 if len(extra_args) > 1 and hasattr(self.get_config, '_is_default'):
--> 499 raise NotImplementedError('Layers with arguments in __init__
must '
500 'override get_config
.')
501 return config
NotImplementedError: Layers with arguments in __init__
must override get_config
.
Attached is the ipynb file used to train the model (just change *.txt to *.ipynb)
05_02_cyclegan_train_ukiyoe2photo.txt
You can find the solution here: https://stackoverflow.com/questions/50677544/reflection-padding-conv2d
You'll need to add the 'get_config()' super method to the ReflectionPadding2D Class. It worked for me.
For tensorflow_2 branch, added following code to cycleGAN.py before CycleGAN class definition:
import tensorflow as tf
from tensorflow.keras.layers import Layer
class ReflectionPadding2D(Layer):
def __init__(self, padding=(1, 1), **kwargs):
self.padding = tuple(padding)
self.input_spec = [InputSpec(ndim=4)]
super(ReflectionPadding2D, self).__init__(**kwargs)
def compute_output_shape(self, s):
if s[1] == None:
return (None, None, None, s[3])
return (s[0], s[1] + 2 * self.padding[0], s[2] + 2 * self.padding[1], s[3])
def call(self, x, mask=None):
w_pad, h_pad = self.padding
return tf.pad(x, [[0, 0], [h_pad, h_pad], [w_pad, w_pad], [0, 0]], 'REFLECT')
def get_config(self):
config = super(ReflectionPadding2D, self).get_config()
# print(config)
return config
You also have to import below.
from tensorflow.keras.layers import InputSpec