LeeSureman / ConRetriever

Use contrastive learning to train a large language model (LLM) as a retriever

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

ConRetriever

ConRetriever is desinged to use contrastive learning to train a large language model (LLM) as a retriever. We provide both evaluation code and synthetic data code simultaneously in this repository.

Contents

Install

  1. Clone this repository and navigate to the ConRetriever folder
git clone url_path_to_this_repo
cd ConRetriever
  1. Install Package
pip install --upgrade pip
pip install -r requirements.txt

Train_Eval

For simplicity, we place the training code and testing code in the same script named train_eval_llm_retriever.sh.

In the script, we use mistralai/Mistral-7B-v0.1 as an example for illustration. Before beginning training, two very important steps must be completed: data preparation and determination of the transformer layer.

Data Preparation

For training data, we use the following format:

# Assuming there are data for two tasks in a folder named `demo`
demo
|--- task1.jsonl
|--- task2.jsonl

# For each task, the data follows the following format (If the task is not "synthetic", the instruction field can be omitted)
"""
{
    "query": "What are some popular Italian pasta recipes?",
    "positive": [
        "Italian cuisine is known for its diverse range of pasta dishes. From classic favorites like spaghetti carbonara and fettuccine alfredo to regional specialties like lasagna and ravioli, Italian pasta recipes offer a wide variety of flavors and ingredients. One popular recipe is penne arrabbiata, which is made with penne pasta, a spicy tomato sauce, garlic, and red chili flakes. Another delicious option is tortellini with pesto sauce, where homemade tortellini pasta is filled with a mixture of cheese and served with a flavorful basil pesto sauce. For seafood lovers, linguine with clams is a must-try dish, featuring linguine pasta tossed with fresh clams, garlic, white wine, and parsley. Additionally, pasta primavera is a delightful vegetarian option made with mixed vegetables, cream, and Parmesan cheese. These are just a few examples of the countless Italian pasta recipes that you can explore and enjoy."
    ],
    "negative": [
        "Italian cuisine is famous for its delectable pasta dishes. One of the most popular pasta recipes is spaghetti carbonara, which originated in Rome and features pasta tossed with a creamy egg and pancetta sauce. Another classic Italian dish is fettuccine alfredo, where fettuccine noodles are coated in a rich Parmesan cheese sauce. Lasagna is another beloved Italian pasta dish, made with layers of pasta, meat sauce, and cheese. Additionally, ravioli is a traditional Italian pasta dish consisting of stuffed pasta pockets served with various sauces. Italian pasta recipes are loved worldwide for their simplicity, fresh ingredients, and bold flavors."
    ],
    "instruction": "Given a food cuisine, retrieve recipes or restaurant reviews from that cuisine. "
}
"""

Please check the sample data for more information: demo.jsonl and synthetic.jsonl.

Then, set the sampling ratio, query type, and message type in the file of task_config.py.

Determine the transformer layer

Correctly set the value of parameter fsdp_transformer_layer_cls_to_wrap in the train_eval_llm_retriever.sh file. Use the following code to get the correct value. Here, the transformer layer is the MistralDecoderLayer in the Mistral model. If you use other models, you can also use the same method to determine the transformer layer.

from transformers import AutoModelForCausalLM

model_name_or_path = 'mistralai/Mistral-7B-v0.1'
model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
print(model)

"""Output:
MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x MistralDecoderLayer(
        (self_attn): MistralSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MistralRotaryEmbedding()
        )
        (mlp): MistralMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): MistralRMSNorm()
        (post_attention_layernorm): MistralRMSNorm()
      )
    )
    (norm): MistralRMSNorm()
  )
  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)
"""

Synthesis

If you want to use synthetic data to train a model as a retriever, you need to first generate synthetic data using 1) brainstorm_task.sh and 2) generate_examples.sh. We follow [Wang et al., 2023] to generate synthetic data.

Acknowledgement

For training, we use GradCache to enable contrastive learning training with large batch size.

For evaluation, we use e5 to evaluate the performance of the model.

About

Use contrastive learning to train a large language model (LLM) as a retriever


Languages

Language:Python 95.2%Language:Shell 4.8%