Data augmentation and data shuffling problem?
fjparrado opened this issue · comments
Hi,
I have been trying to train the cityscapes dataset but I was not able to reproduce the results. I am using TF1.12.0
Debugging the training code, I noticed the problem resides in the functionnext_batch(self, batch_size)
in the cityscapes_tf_io.py
file.
If I skip the data augmentation and the shuffle function, the images in the batch are fine:
with tf.name_scope('input_tensor'):
# TFRecordDataset opens a binary file and reads one record at a time.
# `tfrecords_file_paths` could also be a list of filenames, which will be read in order.
dataset = tf.data.TFRecordDataset(tfrecords_file_paths)
# The map transformation takes a function and applies it to every element
# of the dataset.
dataset = dataset.map(
map_func=aug.decode,
num_parallel_calls=CFG.DATASET.CPU_MULTI_PROCESS_NUMS
)
# if self._dataset_flags == 'train':
# dataset = dataset.map(
# map_func=aug.preprocess_image_for_train,
# num_parallel_calls=CFG.DATASET.CPU_MULTI_PROCESS_NUMS
# )
# elif self._dataset_flags == 'val':
# dataset = dataset.map(
# map_func=aug.preprocess_image_for_val,
# num_parallel_calls=CFG.DATASET.CPU_MULTI_PROCESS_NUMS
# )
# The shuffle transformation uses a finite-sized buffer to shuffle elements
# in memory. The parameter is the number of elements in the buffer. For
# completely uniform shuffling, set the parameter to be the same as the
# number of elements in the dataset.
# dataset = dataset.shuffle(buffer_size=512)
# repeat num epochs
dataset = dataset.repeat(self._epoch_nums)
dataset = dataset.batch(batch_size=batch_size, drop_remainder=True)
dataset = dataset.prefetch(buffer_size=batch_size * 16)
iterator = dataset.make_one_shot_iterator()
However, if I uncomment the shuffle function, the "pairs" are:
Using the original code (uncommenting also the data augmentation lines), the transformations of the RGB image and labeled image are not the same
This is the code I used to visualize the images:
A = iterator.get_next(name='{:s}_IteratorGetNext'.format(self._dataset_flags))
with tf.Session() as sess: X = A[0].eval(session=sess)
with tf.Session() as sess: Y = A[1].eval(session=sess)
x = np.zeros((256,512,3))
x[:,:] = Y[10]/125
plt.imshow(X[10]*0.5+0.5);plt.show();plt.imshow(x);plt.show()
Maybe I am doing something wrong, but I can not figure it out... Do you have any suggestion?
Regards
@dragonvenenoso I will check it as soon as possiable:)
I am sorry, no need, I am stupid :)
The code to visualize the images should be:
A = iterator.get_next(name='{:s}_IteratorGetNext'.format(self._dataset_flags))
with tf.Session() as sess:
X = sess.run(tf.tuple(A))
x = np.zeros((256,512,3))
x[:,:] = X[1][13]/125
plt.imshow(X[0][13]*0.5+0.5);plt.show();plt.imshow(x);plt.show()
And then everything is correct.
Edit: I trained the model again and it works perfectly. Thanks a lot for sharing your code!
@dragonvenenoso ok:)