kekmodel / MPL-pytorch

Unofficial PyTorch implementation of "Meta Pseudo Labels"

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Training set should shuffle?

ifsheldon opened this issue · comments

In here, when loading data, it seems we should shuffle the data.

def cifar10_train(params, batch_size=None):
  """Load CIFAR-10 data."""
  shuffle_size = batch_size * 16

  filenames = [os.path.join(CIFAR_PATH, 'train.bin')]
  record_bytes = 1 + (3 * 32 * 32)
  dataset = tf.data.FixedLengthRecordDataset(filenames, record_bytes)
  dataset = dataset.map(
      lambda x: _cifar10_parser(params, x, training=True),
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
  dataset = dataset.shuffle(shuffle_size).repeat()
  dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
  dataset = dataset.batch(batch_size=batch_size, drop_remainder=True)
  dataset = _optimize_dataset(dataset)

  return dataset

In your code, you didn't shuffle the dataset when loading it,

labeled_loader = DataLoader(
        labeled_dataset,
        sampler=train_sampler(labeled_dataset),
        batch_size=args.batch_size,
        num_workers=args.workers,
        drop_last=True)  # the default of `shuffle` is False

I don't really know if I am quoting the wrong TF code block of loading data, since I don't know much about TF 1.x and I know you've shuffled the labeled dataset in x_u_split(), so I just post my question here and it may not be an issue.

I think it makes sense to set shuffle of at least one of labeled_loader and unlabeled_loader to be True, so we can guarantee that we don't have epochs that have same series of batches.

This is because DistributedDataParallel (DDP) is used.

When you use DDP, you shouldn't initialize dataloader with shuffle=True, you should initialize the dataloader with DistributedSampler instead.
As shown here.

train_sampler = RandomSampler if args.local_rank == -1 else DistributedSampler

I see, then it just implementation difference. Thank you!