This is the repository for the paper RECOMP: Improving Retrieval-Augmented LMs with Compression and Selective Augmentation.
Download the files from here and place them in the data/
directory.
data/
- prompts/ # prompts with uncompressed and compressed retrieved documents
- extractive_compressor_inputs/ # sentences passed to extractive compressors
- abstractive_compressor_inputs/ # retrieved documents passed to abstractive compressors
Extractive compressor
Abstractive compressor
Follow steps here to prepare retrieval documents.
We use the below command (with this script).
python prepare_retrieval_data.py \
--retrieval_type sparse \
--tokenizer_name gpt2 \
--max_length 1024 \
--dataset_path wikitext \
--dataset_name wikitext-103-v1 \
--dataset_split validation \
--index_name wikipedia-dpr \
--forbidden_titles_path ralm/retrievers/wikitext103_forbidden_titles.txt \
--stride 32 \
--output_file gpt2_wikitext_validation_retrieval_files_32 \
--num_tokens_for_query 32 \
--num_docs 16
Run the below command to evaluate perplexity with different sets of retrieved documents. For example, to evaluate uncompressed RALM with GPT-2 on wikitext, run:
python eval_lm.py \
--model_name gpt-2 \
--dataset_path wikitext \
--dataset_name wikitext-103-v1 \
--dataset_split validation \
--output_dir outputs/gpt-2 \
--stride 32 \
--max_length 1024 \
--model_parallelism \
--retrieved_file gpt2_wikitext_validation_retrieval_files_32
Run the below script to prompt FLAN-UL2. We release compressed documents in data/
for each dataset.
python prompt_flan.py \
--input_data_csv_file [input_data_file] \
--output_data_csv_file [name_of_output_data]
Run the below script to score the sentences with the train compressor by passing the path to --model_path
. To run baseline compressors, pass in --model_type
instead.
For example, to score sentences with extractive compressor for top 5 retrieved documents from NQ, run:
python run_extractive_compresor.py \
--input_data data/extractive_compressor_intputs/flan_ul2_nq_5shot_top_5_passage_new_msmarco_sent.json \
--model_path fangyuan/nq_extractive_compressor \
--output_file outputs/flan_ul2_nq_5shot_top_5_passage_msmarco_sent_with_scores.json \
--top_k -1 # consider all sentences
Run the below script to compress retrieved documents with abstractive compressor. For example, to compress top 5 retrieved documents from NQ, run:
python train_hf_summarization_model.py \
--model_name_or_path fangyuan/nq_abstractive_compressor \
--do_predict \
--test_file abstractive_compressor_inputs/nq_dev_contriever_msmarco_top_5_docs.json \
--max_target_length 512 \
--output_dir outputs/ \
--per_device_train_batch_size=4 \
--per_device_eval_batch_size=16 \
--predict_with_generate