rolux / stylegan2encoder

StyleGAN2 - Official TensorFlow Implementation

Home Page:http://arxiv.org/abs/1912.04958

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Improving initialization

SimJeg opened this issue · comments

Dear @rolux,

Many thanks for porting the work of @Puzer for StyleGan2. I noticed the optimization sometimes fails due to bad initialization of the dlatent variable W. I tried to finetune two ResNets for a better initialization.

  • The first one is trained to predict W[:, 0] of shape 512 from the image X generated by W. The ResNet is initialized with ImageNet weights.
  • The second one is trained to predict W of shape (18, 512). We use style mixing for W to cover a wider distribution of images. The ResNet is initialized with the weights of the first ResNet.

Code is quick & dirty but functional. This initialization solves some failure cases and speed up convergence. Maybe by digging into this direction, it would be possible to avoid completely the optimization as it was done in the neural style transfer field ? Below an example where all 3 initializations work well :

Zero initialization (current behavior) :
zero

ResNet 1 initialization :
resnet1

ResNet 2 initialization :
resnet2

import os
import numpy as np
import cv2

from keras.applications.imagenet_utils import preprocess_input
from keras.layers import Dense, Reshape
from keras.models import Sequential, Model, load_model
from keras.applications.resnet50 import ResNet50
from keras.optimizers import Adam

import pretrained_networks
import dnnlib.tflib as tflib


def get_batch(batch_size, Gs, image_size=224, Gs_minibatch_size=12, w_mix=None):
    """
    Generate a batch of size n for the model to train
    returns a tuple (W, X) with W.shape = [batch_size, 18, 512] and X.shape = [batch_size, image_size, image_size, 3]
    If w_mix is not None, W = w_mix * W0 + (1 - w_mix) * W1 with
        - W0 generated from Z0 such that W0[:,i] = constant
        - W1 generated from Z1 such that W1[:,i] != constant

    Parameters
    ----------
    batch_size : int
        batch size
    Gs
        StyleGan2 generator
    image_size : int
    Gs_minibatch_size : int
        batch size for the generator
    w_mix : float

    Returns
    -------
    tuple
        dlatent W, images X
    """

    # Generate W0 from Z0
    Z0 = np.random.randn(batch_size, Gs.input_shape[1])
    W0 = Gs.components.mapping.run(Z0, None, minibatch_size=Gs_minibatch_size)

    if w_mix is None:
        W = W0
    else:
        # Generate W1 from Z1
        Z1 = np.random.randn(18 * batch_size, Gs.input_shape[1])
        W1 = Gs.components.mapping.run(Z1, None, minibatch_size=Gs_minibatch_size)
        W1 = np.array([W1[batch_size * i:batch_size * (i + 1), i] for i in range(18)]).transpose((1, 0, 2))

        # Mix styles between W0 and W1
        W = w_mix * W0 + (1 - w_mix) * W1

    # Generate X
    X = Gs.components.synthesis.run(W, randomize_noise=True, minibatch_size=Gs_minibatch_size, print_progress=True,
                                    output_transform=dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True))

    # Preprocess images X for the Imagenet model
    X = np.array([cv2.resize(x, (image_size, image_size)) for x in X])
    X = preprocess_input(X.astype('float'))

    return W, X

def finetune(save_path, image_size=224, base_model=ResNet50, batch_size=2048, test_size=1024, n_epochs=6,
             max_patience=5):
    """
    Finetunes a ResNet50 to predict W[:, 0]

    Parameters
    ----------
    save_path : str
        path where to save the Resnet
    image_size : int
    base_model : keras model
    batch_size :  int
    test_size : int
    n_epochs : int
    max_patience : int

    Returns
    -------
    None

    """

    assert image_size >= 224

    # Load StyleGan generator
    _, _, Gs = pretrained_networks.load_networks('data/stylegan2-ffhq-config-f.pkl')

    # Build model
    if os.path.exists(save_path):
        print('Loading pretrained network')
        model = load_model(save_path, compile=False)
    else:
        base = base_model(include_top=False, pooling='avg', input_shape=(image_size, image_size, 3))
        model = Sequential()
        model.add(base)
        model.add(Dense(512))

    model.compile(loss='mse', metrics=[], optimizer=Adam(3e-4))
    model.summary()

    # Create a test set
    print('Creating test set')
    W_test, X_test = get_batch(test_size, Gs)

    # Iterate on batches of size batch_size
    print('Training model')
    patience = 0
    best_loss = np.inf

    while (patience <= max_patience):
        W_train, X_train = get_batch(batch_size, Gs)
        model.fit(X_train, W_train[:, 0], epochs=n_epochs, verbose=True)
        loss = model.evaluate(X_test, W_test[:, 0])
        if loss < best_loss:
            print(f'New best test loss : {loss:.5f}')
            model.save(save_path)
            patience = 0
            best_loss = loss
        else:
            print(f'-------- test loss : {loss:.5f}')
            patience += 1


def finetune_18(save_path, base_model=None, image_size=224, batch_size=2048, test_size=1024, n_epochs=6,
                max_patience=8, w_mix=0.7):
    """
    Finetunes a ResNet50 to predict W[:, :]

    Parameters
    ----------
    save_path : str
        path where to save the Resnet
    image_size : int
    base_model : str
        path to the first finetuned ResNet50
    batch_size :  int
    test_size : int
    n_epochs : int
    max_patience : int
    w_mix : float

    Returns
    -------
    None

    """

    assert image_size >= 224
    if not os.path.exists(save_path):
        assert base_model is not None

    # Load StyleGan generator
    _, _, Gs = pretrained_networks.load_networks('data/stylegan2-ffhq-config-f.pkl')

    # Build model
    if os.path.exists(save_path):
        print('Loading pretrained network')
        model = load_model(save_path, compile=False)
    else:
        base_model = load_model(base_model)
        hidden = Dense(18 * 512)(base_model.layers[-1].input)
        outputs = Reshape((18, 512))(hidden)
        model = Model(base_model.input, outputs)
        # Set initialize layer
        W, b = base_model.layers[-1].get_weights()
        model.layers[-2].set_weights([np.hstack([W] * 18), np.hstack([b] * 18)])

    model.compile(loss='mse', metrics=[], optimizer=Adam(1e-4))
    model.summary()

    # Create a test set
    print('Creating test set')
    W_test, X_test = get_batch(test_size, Gs, w_mix=w_mix)

    # Iterate on batches of size batch_size
    print('Training model')
    patience = 0
    best_loss = np.inf

    while (patience <= max_patience):
        W_train, X_train = get_batch(batch_size, Gs, w_mix=w_mix)
        model.fit(X_train, W_train, epochs=n_epochs, verbose=True)
        loss = model.evaluate(X_test, W_test)
        if loss < best_loss:
            print(f'New best test loss : {loss:.5f}')
            model.save(save_path)
            patience = 0
            best_loss = loss
        else:
            print(f'-------- test loss : {loss:.5f}')
            patience += 1

if __name__ == '__main__':
    finetune('data/resnet.h5')
    finetune_18('data/resnet_18.h5', 'data/resnet.h5', w_mix=0.8)
commented

Yes, that would be a huge improvement. But isn't ResNet initialization included with pbaylies' fork of the encoder? That one is well-maintained, and I'm not trying to duplicate it. It's also not too hard to port to StyleGAN2.

commented

Other than that, if you pull the latest changes, you can use the projector, which may be a better choice anyway.

@SimJeg was the inspiration and provided the initial code for doing the ResNets in my repo in the first place; nice work!

commented

@SimJeg: Ok, I've looked at this more closely ;)

Three ResNet initializations below, from left to right: @pbaylies, StyleGAN V1; @Quasimondo, via twitter, using your (18, 512) ResNet above; myself, ditto, after 5 minutes of training.

Mona Lisa ResNet

The benefit of your approach is clearly visible.

I'm still wondering though: is ResNet initialization just a useful encoder optimization for faster convergence, or can it be demonstrated that it actually leads to better convergence than initializing with w_avg and running puzer's encoder or tkarras' projector? And if so, does that happen with specific classes of portraits, and is there a sweet spot at 2K or 3K iterations after which initialization doesn't matter?

My understanding, after reading the Image2StyleGAN paper, is that ~5K iterations are sufficient to encode anything into W-space with pretty high fidelity, with the possible exception of subtly translated faces (for example: misaligned portrait looks slightly worse than banana after 5K iterations). I'd be curious to see a failure case that can be fixed with better initialization.

But even if it's just about speed, it may be a good idea to save everyone a few cycles by adding an option to download a pretrained ResNet.

commented

@SimJeg: I have trained a ResNet, and will post some results shortly.

resnet_18_20191231.h5 (best test loss: 0.04438)

If you get a TypeError: Unexpected keyword argument passed to optimizer: learning_rate
you'll need to upgrade keras from 2.2.* to 2.3.* - lr was renamed to learning_rate :(

predictions_paintings

predictions_covers

Very nice work, @rolux !

@SimJeg in my experience, the loss decreases gradually but improves over time. My network architecture might have been a bit different, I had some layers I added on after the main ResNet, mainly to avoid having a huge dense layer. I got my best performance from using a ResNet but just training it for longer, but I liked having support for EfficientNet, just to have more potential options. I think there's still a lot that could be explored here, as far as different possible architectures and configurations.

commented

It took me a while to appreciate the fact (thanks to @pbaylies for the insight) that encoder output can have high visual quality, but bad semantics.

The W(18, 512) projector, for example, explores W-space so well that its output, beyond a certain number of iterations, becomes meaningless. It is so far from w_avg that the usual applications -- interpolate with Z -> W(1, 512) dlatents, apply direction vectors obtained from Z -> W(1, 512) samples, etc. -- won't work as expected.

To illustrate this, I have run the projector on the 8 samples I had posted above, once for 1000 and once for 5000 iterations, and plotted the visual quality -- the projector's dist_value -- and the semantic quality -- np.sqrt(np.mean(np.square(dlatent - w_avg))) -- of the results.

For comparison: the mean semantic quality of Z -> W(1, 512) dlatents is 0.44. 2 is okay, 4 is bad.

plots_iterations

To keep the dlatent closer to w_avg, one can either clip it or introduce a penalty, at the expense of some visual quality. Both options are present in pbaylies' encoder, but I haven't instumented it yet.

Now what about ResNet initialization? I have added it to the projector, and tested it on 100 samples. The results suggest that, in this particular setup, it doesn't make a considerable difference.

plots_initialization

Of course, this is by no means the last word on the matter. I could train the ResNet for longer, and/or play with w_mix. I could instrument puzer's encoder, add w_avg and ResNet initialization, and see how it compares. Also, my choice of metrics for visual and semantic quality may be misguided, just as my choice of samples (100 faces from FFHQ).

I am not sure how meaningful that semantic quality measure is - I guess it depends on what one is looking for in that latent space. If I am not mistaken it just measures how "normal" (or as I prefer to say how "boring") the resulting image is. To me the really interesting images are those that are as far away from the mean as possible without showing artifacts or breaking up.

So the way I would interpret your semantic quality graph is that with ResNet initialization you are able to reach hard-to-find images that are outside the normal distribution quicker, and at least with the ResNet encoder I trained for myself I subjectively feel that this is the case - right now I only run 200 iterations and often with "simple" faces it gets to the similar-enough state after just 50-70 iterations. Of course with 200 iterations you will not get all the details, like occluding hands, hair strains or certain glasses, but at least for my purposes it's more important to get a lot of okay-faces rather than get just a few perfect ones.

@rolux thank you for investigating this! I find it surprising that the ResNet doesn't seem to get you any initial advantage in the visual quality metric. Is this ResNet outputting into W(1, 512) or W(18, 512) space? Also note that in my code for training a ResNet, I added a parameter for using the truncation trick in the generated training data, so you can control your initial semantic quality output, as it were.

@Quasimondo you're right about the normal or boringness, as it were; I think this metric captures how well the model can represent or understand an image based on what it learned from the training data distribution. This is more useful if you're working with interpolations; if the goal is the image itself, then visual quality would be the important metric.

commented

@pbaylies: I'm outputting into W(18, 512). It's SimJeg's code, with the w_mix arg passed in the last get_batch call. The initial advantage, averaged over 100 runs, is not exactly zero, but very small.

plot_initialization

@Quasimondo: From what I've seen, the visual distance drop you are getting from your ResNet seems to be more significant. Maybe you've made more changes than just that one fix? Trained it for longer? Or you're using the encoder and not the projector?

If your specific use case involves encoding a large number of images, then getting acceptable results after 200 iterations would be a huge improvement. On the other hand, if you just want a few faces with high accuracy, then initialization doesn't seem to matter.

With regards to semantics: I totally agree that among Z -> W(1, 512) mappings, the interesting faces are usually the ones further away from w_avg. It's just that you can push projections much further than that.

For example, take a look at this video, or the still below:

semantics

On the left is a Z -> W(1, 512) face, ψ=0.75, with a semantics score of 0.28. On the right is the same face projected into W(18, 512), it=5000, with a score of 3.36. They both transition along the same "surprise" vector. On the left, this looks gimmicky, but visually okay. On the right, you have to multiply the vector by 10 to achieve a comparable amount of change, which leads to obvious artifacts. As long as you obtain your vectors from annotated Z -> W(1, 512) samples, you're going to run into this problem.

Should you just try to adjust your vectors more cleverly, or find better ones? My understanding is that this won't work, and that there is no outer W-space where you can smoothly interpolate between all the cool projections that are missing from the regular inner W-space mappings. (Simplified: Z is a unit vector, a point on a 512D-sphere. Ideally, young-old would be north-south pole, male-female east-west pole, smile-unsmile front-back pole, and so on. W(1, 512) is a learned deformation of that surface that accounts for the uneven distribution of features in FFHQ. W(18, 512) is a synthesizer option that allows for style mixing and can be abused for projection. But all the semantics of StyleGAN reside on W(1, 512). W(18, 512) vectors filled with 18 Z -> W(1, 512) mappings already belong to a different species. High-quality projections are paintings of faces.)

Should you use the encoder "just for the image"? As far as I can see, nothing keeps you from projecting arbitrary video into StyleGAN. If that happens to be a well-aligned portrait shot, are you the first person who can make a StyleGAN face sneeze or stick out her tongue? Or just the inventor of an extremely energy inefficient video codec?

@rolux Yes, I did some more changes to the training method and also trained for longer. The major change is probably that I also mix W's resulting from previous descends into the training set, because those are extremely unlikely to be returned by just random initialization (even with style mixing) - and before you ask: no the W's of my examples were not part of the training. The other changes are just using a bigger training set size and training with more epochs.

I like your point that high-quality projections are just paintings of faces - I have not analyzed which layers of W are mostly responsible for "rare" details, but I suspect that the heavy lifting is all done by the style layers.

commented

@Quasimondo: To get a better sense of which layer does what, I used to render these style grids:

https://youtu.be/hx51TqJ_adE

Top row: style target, midpoint, style source, 0-3 mix (coarse styles, the "pose"), 4-7 mix (middle styles, the "person"), 8-17 mix (fine styles, the "style"). Below: single-layer mixes from 0 to 17.

Check 2 for hair, 4 for shape and smile, 6 for gender and eyes, 8 for light, 10 for lipstick, etc.

Maybe that also helps to visualize when the semantics of a projection get bad. Top: 1000 iterations, bottom: 5000 iterations. In the bottom grid, 9 and 11 start to look unhealthy.

style_grid_wavg_monalisa1000

style_grid_wavg_monalisa5000

commented

@Quasimondo: The changes you made to the ResNet training process sound interesting, I'll try to find out how much of that I can reproduce.

I'm a still a bit reluctant to add it to the repo because it seems like a step down a slippery slope: If I get this to work, I would want to try out EffNet initialization for comparison, then clipping vs. penalty to keep dlatents closer to w_avg, and so on. I would probably end up with buggy, untested and less well maintained implementation of half of pbaylies' encoder. So maybe, if something comes out of this, I'd rather submit it as a pull request for that one. We'll see...

@rolux that's basically the same slippery slope I went down with Puzer's encoder in the first place; go ahead if you like, it could use a rewrite by now, or at least a solid refactoring + ablation test. Note that I do have code already in there for training ResNets and EfficientNets (surely buggy / out of date by now given that the library I had targeted had just been released), and there is code for using the mapping network to generate mixed latents for training in (18, 512) space.

commented

This discussion was an interesting read, which kind of answers my question about the "tiled" projector (#21).

Given the semantic issue with W(18,*) compared to W(1,*), I think it would make sense to accept this commit (kreativai@2036fb8) from this pull request (#9).
Otherwise, I was scratching my head wondering whether the "tiled" projector was necessarily an improvement (in terms of visual quality, it appears to be), and why it was not the projector used in Nvidia's default implementation of the projector. Providing a command-line argument to toggle off this modification would allow people like me to learn about its existence (without diving into the code), and its limitations.

Moreover, if I understand correctly, the grids shown in this post were obtained by projecting "Mona Lisa" with W(18,*) in both cases. I find it interesting that the difference in number of iterations (1000 vs. 5000) leads to:

  • barely different projection results (it is noticeable if you look closely, but not that big),
  • very different interpolation results (with visible artefacts mentioned by Rolux for images 9 and 11, but not just that: longer hair on images 1 and 3, smaller head and different pose on image 2, swapped gender on image 6, artifact on image 7, no eyebrow and a very different result on image 8, etc.).

1000 iterations
5000 iterations

This contrasts with my experience with the default projector from Nvidia's original repository, which seemed to converge fast (although it cannot perfectly fit the real image because it uses W(1,*)).

default projector

Now, I am left wondering if the default projector suffers from the same dependency on the number of iterations: it visually looked like it converged, but would more iterations change its semantics without me realizing it? That would be disappointing.

commented

Hey there,

I come back here, because the trade-off between visual quality and semantic quality is discussed in this paper:

Tov, O., Alaluf, Y., Nitzan, Y., Patashnik, O., & Cohen-Or, D. (2021). Designing an Encoder for StyleGAN Image Manipulation. arXiv preprint arXiv:2102.02766.
https://arxiv.org/abs/2102.02766

It is done by the people behind https://github.com/orpatashnik/StyleCLIP

Semantic quality is called "editability".
Visual quality is divided into two elements:

  • distortion (distance between the input and the projection)
  • perception (realism of the projection).