AkariAsai / self-rag

This includes the original implementation of SELF-RAG: Learning to Retrieve, Generate and Critique through self-reflection by Akari Asai, Zeqiu Wu, Yizhong Wang, Avirup Sil, and Hannaneh Hajishirzi.

Home Page:https://selfrag.github.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Why is eval logic so complicated?

emrgnt-cmplxty opened this issue · comments

Hi,

Two major questions cross my mind when looking over this repository

1.) Would you like access to an optimized compressed FAISS index? I ran a simple optimization this morning and shrunk my index from ~60gb to 5.5gb without any obvious degradation. This might help others in replicating the retrieval pipeline.
2.) Will I be able to reproduce your results in production if I elect to just retrieve on '[Retrieval]' tokens, rather than performing this calculation -

        if threshold is not None:
            score_dict = {}
            for tok, id in ret_tokens.items():
                if id not in pred_log_probs[0]:
                    score_dict[tok] = -100
                prob = pred_log_probs[0][id]
                score_dict[tok] = float(prob)
            do_retrieve = score_dict["[Retrieval]"] / (
                score_dict["[Retrieval]"] + score_dict["[No Retrieval]"]) > threshold

E.g. how different is this in practice retrieving only when the retrieval token is the sampled result?

Re 1:
That sounds awesome! I tried the HNSW index to reduce the inference time efficiency (at the additional storage costs) and it gives us some performance deterioration, possibly due to bad hyperparameters. Reducing the index size is another issue that is helpful. Do you mind sharing more details?

Re 2:
Sections 3.3 and Appendix Section A.3 explain why we did that, but it's essentially because we want to adjust retrieval frequency for different downstream tasks. As shown in Figure 4, retrieving more helps a lot on PopQA, which is dominated by many rare entities, while it doesn't affect much for PubHealth (claim verification task). Our Table 3, Hard constraints baselines show the performance when we only retrieves when the retrieval tokens are generated.

Re 1:

Awesome.

Another thing I did was move away from loading all passages into memory (this + the original FAISS index blew up my instance with ~80gb of mem). Further, I saw that inference times were on the order of 20s with the implementation in this repo, optimizing the FAISS index reduced the size 90% and inference speed to ~10ms. I haven't extensively checked the evaluation performance, but doing some spot checks it all looked reasonable.

The script for generating the compressed index is here. The db implementation is here. I am planning on hosting the index + a self-rag model today that others can access once I have smoothed out a few rough edges around the infra.

I am very inspired by this work, and so I am also building a simplified approach to using models which follow this format. I have built a simple framework that allows one to attach a local LLM to a remote vLLM provider + rag db server (boiler plate implementation is in sciphi-infra). And I have worked out a simple way to do inference.

My goal is to reduce the interface for creating a self-rag LLM to that shown below:

    llm = SciPhiLLM(
        SciPhiConfig(
            server_base="http://localhost:8000/v1",
            rag_provider_base="http://localhost:8001/search",
            rag_provider_token="",
        )
    )

w/ the option to just use local vllm if one doesn't want to make an server on their instance. This is something which I have working in my local setup and just need to cleanup / push.

Lastly,

I am re-running your fine-tune on Mistral to see how impacts your findings. I have generally found Mistral to be much better than llama-7b, so I'm excited to see how that shakes out.

Re 2:

Ah, thank you for the thoughtful explanation. I hope it is not a terrible approximation for now to simply stop on '' tokens and do the retrieval when the retrieval token precedes this. I will look into implementing the more complicated logic at a later date.

Great discussion about optimization. I’m hoping there is limited performance regressions with the optimized code and when trying it out on Mistral!

FYI, added issue on license with your inference code @emrgnt-cmplxty

Thanks, that's Apache 2.0 - I uploaded the license.

Hi @emrgnt-cmplxty thank you so much for all of the contributions!! I will try to reduce the inference time memory requirement for retrieval part following your suggestions & snippets!

Closing this issue now, but feel free to reopen it if you want! Thank you so much for such an amazing followup, @emrgnt-cmplxty !

Re 1:

Awesome.

Another thing I did was move away from loading all passages into memory (this + the original FAISS index blew up my instance with ~80gb of mem). Further, I saw that inference times were on the order of 20s with the implementation in this repo, optimizing the FAISS index reduced the size 90% and inference speed to ~10ms. I haven't extensively checked the evaluation performance, but doing some spot checks it all looked reasonable.

The script for generating the compressed index is here. The db implementation is here. I am planning on hosting the index + a self-rag model today that others can access once I have smoothed out a few rough edges around the infra.

I am very inspired by this work, and so I am also building a simplified approach to using models which follow this format. I have built a simple framework that allows one to attach a local LLM to a remote vLLM provider + rag db server (boiler plate implementation is in sciphi-infra). And I have worked out a simple way to do inference.

My goal is to reduce the interface for creating a self-rag LLM to that shown below:

    llm = SciPhiLLM(
        SciPhiConfig(
            server_base="http://localhost:8000/v1",
            rag_provider_base="http://localhost:8001/search",
            rag_provider_token="",
        )
    )

w/ the option to just use local vllm if one doesn't want to make an server on their instance. This is something which I have working in my local setup and just need to cleanup / push.

Lastly,

I am re-running your fine-tune on Mistral to see how impacts your findings. I have generally found Mistral to be much better than llama-7b, so I'm excited to see how that shakes out.

Re 2:

Ah, thank you for the thoughtful explanation. I hope it is not a terrible approximation for now to simply stop on '' tokens and do the retrieval when the retrieval token precedes this. I will look into implementing the more complicated logic at a later date.

Hi, I guess the link is out-of-date, could you please update this link? I would like to look into details about the optimization in memory during the wiki-reader and FAISS embedding process. @emrgnt-cmplxty