akanimax / pro_gan_pytorch

Unofficial PyTorch implementation of the paper titled "Progressive growing of GANs for improved Quality, Stability, and Variation"

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

If you don't mind, please let me know what environment you are using.

H1R0Y4 opened this issue · comments

Hi akanimax,

I ran your code in my environment (torch 1.4.0, torchvision 0.5.0) and it didn't work due to a bug.
So I would like to know what environment you are using. If you don't mind, please let me know in requirements.txt or something like that.

import torch as th
import torchvision as tv
import pro_gan_pytorch.PRO_GAN as pg

# select the device to be used for training
device = th.device("cuda" if th.cuda.is_available() else "cpu")
data_path = "cifar-10/"

def setup_data(download=False):
    """
    setup the CIFAR-10 dataset for training the CNN
    :param batch_size: batch_size for sgd
    :param num_workers: num_readers for data reading
    :param download: Boolean for whether to download the data
    :return: classes, trainloader, testloader => training and testing data loaders
    """
    # data setup:
    classes = ('plane', 'car', 'bird', 'cat', 'deer',
               'dog', 'frog', 'horse', 'ship', 'truck')

    transforms = tv.transforms.ToTensor()

    trainset = tv.datasets.CIFAR10(root=data_path,
                                   transform=transforms,
                                   download=download)

    testset = tv.datasets.CIFAR10(root=data_path,
                                  transform=transforms, train=False,
                                  download=False)

    return classes, trainset, testset


if __name__ == '__main__':

    # some parameters:
    depth = 4
    # hyper-parameters per depth (resolution)
    num_epochs = [10, 20, 20, 20]
    fade_ins = [50, 50, 50, 50]
    batch_sizes = [128, 128, 128, 128]
    latent_size = 128

    # get the data. Ignore the test data and their classes
    _, dataset, _ = setup_data(download=True)

    # ======================================================================
    # This line creates the PRO-GAN
    # ======================================================================
    pro_gan = pg.ConditionalProGAN(num_classes=10, depth=depth, 
                                   latent_size=latent_size, device=device)
    # ======================================================================

    # ======================================================================
    # This line trains the PRO-GAN
    # ======================================================================
    pro_gan.train(
        dataset=dataset,
        epochs=num_epochs,
        fade_in_percentage=fade_ins,
        batch_sizes=batch_sizes
    )
    # ======================================================================  
Files already downloaded and verified
Starting the training process ... 


Currently working on Depth:  0
Current resolution: 4 x 4

Epoch: 1
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-1-c94a729137e7> in <module>
     59         epochs=num_epochs,
     60         fade_in_percentage=fade_ins,
---> 61         batch_sizes=batch_sizes
     62     )
     63     # ======================================================================

~/pro_gan_pytorch/pro_gan_pytorch/PRO_GAN.py in train(self, dataset, epochs, batch_sizes, fade_in_percentage, start_depth, num_workers, feedback_factor, log_dir, sample_dir, save_dir, checkpoint_factor)
   1044                     # optimize the discriminator:
   1045                     dis_loss = self.optimize_discriminator(gan_input, images,
-> 1046                                                            labels, current_depth, alpha)
   1047 
   1048                     # optimize the generator:

~/pro_gan_pytorch/pro_gan_pytorch/PRO_GAN.py in optimize_discriminator(self, noise, real_batch, labels, depth, alpha)
    863 
    864             loss = self.loss.dis_loss(real_samples, fake_samples,
--> 865                                       labels, depth, alpha)
    866 
    867             # optimize discriminator

~/pro_gan_pytorch/pro_gan_pytorch/Losses.py in dis_loss(self, real_samps, fake_samps, labels, height, alpha)
    343     def dis_loss(self, real_samps, fake_samps, labels, height, alpha):
    344         # define the (Wasserstein) loss
--> 345         fake_out = self.dis(fake_samps, labels, height, alpha)
    346         real_out = self.dis(real_samps, labels, height, alpha)
    347 

~/.pyenv/versions/3.7.5/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    530             result = self._slow_forward(*input, **kwargs)
    531         else:
--> 532             result = self.forward(*input, **kwargs)
    533         for hook in self._forward_hooks.values():
    534             hook_result = hook(self, input, result)

~/.pyenv/versions/3.7.5/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py in forward(self, *inputs, **kwargs)
    150             return self.module(*inputs[0], **kwargs[0])
    151         replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
--> 152         outputs = self.parallel_apply(replicas, inputs, kwargs)
    153         return self.gather(outputs, self.output_device)
    154 

~/.pyenv/versions/3.7.5/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py in parallel_apply(self, replicas, inputs, kwargs)
    160 
    161     def parallel_apply(self, replicas, inputs, kwargs):
--> 162         return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
    163 
    164     def gather(self, outputs, output_device):

~/.pyenv/versions/3.7.5/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py in parallel_apply(modules, inputs, kwargs_tup, devices)
     83         output = results[i]
     84         if isinstance(output, ExceptionWrapper):
---> 85             output.reraise()
     86         outputs.append(output)
     87     return outputs

~/.pyenv/versions/3.7.5/lib/python3.7/site-packages/torch/_utils.py in reraise(self)
    392             # (https://bugs.python.org/issue2651), so we work around it.
    393             msg = KeyErrorMessage(msg)
--> 394         raise self.exc_type(msg)

RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/home/USER/.pyenv/versions/3.7.5/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 60, in _worker
    output = module(*input, **kwargs)
  File "/home/USER/.pyenv/versions/3.7.5/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/USER/pro_gan_pytorch/pro_gan_pytorch/PRO_GAN.py", line 305, in forward
    out = self.final_block(y, labels)
  File "/home/USER/.pyenv/versions/3.7.5/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/USER/pro_gan_pytorch/pro_gan_pytorch/CustomLayers.py", line 445, in forward
    labels = self.label_embedder(labels)  # [B x C]
  File "/home/USER/.pyenv/versions/3.7.5/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/USER/.pyenv/versions/3.7.5/lib/python3.7/site-packages/torch/nn/modules/sparse.py", line 114, in forward
    self.norm_type, self.scale_grad_by_freq, self.sparse)
  File "/home/USER/.pyenv/versions/3.7.5/lib/python3.7/site-packages/torch/nn/functional.py", line 1484, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: diff_view_meta->output_nr_ == 0 INTERNAL ASSERT FAILED at /pytorch/torch/csrc/autograd/variable.cpp:326, please report a bug to PyTorch. 

Now, I tried running the code with device=cpu and it worked fine. I don't think it's a problem with your code, it's a problem on the DataParallel side. If you know of a good solution, I would appreciate it if you could share it with me.