Official repository containing the pruner the training scripts for R2D2

Pruning the Index Contents for Memory Efficient Open-Domain QA

This repository contains the official implementation accompanying our preprint. The sources present in this repository can be used to train new models.

Please note our paper is accompanied with two repositories. If you are interested in run model inference in pipeline instead, check the R2-D2-pipeline repository.

If you use this code, please cite our preprint:

Table of Contents



Set your system's locale.

export LANG=en_US.UTF-8 LC_ALL=en_US.UTF-8

Install this package using python3.6.

git clone
cd scalingQA; python -m pip install -r requirements.txt; python install



Following hyperlinks contain preprocessed datasets required for training

These files were created using DPR retrieval over all 21M passages of Wikipedia.
Additionaly, we also release original files we used as inputs to DPR to generate the preprocessed datasets for readers and reranker.

If you would like to process your custom data, follow "Retrieving the results via DPR (optional)" guide at the end of this README.


SQLite database of 21M passages is available here.
Embedding matrix for full 21M passages is available here.

Training R2-D2 models

Passage Reranker

Data Pre-processing

The datasets mentioned above comprise a set of the best-retrieved passages and one ground truth passage if it exists. For several samples, no retriever passage contains an answer, and the ground truth is unknown. Those samples should be removed from reranker training data using the following command:

grep -v '"gt_index": -1, "hit_rank": -1,' [INPUT] > [FILTERED_OUTPUT]

Training the Model

The scripts for passage reranker training can be found in the folder scalingqa/reranker/training. See help for more information about training configuration.

python -m --help

Our results should be easily replicable using several ready-made scripts, e.g. for the NQ dataset:

python -m

Note that the GPU with at least that 12 GB of RAM (tested on GeForce RTX 2080Ti) is required for training.

Reranker Outside the Pipeline

The passage ranker can be run separately on input in the same format as training data. See help for more information:

python -m scalingqa.reranker.run_reranker --help

Extractive Reader

Data Pre-processing

The extractive reader always expects at least one answer span per a training example. To ensure this run:

python -m scalingqa.extractivereader.run_extractive_reader_filter

The filtering script can be configured. An example of a configuration file for the filter is:


Training the Model

To train you own model use:

python -m scalingqa.extractivereader.run_extractive_reader_train

An example of a configuration file for the training script is:


If you want to learn more about the usage of our scripts, read descriptions in configuration files. There are also ready to run toy examples in



To replicate training of our model for NaturalQuestions-Open run:


for TriviaQA-Open:


The scripts expect that all data files are already in the .data folder (see configurations in scalingqa/extractivereader/replicate). They also run the filtering.

Generative Reader

Training the Model

The run-files for replicating our results on NQ and Trivia are available in folder scalingqa/generative_reader/training. To run the training, adjust the config dictionary right inside the file (you will probably want to set the paths to your data and to output directories).

config = {
    "save_dir": ".saved",  # where the checkpoints will be saved
    "results": ".results",  # where validation results will be saved
    "validate_after_steps": 500,  # validation period, divided by 2 after 2/3 of training 

    # Data
    "data_cache_dir": ".data/reader/NQ/ranked/", # where the preprocessed datafiles will be cached
    "train_data": ".data/reader/NQ/ranked/NQ-open_TRAINING_maxlen_5_ms_with_dpr_annotation.jsonl_dpr_official_nqsingle_of_impossible.jsonl",
    "val_data": ".data/reader/NQ/ranked/NQ-open_DEV_maxlen_5_ms_with_dpr_annotation.json_dpr_official_nqsingle_of_impossible.jsonl",
    "test_data": ".data/reader/NQ/ranked/NQ-open_TEST.jsonl_nq-open_dpr_official_nqsingle_of_impossible.jsonl",
    "pass_database": ".index/wiki2018_dpr_blocks.db",  # database of passages and titles

    # number of passages encoded from mini-batch
    #   for training dataset there is always the ground truth passage and the rest is filled with the others recommended by retriever
    #   for validation dataset only the passages from retriever are used
    "context_length": 25,  # number of passages at the input of FiD
    # ...

Afterwards simply run the module to e.g. replicate the results of FiD-large on NQ

python -m

Note that training is expected to run with on-hardware-batch size 1. FiD-large on NQ takes about 9 days to converge on the single RTX 8000 48GB GPU.

Common Use-Cases

To evaluate some checkpoint on the test data, add its path into config dictionary under "pre_initialize" key and set "test_only" to True:

config = {
    "pre_initialize": PATH_TO_CHECKPOINT,
    "test_only": True,
    # ...

To resume training from some checkpoint, use "resume_training" and "resume_checkpoint" in analogously to previous example.

config = {
    "resume_checkpoint": PATH_TO_CHECKPOINT,
    "resume_training": True,
    # ...

You can also train system in mixed precision (see flag "fp16"). Note that while the system seems to converge after initial updates, we have never fully trained it, and thus cannot guarantee that it works as intended.

To "try out, if it works", you can try out toy-example run-file, which runs the FiD-base training using just 2 retrieved passages (runs on 12 GB GPU).

Exporting the Checkpoint for R2-D2 Pipeline

To use the trained checkpoint in R2-D2 pipeline, the checkpoint needs to be resctructured so it contains just a state dictionary and a model configuration. This can be done via script scalingqa/generative_reader/training/

python -m INPUT_FILE OUTPUT_FILE [fp16]

You can use option fp16 to save checkpoint in 16-bit precision.

Retrieving the Data via DPR (Optional)

Here we describe how to process your custom dataset which follows the same format as NQ-open or TriviaQA-open via retriever.
Firstly, you will need to adjust the configuration in scalingqa/retriever/ script. You will need to change the contents of config dictionary at the start of the file. Here is an example, how this configuration might look:

config = {
    # Omit option, if you do not have the file in your split 
    # (e.g. if you have only training/test split, comment-out "test_data_file" option here
    # Path to your training data
    "training_data_file": ".data/nqopen/nq-open_train_short_maxlen_5_ms_with_dpr_annotation.jsonl",
    # Path to your validation data
    "validation_data_file": ".data/nqopen/nq-open_dev_short_maxlen_5_ms_with_dpr_annotation.jsonl",
    # Path to your test data
    "test_data_file": ".data/nqopen/NQ-open-test.jsonl",

    # Output directory, where to save files with retrievan information
    "output_directory": "retrieved_data",

    # Path to your passage embeddings
    "embeddings": ".embeddings/DPR_nqsingle_official.h5",
    # Path to databse containing passages
    "db_path": ".wikipedia/wiki2018_dpr_blocks.db",
    # Path to retriever model
    "model_path": ".checkpoints/",
    # How many top-K passage indices to save into the output file
    "topK_extract": 400,
    # ...
  • Note Trivia files also contain "human_answer" entry for each example, which is used to supervise the FiD reader.
  • This code does exact retrieval (dot-product with the embedding matrix). Therefore if you use full matrix of 21M passages in this step, you will need to fit it into your RAM (~65GB).
  • You can find download urls to compressed index/database/retriever in every R2-D2-pipeline configuration (for example, check configurations/pipeline/NQ/r2d2_full.json to get files needed to run this code snippet).

Afterwards simply run the module to extract the DPR's predictions.

python -m scalingqa.retriever.extract_DPR_predictions

Pruning the Index Contents

1. Constructing Golden Dataset (dataset with relevant and irrelevant passages)

For building NQ-Golden set, run script scalingqa/index_pruning/dataset/NQ/

python -m scalingqa.index_pruning.dataset.NQ.build_dataset

The script works with 4 arguments. They are not passed, please edit them directly in the script's main() function

raw_passages_db = ".index/wiki2018_dpr_blocks.db"
output_folder = ".data/nq_corpus_pruning"
training_data_source = ".data/nqopen/nq-open_train_short_maxlen_5_ms_with_dpr_annotation.jsonl"
validation_data_source = ".data/nqopen/nq-open_dev_short_maxlen_5_ms_with_dpr_annotation.jsonl"

Note the data here are the same as inputs to DPR use to generate data for reranker and reader training. Validation and Test sets for this task are build from nq-open's validation set. You should end up with 176,628 examples for training, 4,332 examples for validation, and examples 8,698 for testing on NQ.

Similarly, you can use scalingqa/index_pruning/dataset/Trivia/ to build Trivia-Golden dataset.

2. Training the Irrelevant Passage Classifier (Pruner)


python -m[NQ|TRIVIA]

Adjust the parameters in the config if needed; in particular, you might be interested in setting paths to your data. For example, the defaults for NQ dataset are:

    "data_cache_dir": '.data/nq_corpus_pruning',
    "training_data": "train.jsonl",
    "validation_data": "val.jsonl",
    "test_data": "test.jsonl",

The training takes about 1.5h on 2080Ti 12 GB GPU for both datasets. In the paper we use the following checkpoints.

3. Inferring Irrelevant Passage's Probabilities

Now when the model is training, the next step is to extract the irrelevance probability for each passage. Extract probabilities for each passage into h5 matrix via:

python -m scalingqa.index_pruning.inference.run_irrelevant_doc_predictor

The parameters can be again adjusted inside runfile's config:

    "passage_source": ".data/index/psgs_w100.tsv", # all passages from DPR
    "prob_file": ".pruning/psgs_w100_pruneprobs.h5", # output file
    "cls_checkpoint": ".saved/" # checkpoint from training

This is usually the longest step. For 21M passages, it takes about 24h to extract the probabilities. To get the wikipedia passages, you can use this link available in the official DPR implementation.

You can get the extracted probabilities we used in the paper from the following links:

4. Choosing the Relevant Documents

Now, prune the index (manually) via jupyter-notebook file scalingqa/index_pruning/inference/get_pruning_index.ipynb. There, you can select the number of passages or manually adjust the threshold for pruner. Running the notebook will create a file containing set of all passage indices to keep in the index.

5. Dumping the Pruned Index

Finally, the embedding index and the database can be pruned. You can use index_pruning/inference/ to prune embedding matrix. Adjust paths to full embeddings (FULL_EMBEDDINGS) and file from previous step (PRUNE_FILE) directly in the file.

python -m scalingqa.index_pruning.inference.prune_embeddings

Analogously, use index_pruning/inference/ to prune the SQLite database. There adjust path to databse (FULL_DB_PATH) and PRUNE_FILE.

python -m scalingqa.index_pruning.inference.prune_db

See any of the configurations/pipeline/[NQ|Trivia]/*_pruned.json files in R2-D2-pipeline for links to pruned versions of NQ/Trivia index we used in the paper.


