qumengxue / siri-vg

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

SiRi: A Simple Selective Retraining Mechanism for Transformer-based Visual Grounding

This repository is an official PyTorch implementation of the ECCV 2022 paper SiRi: A Simple Selective Retraining Mechanism for Transformer-based Visual Grounding.

Introduction

we investigate a new training mechanism to improve the Transformer encoder, named Selective Retraining (SiRi), which continually update the parameters of the encoder while periodically re-initialize the rest parameters as the training goes on. In this way, the model can be better optimized based on an enhanced encoder. Figure below shows the training process of SiRi. For more details. please refer to our paper.

SiRi

Updates

  • 2022-7-25 Code and model link of SiRi in MDETR-like model on task REC
  • We will update the code and models on TransVG and other VL tasks such as RES.

Installation

Environment:

  • We provide instructions how to install dependencies via conda. First, clone the repository locally:
    git clone https://github.com/qumengxue/siri-vg.git
    
  • Make a new conda env and activate it:
    conda create -n siri python=3.8
    conda activate siri
    
  • Install the the packages in the requirements.txt:
    pip install -r requirements.txt
    

Dataset preparation

  • Prepare COCO training set ("train2014")
  • Download the pre-processed annotations that are converted to coco format in MDETR.
  • Modify the config file under configs/ according to your dataset path, especially coco_path, refexp_ann_path.

For more installation details, please see the repository of MDETR, our code is built based on it.

Training

  • For example, if with 2 decoders and 8 retraining periods in RefCOCOg, run
    sh refcocog.sh
    
  • For individual initial training, run
    python -m torch.distributed.launch   --nproc_per_node=4  --use_env main.py  --dataset_config configs/refcocog.json  --batch_size 18  --output-dir exps/refcocog_retrain_2decoder_1/   --ema   --lr 5e-5   --lr_backbone 5e-5   --text_encoder_lr 1e-5  --num_queries 16  --no_contrastive_align_loss  --cascade_num 2
    

Evaluation

  • Training with running *.sh will automatically evaluate for each round of SiRi, so you can check it directly.
  • For individual model evaluation, run
    python -m torch.distributed.launch  --nproc_per_node=4  --use_env main.py  --dataset_config configs/refcocog.json  --batch_size 18 --output-dir exps/   --ema   --lr 5e-5   --lr_backbone 5e-5   --text_encoder_lr 1e-5  --num_queries 16  --no_contrastive_align_loss  --cascade_num 2  --resume exps/refcocog_1d.pth  --eval
    

Model Zoo

TASK1: Referring Expression Comprehension

  • RefCOCO
Model val testA testB model
MDETR* +SiRi 85.83 88.56 81.27 gdrive
MDETR* + MT SiRi 85.82 89.11 81.08 gdrive
  • RefCOCO+
Model val testA testB model
MDETR* +SiRi 76.68 (76.63) 82.01 (81.99) 66.33 (66.86) gdrive
MDETR* + MT SiRi 77.47 (77.53) 83.04 (82.47) 67.11 (67.89) gdrive
  • RefCOCOg
Model val test model
MDETR* +SiRi 76.63 76.46 gdrive
MDETR* + MT SiRi 77.39 76.80 gdrive

TASK2: Referring Expression Segmentation

Coming soon!

Citing SiRi

@inproceedings{qu2022siri,
  title={SiRi: A Simple Selective Retraining Mechanism for Transformer-based Visual Grounding},
  author={Qu,Mengxue and Wu, Yu and Liu, Wu and Gong, Qiqi and Liang, Xiaodan and Olga, Russakovsky and Zhao, Yao and Wei, Yunchao},
  booktitle={ECCV},
  year={2022}
}

Acknowledgement

Our code is built on the previous work MDETR.

About

License:Apache License 2.0


Languages

Language:Python 99.5%Language:Shell 0.5%