NVlabs / imaginaire

NVIDIA's Deep Imagination Team's PyTorch Library

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

How to stable MUNIT training process?

NYCXI opened this issue · comments

I use my own dataset to train MUNIT(which is similar to the synthetic2cityscape), but the training process is very unstable. How can I make training procedure more stable?
some of the charts are look like below:
image
image
image
And the yaml file is as blow, which is modified from configs/projects/munit/summer2winter_hd/ampO1.yaml

pretrained_weight: 17gYCHgWD9xM_EFqid1S3b3MXBjIvElAI
inference_args:
    # Translates images from domain A to B or from B to A.
    a2b: True
    # Samples the style code from the prior distribution or uses the style code
    # encoded from the input images in the other domain.
    random_style: True

# How often do you want to log the training stats.
logging_iter: 10
# Number of training epochs.
max_iter: 100000
# Whether to benchmark speed or not.
speed_benchmark: True

image_display_iter: 500
image_save_iter: 5000
snapshot_save_iter: 5000
trainer:
    type: imaginaire.trainers.munit
    model_average_config:
        enabled: True
    amp_config:
        enabled: True
    gan_mode: hinge
    perceptual_mode: vgg19
    perceptual_layers: 'relu_4_1'
    loss_weight:
        gan: 1
        image_recon: 10
        content_recon: 1
        style_recon: 1
        perceptual: 0
        cycle_recon: 10
        gp: 0
        consistency_reg: 0
    init:
        type: orthogonal
        gain: 1
gen_opt:
    type: adam
    lr: 0.0001
    adam_beta1: 0.5
    adam_beta2: 0.999
    lr_policy:
        type: constant
dis_opt:
    type: adam
    lr: 0.0004
    adam_beta1: 0.5
    adam_beta2: 0.999
    lr_policy:
        type: constant
gen:
    type: imaginaire.generators.munit
    latent_dim: 8
    num_filters: 64
    num_filters_mlp: 256
    num_res_blocks: 4
    num_mlp_blocks: 2
    num_downsamples_style: 4
    num_downsamples_content: 3
    content_norm_type: instance
    style_norm_type: none
    decoder_norm_type: instance
    weight_norm_type: spectral
    pre_act: True
dis:
    type: imaginaire.discriminators.munit
    patch_wise: True
    num_filters: 48
    max_num_filters: 1024
    num_layers: 5
    activation_norm_type: none
    weight_norm_type: spectral

# Data options.
data:
    # Name of this dataset.
    name: fusion2cityscape
    # Which dataloader to use?
    type: imaginaire.datasets.unpaired_images
    # How many data loading workers per GPU?
    num_workers: 8
    input_types:
        - images_a:
            # If not specified, is None by default.
            ext: png
            # If not specified, is None by default.
            num_channels: 3
            # If not specified, is None by default.
            normalize: True
        - images_b:
            # If not specified, is None by default.
            ext: png
            # If not specified, is None by default.
            num_channels: 3
            # If not specified, is None by default.
            normalize: True

    # Train dataset details.
    train:
        # Input LMDBs.
        roots:
            - /data/hdd2/zhangrui/dataset/fusion2cityscape_raw/train
        # Batch size per GPU.
        batch_size: 8
        # Data augmentations to be performed in given order.
        augmentations:
            # First resize all inputs to this size.
            resize_h_w: 480, 640 
            # Horizontal flip?
            horizontal_flip: True
            # Crop size.
            random_crop_h_w: 480, 640

    # Val dataset details.
    val:
        # Input LMDBs.
        roots:
            - /data/hdd2/zhangrui/dataset/fusion2cityscape_raw/test
        # Batch size per GPU.
        batch_size: 1
        # If resize_h_w is not given, then it is assumed to be same as crop_h_w.
        augmentations:
            center_crop_h_w: 480, 640

test_data:
    # Name of this dataset.
    name: fusion2cityscape
    # Which dataloader to use?
    type: imaginaire.datasets.unpaired_images
    input_types:
        - images_a:
              ext: png
              num_channels: 3
              normalize: True
        - images_b:
              ext: png
              num_channels: 3
              normalize: True

    # Which labels to be concatenated as final output label from dataloader.
    paired: False
    # Validation dataset details.
    test:
        is_lmdb: False
        roots:
            - /data/hdd2/zhangrui/dataset/fusion2cityscape_raw/test
        # Batch size per GPU.
        batch_size: 1
        # If resize_h_w is not given, then it is assumed to be same as crop_h_w.
        augmentations:
            resize_smallest_side: 1024

sorry about my poor English

Can I ask you how to setup and install imaginaire to train MUNIT