AkariAsai / self-rag

This includes the original implementation of SELF-RAG: Learning to Retrieve, Generate and Critique through self-reflection by Akari Asai, Zeqiu Wu, Yizhong Wang, Avirup Sil, and Hannaneh Hajishirzi.

Home Page:https://selfrag.github.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

RAG Baselines

robertgshaw2-neuralmagic opened this issue · comments

Hey @AkariAsai, great work and thank you very much for putting together a nice hugging face model, dataset, and repo for reproducing and extending the results :)

In terms of the baselines for RAG, I see the paper describe:

Baselines with retrievals. We evaluate models augmented with retrieval at test time or during training.
The first category includes standard RAG baselines, where an LM (Llama2, Alpaca) generates output
given the query prepended with the top retrieved documents using the same retriever as in our system.

I was wondering if you had the scripts available for running the evaluations of Llama2 and Llama-2 chat with the described setup? It seemed to me that run_short_form.py should only work for models with the special tokens.

Thanks in advance!

I've uploaded the script to run baseline LMs!
https://github.com/AkariAsai/self-rag/blob/main/retrieval_lm/run_baseline_lm.py
I'll add documentations to run baselines, but essentially, you just need to specify the model name, and pass the same input file as in the Self-RAG pipeline. For retrieval baseline, please use --mode retrieval --prompt_name "prompt_no_input_retrieval" option to trigger retrieval.

e.g., Llama2-7b (pre-trained)

python run_baseline_refactor.py \
--model_name meta-llama/Llama-2-7b-hf \
--input_file INPUT_FILE_SAME_AS_SELF_RAG \
 --max_new_tokens 100 --metric match \
--result_fp RESULT_FILE_PATH --task qa --mode retrieval --prompt_name "prompt_no_input_retrieval"

e.g., ChatGT (March)

python run_baseline_refactor.py \
--model_name gpt-3.5-turbo-0301 \
--input_file INPUT_FILE_SAME_AS_SELF_RAG \
--max_new_tokens 100 --metric match \
--result_fp RESULT_FILE_PATH \
 --task qa \
--api_key YOUR_OPEN_AI_API_KEY_FILE \
--mode retrieval --prompt_name "prompt_no_input_retrieval" 

For OpenAI API models, you also need to set organization key here: https://github.com/AkariAsai/self-rag/blob/main/retrieval_lm/run_baseline_lm.py#L12

I close this issue now but let me know if you have any further questions!

Thank you :)