chaofengc / ITER

PyTorch codes for "Iterative Token Evaluation and Refinement for Real-World Super-Resolution", AAAI 2024

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool


framework_img

Pipeline of ITER. The input $I_l$ first passes through a distortion removal network $E_l$ to obtain the initially restored tokens $S_l$, which are composed of indexes of the quantized features in the codebook of VQGAN. Then, a reverse discrete diffusion process, conditioned on $S_l$, is used to generate textures. The process starts from completely masked tokens $S_T$. The refinement network (also called the de-masking network) $\phi_r$ generates refined outputs $S_{T-1}$ with $S_l$ as a condition. Then, $\phi_e$ evaluates $S_{T-1}$ to obtain the evaluation mask $m_{T-1}$, which determines the tokens to keep and refine for step $T-1$ through a masked sampling process. Repeat this process $T$ times to obtain de-masked outputs $S_0$, and then reconstruct the restored images $I_{sr}$ using the VQGAN decoder $D_H$. We found that $T\leq8$ is enough to get good results with ITER, which is much more efficient than other diffusion-based approaches.

TODO List

  • Release ITER_x4 model.
  • Release ITER_x2 model.
  • Release training and testing codes.
  • Release training datasets.

🔧 Dependencies and Installation

# git clone this repository
git clone https://github.com/chaofengc/ITER.git
cd ITER 

# create new anaconda env
conda create -n iter python=3.8
source activate iter 

# install python dependencies
pip3 install -r requirements.txt
python setup.py develop

⚡Quick Inference

python inference_iter.py -s 2 -i ./testset/lrx4/frog.png
python inference_iter.py -s 4 -i ./testset/lrx4/frog.png

👨‍💻Train the model

⏬ Download Datasets

The training datasets can be downloaded from 🤗hugging face. You may also refer to FeMaSR to prepare your own training data.

‍🔁 Training

Below are brief examples for training the model. Please modify the corresponding configuration files to suit your needs.

Stage I: Train the Swin-VQGAN

accelerate launch --multi_gpu --num_processes=8 --mixed_precision=bf16 basicsr/train.py -opt options/train_ITER_HQ_stage.yml

Stage II & III: Train the LQ encoder and the refinement network

accelerate launch --main_process_port=29600 --multi_gpu --num_processes=8 --mixed_precision=bf16 basicsr/train.py -opt options/train_ITER_LQ_stage_X2.yml

accelerate launch --main_process_port=29600 --multi_gpu --num_processes=8 --mixed_precision=bf16 basicsr/train.py -opt options/train_ITER_LQ_stage_X4.yml

📝 Citation

If you find this code useful for your research, please cite our paper:

@inproceedings{chen2024iter,
  title={Iterative Token Evaluation and Refinement for Real-World Super-Resolution},
  author={Chaofeng Chen and Shangchen Zhou and Liang Liao and Haoning Wu and Wenxiu Sun and Qiong Yan and Weisi Lin},
  booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
  year={2024},
}

⚖️ License

Creative Commons License
This work is licensed under a Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License and NTU S-Lab License 1.0.

❤️ Acknowledgement

This project is based on BasicSR.

About

PyTorch codes for "Iterative Token Evaluation and Refinement for Real-World Super-Resolution", AAAI 2024

License:Other


Languages

Language:Python 100.0%