littlewwwhite / RetroMAE

Codebase for RetroMAE and beyond.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

RetroMAE

Codebase for RetroMAE and beyond.

What's New

Released Models

We have uploaded some checkpoints to Huggingface Hub.

Model Description Link
RetroMAE Pre-trianed on the wikipedia and bookcorpus Shitao/RetroMAE
RetroMAE_MSMARCO Pre-trianed on the MSMARCO passage Shitao/RetroMAE_MSMARCO
RetroMAE_MSMARCO_finetune Finetune the RetroMAE_MSMARCO on the MSMARCO passage data Shitao/RetroMAE_MSMARCO_finetune
RetroMAE_MSMARCO_distill Finetune the RetroMAE_MSMARCO on the MSMARCO passage data by minimizing the KL-divergence with the cross-encoder Shitao/RetroMAE_MSMARCO_distill
RetroMAE_BEIR Finetune the RetroMAE on the MSMARCO passage data for BEIR (use the official negatives provided by BEIR) Shitao/RetroMAE_BEIR

You can load them easily using the identifier strings. For example:

from transformers import AutoModel
model = AutoModel.from_pretrained('Shitao/RetroMAE')

State of the Art Performance

RetroMAE can provide a strong initialization of dense retriever; after fine-tuned with in-domain data, it gives rise to a high-quality supervised retrieval performance in the corresponding scenario. Besides, It substantially improves the pre-trained model's transferability, which helps to result in superior zero-shot performances on out-of-domain datasets.

MSMARCO Passage

  • Model pre-trained on wikipedia and bookcorpus:
Model MRR@10 Recall@1000
Bert 0.346 0.964
RetroMAE 0.382 0.981
  • Model pre-trained on MSMARCO:
Model MRR@10 Recall@1000
coCondenser 0.382 0.984
RetroMAE 0.393 0.985
RetroMAE(distillation) 0.416 0.988

BEIR Benchemark

Model Avg NDCG@10 (18 datasets)
Bert 0.371
Condenser 0.407
RetroMAE 0.452
RetroMAE v2 0.491

Installation

git clone https://github.com/staoxiao/RetroMAE.git
cd RetroMAE
pip install .

For development, install as editable:

pip install -e .

Workflow

This repo includes two functions: pre-train and finetune. Firstly, train the RetroMAE on general dataset (or downstream dataset) with mask language modeling loss. Then finetune the RetroMAE on downstream dataset with contrastive loss. To achieve a better performance, you also can finetune the RetroMAE by distillation the scores provided by cross-encoder. Detailed workflow please refer to our examples.

Pretrain

python -m torch.distributed.launch --nproc_per_node 8 \
  -m pretrain.run \
  --output_dir {path to save ckpt} \
  --data_dir {your data} \
  --do_train True \
  --model_name_or_path bert-base-uncased \
  --pretrain_method {retromae or dupmae}

Finetune

python -m torch.distributed.launch --nproc_per_node 8 \
-m bi_encoder.run \
--output_dir {path to save ckpt} \
--model_name_or_path Shitao/RetroMAE \
--do_train  \
--corpus_file ./data/BertTokenizer_data/corpus \
--train_query_file ./data/BertTokenizer_data/train_query \
--train_qrels ./data/BertTokenizer_data/train_qrels.txt \
--neg_file ./data/train_negs.tsv 

Examples

Citation

If you find our work helpful, please consider citing us:

@inproceedings{RetroMAE,
  title={RetroMAE: Pre-Training Retrieval-oriented Language Models Via Masked Auto-Encoder},
  author={Shitao Xiao, Zheng Liu, Yingxia Shao, Zhao Cao},
  url={https://arxiv.org/abs/2205.12035},
  booktitle ={EMNLP},
  year={2022},
}

About

Codebase for RetroMAE and beyond.

License:Apache License 2.0


Languages

Language:Python 100.0%