argilla-io / distilabel-spin-dibt

Repository containing the SPIN experiments on the DIBT 10k ranked prompts

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Built with Distilabel

distilabel-spin-dibt

SPIN experiments on the DIBT 10k ranked prompts.

This repository contains the instructions to run SPIN on a subset of the DIBT/10k_prompts_ranked dataset: Those that have avg_rating>=4 and num_response>1, making a total of 1832 records (which will then be splitted in 1648 for training and 184 for testing).

It contains the references to all the scripts to generate the datasets, the configuration files used for the training process and the setup used to run the model. The dataset generation was done using distilabel==0.6.0.

SPIN needs a specific format for the data to do the training, where the "real" data is the reference for the model to improve. As the dataset was made of prompts, we decided to generate these responses using mistral-large. The different iterations of the "generated" datasets were created using distilabel with vllm, using 2 A100 GPUs (just for speed, it should work with less computer power, just need to update the --cuda-devices and --batch-size arguments accordingly).

Contribute to the DIBT prompt collective

This work shows the huge benefit of collecting high-quality prompts for LLM fine-tuning. If you want to support the OSS community with larger datasets, contribute to the Prompt Collective initiative.

Prepare the data

Initially, we create the reference dataset with the real responses being generated from mistral-large, using the following script:

Experiment top prompts

The following are the steps to prepare the training data for SPIN, and the resulting datasets:

SPIN iter 0
  • generate_iter_spin.py Script to generate the initial "generated" responses, from the SFT model that will then be fine-tuned.

    Dataset: argilla/10k_prompts_ranked_sft_zephyr

    Run the following:

    python generate_iter_spin.py \
        --hf-apikey $HF_API_TOKEN \
        --source-dataset "DIBT/10k_prompts_ranked" \
        --new-dataset "argilla/10k_prompts_ranked_sft_zephyr" \
        --model-name "alignment-handbook/zephyr-7b-sft-full" \
        --batch-size 128 \
        --cuda-devices "0,1"
  • prepare_for_training.py Generates the dataset that will be directly ingested in the SPINTrainer.

    Dataset: argilla/10k_prompts_top_SPIN_iter0

    Running the following python script:

    python prepare_for_training.py \
        --portion top \
        --target-dataset argilla/10k_prompts_SPIN_iter0_zephyr_top
SPIN iter 1
  • generate_iter_spin.py

    Regenerates the "generated" responses from the model in the previous iteration:

    python generate_iter_spin.py \
        --hf-apikey $HF_API_TOKEN \
        --source-dataset "argilla/10k_prompts_SPIN_iter0_zephyr_top" \
        --new-dataset "argilla/10k_prompts_SPIN_iter1_zephyr_top_generated" \
        --model-name "plaguss/zephyr-7b-spin-iter0-v0" \
        --batch-size 128 \
        --cuda-devices "0,1"

    Dataset: argilla/10k_prompts_top_SPIN_iter1_generated

  • transform_iter_generated.py

    The script transforms the generated responses to the format expected by SPIN trainer:

    python transform_iter_generated.py \
        --real-dataset "argilla/10k_prompts_ranked_with_responses" \
        --generated-dataset "argilla/10k_prompts_SPIN_iter1_zephyr_top_generated" \
        --new-dataset "argilla/10k_prompts_SPIN_iter1_zephyr_top"
SPIN iter 2
  • generate_iter_spin.py

    Regenerates the "generated" responses from the model in the previous iteration:

    python generate_iter_spin.py \
        --hf-apikey $HF_API_TOKEN \
        --source-dataset "argilla/10k_prompts_SPIN_iter0_zephyr_top" \
        --new-dataset "argilla/10k_prompts_SPIN_iter2_zephyr_top_generated" \
        --model-name "plaguss/zephyr-7b-spin-iter1-v0" \
        --batch-size 128 \
        --cuda-devices "0,1"

    Dataset: argilla/10k_prompts_top_SPIN_iter2_generated

  • transform_iter_generated.py

    The script transforms the generated responses to the format expected by SPIN trainer:

    python transform_iter_generated.py \
        --real-dataset "argilla/10k_prompts_ranked_with_responses" \
        --generated-dataset "argilla/10k_prompts_SPIN_iter2_zephyr_top_generated" \
        --new-dataset "argilla/10k_prompts_SPIN_iter2_zephyr_top"
SPIN iter 3
  • generate_iter_spin.py

    Regenerates the "generated" responses from the model in the previous iteration:

    python generate_iter_spin.py \
        --hf-apikey $HF_API_TOKEN \
        --source-dataset "argilla/10k_prompts_SPIN_iter0_zephyr_top" \
        --new-dataset "argilla/10k_prompts_SPIN_iter3_zephyr_top_generated" \
        --model-name "plaguss/zephyr-7b-spin-iter2-v0" \
        --batch-size 128 \
        --cuda-devices "0,1"

    Dataset: argilla/10k_prompts_top_SPIN_iter3_generated

  • transform_iter_generated.py

    The script transforms the generated responses to the format expected by SPIN trainer:

    python transform_iter_generated.py \
        --real-dataset "argilla/10k_prompts_ranked_with_responses" \
        --generated-dataset "argilla/10k_prompts_SPIN_iter3_zephyr_top_generated" \
        --new-dataset "argilla/10k_prompts_SPIN_iter3_zephyr_top"

Fine tune using SPIN

The following steps are almost a copy from the SPIN repository, take a look there for more information.

Runpod

We used Runpod with the following setup:

  • 4 A100 80Gb.
  • 500Gb container and volume.
  • Base image with CUDA 12.1.

Once with the POD running

These are the steps outlined in the SPIN repo, you can run them by running the script in scripts/setup.sh:

pip install torch==2.1.1 --index-url https://download.pytorch.org/whl/cu121

Clone and install the repo from source:

git clone https://github.com/uclaml/SPIN.git && cd SPIN

Install package and flash-attn

python -m pip install .
python -m pip install flash-attn==2.5.3 --no-build-isolation

Log to huggingface:

huggingface-cli login --token $HF_API_TOKEN

Log to wandb:

pip install wandb
wandb login $WANDB_TOKEN

And update the WANDB variables to keep track of the experiments:

export WANDB_ENTITY="argilla-io"
export WANDB_PROJECT="dibt-spin-zephyr"
export WANDB_NAME="zephyr-7b-spin-iter0-v0"

After the previous step, replace the config file of the model to run, and the finetune.sh script, and start the training process:

bash scripts/finetune.sh

Weights and Biases runs

DIBT 10k *Top* subset

About

Repository containing the SPIN experiments on the DIBT 10k ranked prompts

License:Apache License 2.0


Languages

Language:Python 91.8%Language:Shell 8.2%