akanimax / pro_gan_pytorch

Unofficial PyTorch implementation of the paper titled "Progressive growing of GANs for improved Quality, Stability, and Variation"

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Samples not showing any progress during training for for 1024x1024 model

tomasheiskanen opened this issue · comments

I tried to train progan on multigpu instance but the generated sample images where seemed to be copies of each other.

Left side first sample at first level and right side last epoch(27) for first level. Pixels seem exactly the same.
image

The same continues for higher resolutions but some color variation probably due to fadeins.
image
image

I was using pytorch=0.4.1 cuda90 and the pro_gan_pytorch-examples/implementation/train_network.py training script with following config:

# Hyperparameters for the Model
img_dims:
  - 1024
  - 1024

# Pro GAN hyperparameters
use_eql: True
depth: 9
latent_size: 512
learning_rate: 0.003
beta_1: 0
beta_2: 0.99
eps: 0.00000001
drift: 0.001
n_critic: 1
use_ema: True
ema_decay: 0.999

# Training hyperparameters:
epochs:
  - 27
  - 54
  - 54
  - 54
  - 54
  - 54
  - 54
  - 54
  - 300

# % of epochs for fading in the new layer
fade_in_percentage:
  - 50
  - 50
  - 50
  - 50
  - 50
  - 50
  - 50
  - 50
  - 50

batch_sizes:
  - 512
  - 512
  - 512
  - 512
  - 512
  - 256
  - 128
  - 64
  - 32

loss_function: "wgan-gp"  # loss function to be used

num_samples: 16
num_workers: 90
feedback_factor: 5   # number of logs generated per epoch
checkpoint_factor: 10  # save the models after these many epochs

Here's part of the loss log as well. Could it be that it's training but the model is not updating to cpu when generating samples?

0	7.445143699645996	-0.666088879108429
1	3.5027077198028564	0.26841363310813904
3	0.03858590126037598	1.2750496864318848
5	-1.2547638416290283	2.315767526626587
7	-1.9465951919555664	4.041855812072754
9	-2.7310028076171875	4.791447162628174
0	-3.052764892578125	5.088418960571289
1	-3.4719738960266113	5.497372150421143
3	-3.640378475189209	5.858086585998535
5	-3.8943538665771484	6.33864688873291
7	-3.9903652667999268	6.605764865875244
9	-3.8730273246765137	6.578725337982178
0	-3.7433838844299316	6.438271999359131
1	-3.77139949798584	6.152718544006348
3	-3.8825106620788574	5.874748706817627
5	-3.9730663299560547	5.793916702270508
7	-3.8050220012664795	5.687839031219482
9	-3.607787847518921	5.536232948303223
0	-3.7319984436035156	5.574792861938477
1	-3.6925082206726074	5.549401760101318
3	-3.807774782180786	5.701358795166016
5	-3.779202938079834	5.610644340515137
7	-3.6574831008911133	5.056485652923584
9	-3.5362486839294434	4.7012176513671875
0	-3.62532377243042	4.540830612182617
1	-3.7182655334472656	4.524333953857422
3	-3.5735902786254883	4.3521728515625
5	-3.5028626918792725	4.146020889282227
7	-3.413512706756592	4.014176845550537
9	-3.5509660243988037	3.9058191776275635
0	-3.393911838531494	3.8175041675567627
1	-3.3943676948547363	3.8033339977264404
3	-3.2545809745788574	3.689268112182617
5	-2.7422475814819336	3.8836066722869873
7	-2.5299720764160156	3.901505708694458
9	-3.1829843521118164	3.4741673469543457
0	-3.391071319580078	3.5319602489471436
1	-3.2669315338134766	3.4916226863861084
3	-3.2526116371154785	3.32706880569458
5	-3.003854513168335	3.124220132827759
7	-2.970198392868042	3.023137331008911
9	-3.0074851512908936	2.84783673286438
0	-2.8317761421203613	2.8252968788146973
1	-2.8186793327331543	2.7927169799804688
3	-2.913585662841797	2.7707386016845703
5	-2.9216248989105225	2.7827188968658447
7	-3.1883113384246826	2.8710832595825195
9	-3.051466941833496	2.7973804473876953
0	-3.005812644958496	2.6593427658081055
1	-2.852177381515503	2.3979153633117676
3	-2.6666369438171387	2.202028751373291
5	-2.686516761779785	2.1320979595184326
7	-2.59663987159729	1.9422683715820312
9	-2.5013532638549805	1.8353694677352905
0	-2.563154697418213	1.8157066106796265
1	-2.614731788635254	1.8588039875030518
3	-2.5844807624816895	1.7903416156768799
5	-2.60732364654541	1.682680606842041
7	-2.332916259765625	1.6604965925216675
9	-2.452439785003662	1.5584651231765747
0	-2.3747992515563965	1.5383646488189697
1	-2.4514126777648926	1.5030796527862549
3	-2.3756911754608154	1.4402575492858887
5	-2.3623554706573486	1.4759807586669922
7	-2.492440700531006	1.4793274402618408
9	-2.495957612991333	1.677773118019104
0	-2.670729637145996	1.825972080230713
1	-2.647500991821289	1.6406993865966797
3	-2.151350498199463	1.2234678268432617
5	-2.158658027648926	1.0205659866333008
7	-2.1201534271240234	0.9499409794807434
9	-2.060696601867676	0.9342047572135925
0	-2.1514177322387695	0.8344566822052002
1	-2.0865421295166016	0.8254259824752808
3	-2.012254238128662	0.8296260237693787
5	-2.0154666900634766	0.8031508326530457
7	-2.1361961364746094	0.9131251573562622
9	-2.169346809387207	1.1396573781967163
0	-2.3894355297088623	1.3240206241607666
1	-2.3853883743286133	1.415444254875183
3	-2.2191309928894043	0.9893991947174072
5	-1.858288288116455	0.5515211820602417
7	-1.9216136932373047	0.4365634024143219
9	-1.8308132886886597	0.39891186356544495
0	-1.952486276626587	0.369014710187912
1	-1.808902382850647	0.37214067578315735
3	-1.8346182107925415	0.37714967131614685
5	-1.8739262819290161	0.4017769396305084
7	-1.8889203071594238	0.47360825538635254
9	-1.9162927865982056	0.5228189826011658
0	-1.9699218273162842	0.6498271822929382
1	-1.9654009342193604	0.548102855682373
3	-1.960349440574646	0.44029688835144043
5	-1.741163730621338	0.3177216351032257
7	-1.7001712322235107	0.17174968123435974
9	-1.829635739326477	0.15137243270874023
0	-1.8278175592422485	0.31675127148628235
1	-1.6870824098587036	0.2879297733306885
3	-1.7959985733032227	0.2949194312095642
5	-1.8141270875930786	0.3123873472213745
7	-1.7969791889190674	0.22187906503677368
9	-1.6505796909332275	0.06910364329814911
0	-1.7593704462051392	0.10136839747428894
1	-1.6619659662246704	0.03943883627653122
3	-1.6369332075119019	-0.062033940106630325
5	-1.6174414157867432	0.08158985525369644
7	-1.5932389497756958	0.06622900068759918
9	-1.6893978118896484	0.11341078579425812
0	-1.6941074132919312	0.09012114256620407
1	-1.5865182876586914	0.007851353846490383
3	-1.6136672496795654	-0.009824557229876518
5	-1.5824973583221436	-0.06083100289106369
7	-1.6119481325149536	0.01717076078057289
9	-1.5229578018188477	-0.10854979604482651
0	-1.5676738023757935	-0.08700834214687347
1	-1.553999423980713	-0.11756245791912079
3	-1.4950313568115234	-0.1988004595041275
5	-1.466239333152771	-0.18715940415859222
7	-1.421967625617981	-0.16267475485801697
9	-1.4385203123092651	-0.20462502539157867
0	-1.4796195030212402	-0.1997477114200592
1	-1.461942195892334	-0.22641906142234802
3	-1.433573603630066	-0.28313198685646057
5	-1.459473729133606	-0.1353251188993454
7	-1.441795825958252	-0.2278439700603485
9	-1.3938238620758057	-0.3460708260536194
0	-1.4129831790924072	-0.3696824014186859
1	-1.3446242809295654	-0.3944481909275055
3	-1.421055555343628	-0.34057164192199707
5	-1.3732415437698364	-0.3205330967903137
7	-1.3133249282836914	-0.3759555518627167
9	-1.321431040763855	-0.4381392002105713
0	-1.420142650604248	-0.33854085206985474
1	-1.379390001296997	-0.31235355138778687
3	-1.3552571535110474	-0.27970457077026367
5	-1.2896485328674316	-0.5287396311759949
7	-1.2823657989501953	-0.46846550703048706
9	-1.2347792387008667	-0.5471756458282471
0	-1.295735239982605	-0.44667479395866394
1	-1.3137845993041992	-0.4586533308029175
3	-1.2428719997406006	-0.3409656584262848
5	-1.330362319946289	-0.35211002826690674
7	-1.283571481704712	-0.45962458848953247
9	-1.2806823253631592	-0.4117633104324341
0	-1.2148696184158325	-0.45754459500312805
1	-1.2371199131011963	-0.4458094537258148
3	-1.3305516242980957	-0.43113473057746887
5	-1.2759795188903809	-0.4516662657260895
7	-1.1464309692382812	-0.5135030150413513
9	-1.2015198469161987	-0.41749152541160583
0	-1.2296406030654907	-0.4396921992301941
1	-1.1728832721710205	-0.5512722730636597
3	-1.2005518674850464	-0.42940470576286316
5	-1.1692029237747192	-0.41121476888656616
7	-1.1729202270507812	-0.4978601038455963
9	-1.1450990438461304	-0.5678228735923767
0	-1.1326179504394531	-0.5848303437232971
1	-1.1378253698349	-0.5951849222183228
3	-1.149521827697754	-0.6170066595077515
5	-1.1457865238189697	-0.5778022408485413
7	-1.1343293190002441	-0.5437898635864258
9	-1.0890092849731445	-0.5257707834243774

@tomasheiskanen,

Thanks for the detailed description of the problem.
Could you please address the following few questions?

1.) How many GPUs were you running this on? There is a known bug for multi-gpu training on wgan-gp -> another_issue. Could you try with a different loss may-be? relativistic-hinge, that works a lot well for me.
2.) How many images do you have in your dataset?
3.) Again, the code is for Python==3.5.6 and Pytorch==1.0.0 try with these maybe?
4.) I haven't seen a problem like this till now. @panovr, @minxdragon, You guys have trained your models recently right? Did you encounter a similar problem?

Thanks.

Hope this helps!

cheers 🍻!
@akanimax

Thanks for your help @akanimax

Seems to work on 1 GPU but not 8 GPUs. Starting to see quite quickly progress for 1 GPU but not for 8GPU.

My dataset size is about ~5000 1024x1024 images

I changed to relativistic-hinge and installed the requirements with

conda create -n progan_pytorch python=3.5.6 -y
source activate progan_pytorch
conda install pytorch=1.0.0 torchvision cuda100 cudatoolkit=10.0 -c pytorch -y
pip install pyyaml easydict

New training script for

import torch as th
import torchvision as tv
import pro_gan_pytorch.PRO_GAN as pg

class NoClassImageFolder(tv.datasets.ImageFolder):
    def __init__(self, *args, **kwargs):
        super(NoClassImageFolder, self).__init__(*args, **kwargs)
    def __getitem__(self, index):
        return super(NoClassImageFolder, self).__getitem__(index)[0]

beta_1=0
beta_2=0.99
eps=1e-8
drift=0.001
n_critic=1
use_eql=True
use_ema=True
ema_decay=0.999
num_samples=16
start_depth=0

results_path = '../results'
data_path = '../dataset'

epochs = [27*2]+[54*2]*7+[300*2]

# 8 GPU (32GB)
batch_sizes = [512]*5+[256,128,64,32]
# 1 GPU (16GB)
# batch_sizes = [256,128,64,32,32,16,8,4,2]

fade_in_percentage = [50]*9

learning_rate=0.003
depth = 9
latent_size = 512
feedback_factor = 3
checkpoint_factor=1
num_workers = 16
loss = "relativistic-hinge"

log_dir = results_path+"/models/"
sample_dir = results_path+"/samples/"
save_dir = results_path+"/models/"

device=th.device("cuda")
th.backends.cudnn.benchmark = True

transforms = tv.transforms.ToTensor()
dataset = NoClassImageFolder(root=data_path,transform=transforms)

pro_gan = pg.ProGAN(depth=depth, latent_size=latent_size, learning_rate=learning_rate, beta_1=beta_1,
                    beta_2=beta_2, eps=eps, drift=drift, n_critic=n_critic, use_eql=use_eql,
                    loss=loss, use_ema=use_ema, ema_decay=ema_decay,
                    device=device)

pro_gan.train(dataset=dataset, epochs=epochs, batch_sizes=batch_sizes,
              fade_in_percentage=fade_in_percentage, num_samples=num_samples,
              start_depth=start_depth, num_workers=num_workers, feedback_factor=feedback_factor,
              log_dir=log_dir, sample_dir=sample_dir, save_dir=save_dir,
              checkpoint_factor=checkpoint_factor)

8 GPU
image

1 GPU
image

@tomasheiskanen,

Thanks a lot for narrowing down the problem.
I'll look into this.
One suggestion: I find it most helpful when you keep all the number of epochs on all resolutions equal and then set the fade-in percentage to 50%. This seems to give the best results.
You could continue running on single-gpu instance for now, but I will soon fix the multi-gpu problem 👍 .

Thanks!

cheers 🍻!
@akanimax

@akanimax

Ok will keep that in mind. Thanks for looking into this.

How does the distribution to gpus work? Do you divide the batches or the network across the gpus?

Yup. I use the DataParallel feature from PyTorch. Looks like something is wrong with it.