LabShuHangGU / Adaptive-Token-Dictionary

CVPR2024 - Transcending the Limit of Local Window: Advanced Super-Resolution Transformer with Adaptive Token Dictionary

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Adaptive Token Dictionary

This repository is an official implementation of the paper "Transcending the Limit of Local Window: Advanced Super-Resolution Transformer with Adaptive Token Dictionary", CVPR, 2024.

[Arxiv] [visual results] [pretrained models]

By Leheng Zhang, Yawei Li, Xingyu Zhou, Xiaorui Zhao, and Shuhang Gu.

Abstract: Single Image Super-Resolution is a classic computer vision problem that involves estimating high-resolution (HR) images from low-resolution (LR) ones. Although deep neural networks (DNNs), especially Transformers for super-resolution, have seen significant advancements in recent years, challenges still remain, particularly in limited receptive field caused by window-based self-attention. To address these issues, we introduce a group of auxiliary Adaptive Token Dictionary to SR Transformer and establish an ATD-SR method. The introduced token dictionary could learn prior information from training data and adapt the learned prior to specific testing image through an adaptive refinement step. The refinement strategy could not only provide global information to all input tokens but also group image tokens into categories. Based on category partitions, we further propose a category-based self-attention mechanism designed to leverage distant but similar tokens for enhancing input features. The experimental results show that our method achieves the best performance on various single image super-resolution benchmarks.



Contents

  1. Enviroment
  2. Fast Inference
  3. Training
  4. Testing
  5. Results
  6. Visual Results
  7. Citation
  8. Acknowledgements

Environment

  • Python 3.9
  • PyTorch 2.0.1

Installation

git clone https://github.com/LabShuHangGU/Adaptive-Token-Dictionary.git

conda create -n ATD python=3.9
conda activate ATD

pip install -r requirements.txt
python setup.py develop

Fast Inference

Using inference.py for fast inference on single image or multiple images within the same folder.

# For classical SR
python inference.py -i test_image.png -o results/test/ --scale 4 --task classical
python inference.py -i test_images/ -o results/test/ --scale 4 --task classical

# For lightweight SR
python inference.py -i test_image.png -o results/test/ --scale 4 --task lightweight
python inference.py -i test_images/ -o results/test/ --scale 4 --task lightweight

The ATD SR model processes the image test_image.png or images within the test_images/ directory. The results will be saved in the results/test/ directory.

Training

Data Preparation

  • Download the training dataset DF2K (DIV2K + Flickr2K) and put them in the folder ./datasets.
  • It's recommanded to refer to the data preparation from BasicSR for faster data reading speed.

Training Commands

  • Refer to the training configuration files in ./options/train folder for detailed settings.
  • ATD (Classical Image Super-Resolution)
# batch size = 8 (GPUs) × 4 (per GPU)
# training dataset: DF2K

# ×2 scratch, input size = 64×64, 300k iterations
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --use-env --nproc_per_node=8 --master_port=1145  basicsr/train.py -opt options/train/000_ATD_SRx2_scratch.yml --launcher pytorch
# ×2 finetune, input size = 96×96, 250k iterations
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --use-env --nproc_per_node=8 --master_port=1145  basicsr/train.py -opt options/train/001_ATD_SRx2_finetune.yml --launcher pytorch

# ×3 finetune, input size = 96×96, 250k iterations
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --use-env --nproc_per_node=8 --master_port=1145  basicsr/train.py -opt options/train/002_ATD_SRx3_finetune.yml --launcher pytorch

# ×4 finetune, input size = 96×96, 250k iterations
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --use-env --nproc_per_node=8 --master_port=1145  basicsr/train.py -opt options/train/003_ATD_SRx4_finetune.yml --launcher pytorch
  • ATD-light (Lightweight Image Super-Resolution)
# batch size = 2 (GPUs) × 16 (per GPU)
# training dataset: DIV2K

# ×2 scratch, input size = 64×64, 500k iterations
CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --use-env --nproc_per_node=2 --master_port=1145  basicsr/train.py -opt options/train/101_ATD_light_SRx2_scratch.yml --launcher pytorch

# ×3 finetune, input size = 64×64, 250k iterations
CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --use-env --nproc_per_node=2 --master_port=1145  basicsr/train.py -opt options/train/102_ATD_light_SRx3_finetune.yml --launcher pytorch

# ×4 finetune, input size = 64×64, 250k iterations
CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --use-env --nproc_per_node=2 --master_port=1145  basicsr/train.py -opt options/train/103_ATD_light_SRx4_finetune.yml --launcher pytorch

Testing

Data Preparation

  • Download the testing data (Set5 + Set14 + BSD100 + Urban100 + Manga109 [download]) and put them in the folder ./datasets.

Pretrained Models

  • Download the pretrained models and put them in the folder ./experiments/pretrained_models.

Testing Commands

  • Refer to the testing configuration files in ./options/test folder for detailed settings.
  • ATD (Classical Image Super-Resolution)
python basicsr/test.py -opt options/test/001_ATD_SRx2_finetune.yml
python basicsr/test.py -opt options/test/002_ATD_SRx3_finetune.yml
python basicsr/test.py -opt options/test/003_ATD_SRx4_finetune.yml
  • ATD-light (Lightweight Image Super-Resolution)
python basicsr/test.py -opt options/test/101_ATD_light_SRx2_scratch.yml
python basicsr/test.py -opt options/test/102_ATD_light_SRx3_finetune.yml
python basicsr/test.py -opt options/test/103_ATD_light_SRx4_finetune.yml

Results

  • Classical Image Super-Resolution

  • Lightweight Image Super-Resolution

Visual Results

  • Complete visual results can be downloaded from link.

Citation

@misc{zhang2024transcending,
      title={Transcending the Limit of Local Window: Advanced Super-Resolution Transformer with Adaptive Token Dictionary}, 
      author={Leheng Zhang and Yawei Li and Xingyu Zhou and Xiaorui Zhao and Shuhang Gu},
      year={2024},
      eprint={2401.08209},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

Acknowledgements

This code is built on BasicSR.

About

CVPR2024 - Transcending the Limit of Local Window: Advanced Super-Resolution Transformer with Adaptive Token Dictionary

License:Apache License 2.0


Languages

Language:Python 100.0%