destenson / Janspiry--Image-Super-Resolution-via-Iterative-Refinement

Unoffical implementation about Image Super-Resolution via Iterative Refinement by Pytorch

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Image Super-Resolution via Iterative Refinement

Paper | Project

Brief

This is a unoffical implementation about Image Super-Resolution via Iterative Refinement(SR3) by Pytorch.

There are some implement details with paper description, which maybe different with actual SR3 structure due to details missing.

  • We used the ResNet block and channel concatenation style like vanilla DDPM.
  • We used the attention mechanism in low resolution feature(16×16) like vanilla DDPM.
  • We encoding the $\gamma$ as FilM strcutrue did in WaveGrad, and embedding it without affine transformation.

If you just want to upscale 64x64px -> 512x512px images using the pre-trained model, check out this google colab script.

Status

Conditional generation(super resolution)

  • 16×16 -> 128×128 on FFHQ-CelebaHQ
  • 64×64 -> 512×512 on FFHQ-CelebaHQ

Unconditional generation

  • 128×128 face generation on FFHQ
  • 1024×1024 face generation by a cascade of 3 models

Training Step

  • log / logger
  • metrics evaluation
  • multi-gpu support
  • resume training / pretrained model
  • validate alone script
  • Weights and Biases Logging 🌟 NEW

Results

Note: We set the maximum reverse steps budget to 2000 now. Limited to model parameters in Nvidia 1080Ti, image noise and hue deviation occasionally appears in high-resolution images, resulting in low scores. There are a lot room to optimization. Welcome to any contributions for more extensive experiments and code enhancements.

Tasks/Metrics SSIM(+) PSNR(+) FID(-) IS(+)
16×16 -> 128×128 0.675 23.26 - -
64×64 -> 512×512 0.445 19.87 - -
128×128 - -
1024×1024 - -
show show show
show show show
show show show
show show show

Usage

Environment

# Copy the pytorch environment using the dependencies files, and you can choose the either of the following ways.
conda env create -f core/environment.yml
conda create  --name pytorch --file core/environment.txt

Pretrained Model

This paper is based on "Denoising Diffusion Probabilistic Models", and we build both DDPM/SR3 network structure, which use timesteps/gama as model embedding input, respectively. In our experiments, SR3 model can achieve better visual results with same reverse steps and learning rate. You can select the json files with annotated suffix names to train different model.

Tasks Platform(Code:qwer)
16×16 -> 128×128 on FFHQ-CelebaHQ Google Drive|Baidu Yun
64×64 -> 512×512 on FFHQ-CelebaHQ Google Drive|Baidu Yun
128×128 face generation on FFHQ Google Drive|Baidu Yun
# Download the pretrain model and edit [sr|sample]_[ddpm|sr3]_[resolution option].json about "resume_state":
"resume_state": [your pretrain model path]

Data Prepare

New Start

If you didn't have the data, you can prepare it by following steps:

Download the dataset and prepare it in LMDB or PNG format using script.

# Resize to get 16×16 LR_IMGS and 128×128 HR_IMGS, then prepare 128×128 Fake SR_IMGS by bicubic interpolation
python data/prepare_data.py  --path [dataset root]  --out [output root] --size 16,128 -l

then you need to change the datasets config to your data path and image resolution:

"datasets": {
    "train": {
        "dataroot": "dataset/ffhq_16_128", // [output root] in prepare.py script
        "l_resolution": 16, // low resolution need to super_resolution
        "r_resolution": 128, // high resolution
        "datatype": "lmdb", //lmdb or img, path of img files
    },
    "val": {
        "dataroot": "dataset/celebahq_16_128", // [output root] in prepare.py script
    }
},

Own Data

You also can use your image data by following steps. We have some examples in dataset folder.

At first, you should organize images layout like this:

# set the high/low resolution images, bicubic interpolation images path
dataset/celebahq_16_128/
├── hr_128
├── lr_16
└── sr_16_128

then you need to change the dataset config to your data path and image resolution:

"datasets": {
    "train|val": { // train and validation part
        "dataroot": "dataset/celebahq_16_128",
        "l_resolution": 16, // low resolution need to super_resolution
        "r_resolution": 128, // high resolution
        "datatype": "img", //lmdb or img, path of img files
    }
},

Training/Resume Training

# Use sr.py and sample.py to train the super resolution task and unconditional generation task, respectively.
# Edit json files to adjust network structure and hyperparameters
python sr.py -p train -c config/sr_sr3.json

Test/Evaluation

# Edit json to add pretrain model path and run the evaluation 
python sr.py -p val -c config/sr_sr3.json

# Quantitative evaluation alone using SSIM/PSNR metrics on given result root
python eval.py -p [result root]

Inference Alone

Set the HR (vanilla high resolution images), SR (images need processed) image path like step in Own Data. HR directory contexts can be copy from SR, and LR directory is unnecessary.

# run the script
python infer.py -c [config file]

Weights and Biases 🎉

The library now supports experiment tracking, model checkpointing and model prediction visualization with Weights and Biases. You will need to install W&B and login by using your access token.

pip install wandb

# get your access token from wandb.ai/authorize
wandb login

W&B logging functionality is added to sr.py, sample.py and infer.py files. You can pass -enable_wandb to start logging.

  • -log_wandb_ckpt: Pass this argument along with -enable_wandb to save model checkpoints as W&B Artifacts. Both sr.py and sample.py is enabled with model checkpointing.
  • -log_eval: Pass this argument along with -enable_wandb to save the evaluation result as interactive W&B Tables. Note that only sr.py is enabled with this feature. If you run sample.py in eval mode, the generated images will automatically be logged as image media panel.
  • -log_infer: While running infer.py pass this argument along with -enable_wandb to log the inference results as interactive W&B Tables.

You can find more on using these features here. 🚀

Acknowledge

Our work is based on the following theoretical works:

and we are benefit a lot from following projects:

About

Unoffical implementation about Image Super-Resolution via Iterative Refinement by Pytorch

License:Apache License 2.0


Languages

Language:Python 100.0%