This is the PyTorch implementation of Wasserstein GAN (WGAN), a generative adversarial network (GAN) variant that uses Wasserstein distance instead of Jensen-Shannon divergence to measure the similarity between the generated and real data distributions. This code generates new images that resemble the images in a given dataset.
- Python 3.x
- PyTorch 1.7 or higher
- torchvision
- numpy
- matplotlib
The test()
function in the code can be used to check if the Discriminator and Generator models have the correct output shapes. To run the code for generating new images using WGAN, first download the dataset. The code uses the CelebA dataset in this implementation. The dataset can be downloaded from here.
After downloading and extracting the dataset, update the image_path
variable in the code with the path to the directory containing the image files.
Next, adjust the hyperparameters, such as learning rate, batch size, number of epochs, etc., to desired values.
Run the code. The Generator model will generate new images during the training process, and the Discriminator model will evaluate how similar the generated images are to the real images. The generated images will be saved in the samples
directory.
- This function tests the Discriminator and Generator models by generating a random batch of images and a random batch of noise and checking that the output of the models has the correct shape.
device
: device to use for training, either "cuda" if available, otherwise "cpu"learning_rate
: learning rate for the optimizerbatch_size
: batch size for trainingimage_size
: size of the imagechannels_img
: number of channels in the imagenoise_dim
: dimension of the noise vectornum_epochs
: number of epochs to train forfeatures_disc
: number of features in the Discriminatorfeatures_gen
: number of features in the Generatorcritic_iterations
: number of times to train the Discriminator before training the Generatorweight_clip
: maximum absolute value of weights in the Discriminatorimage_path
: path to the directory containing the image files
loader
: dataloader for the training setgenerator
: the Generator modeldiscriminator
: the Discriminator modelinitialize_weights
: initializes the weights of the modelsoptimizer_generator
: optimizer for the Generator modeloptimizer_discriminator
: optimizer for the Discriminator modelreset_grad()
: helper function to reset the gradients of the optimizerstrain_discriminator(images)
: trains the Discriminator model on a batch of real and fake imagestrain_generator()
: trains the Generator model using the output of the Discriminatorsample_vectors
: random noise vectors for generating new imagesdenorm(x)
: helper function to denormalize the image tensorsave_fake_images(index)
: saves a grid of generated images to thesamples
directory every 500 iterations
WGAN is a powerful technique for generating new images that resemble a given dataset. The Wasserstein distance helps to overcome some of the issues associated with the original GAN loss function, such as mode collapse and vanishing gradients. By adjusting the hyperparameters, this code can be used to generate images from any dataset.