kkgoer / query-selected-attention

Official implementation for "QS-Attn: Query-Selected Attention for Contrastive Learning in I2I Translation" (CVPR 2022)

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

QS-Attn: Query-Selected Attention for Contrastive Learning in I2I Translation (CVPR2022)

https://arxiv.org/abs/2203.08483

Unpaired image-to-image (I2I) translation often requires to maximize the mutual information between the source and the translated images across different domains, which is critical for the generator to keep the source content and prevent it from unnecessary modifications. The self-supervised contrastive learning has already been successfully applied in the I2I. By constraining features from the same location to be closer than those from different ones, it implicitly ensures the result to take content from the source. However, previous work uses the features from random locations to impose the constraint, which may not be appropriate since some locations contain less information of source domain. Moreover, the feature itself does not reflect the relation with others. This paper deals with these problems by intentionally selecting significant anchor points for contrastive learning. We design a query-selected attention (QS-Attn) module, which compares feature distances in the source domain, giving an attention matrix with a probability distribution in each row. Then we select queries according to their measurement of significance, computed from the distribution. The selected ones are regarded as anchors for contrastive loss. At the same time, the reduced attention matrix is employed to route features in both domains, so that source relations maintain in the synthesis. We validate our proposed method in three different I2I datasets, showing that it increases the image quality without adding learnable parameters.



QS-Attn applies attention to select anchors for contrastive learning in single-direction I2I task

Getting Started

Prerequisites

  • Ubuntu 16.04
  • NVIDIA GPU + CUDA CuDNN
  • Python 3 Please use pip install -r requirements.txt to install the dependencies.

Pretrained Models

Coming soon!

Training

  • Download horse2zebra dataset :
bash ./datasets/download_qsattn_dataset.sh horse2zebra
  • Train the global model:
python train.py \
--dataroot=datasets/horse2zebra \
--name=horse2zebra_qsattn_global \
--QS_mode=global
  • You can use visdom to view the training loss: Run python -m visdom.server and click the URL http://localhost:8097.

Inference

  • Test the global model:
python test.py \
--dataroot=datasets/horse2zebra \
--name=horse2zebra_qsattn_global \
--QS_mode=global

Citation

If you use this code for your research, please cite

@article{hu2022qs,
  title={QS-Attn: Query-Selected Attention for Contrastive Learning in I2I Translation},
  author={Hu, Xueqi and Zhou, Xinyue and Huang, Qiusheng and Shi, Zhengyi and Sun, Li and Li, Qingli},
  journal={arXiv preprint arXiv:2203.08483},
  year={2022}
}

About

Official implementation for "QS-Attn: Query-Selected Attention for Contrastive Learning in I2I Translation" (CVPR 2022)


Languages

Language:Python 100.0%