Reason-Wang / NAT

The official implementation of paper "Learning From Failure: Integrating Negative Examples when Fine-tuning Large Language Models as Agents"

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Learning From $\textcolor{BrickRed}{\mathbf{Failure}}$: Integrating Negative Examples when Fine-tuning Large Language Models as Agents

🤗Model (NAT-math-13b) • 📄Paper (arxiv) • 🤗Dataset (math, qa)

Illustration

NAT (Negative-Aware Training) aims to push the limit of data efficiency by incorporating negative trajectories when fine-tuning Large Language Models as Agents. We find models' performance increases as we include more negative samples. Besides, our method (NAT) shows superior performance compared to directly including negative samples.

Main Results

Figure 1 and Figure 2 show the result on math and question-answering tasks. For math, our method continues to enhance the performance when incorporating more negative samples. For both math and QA, our method achieves better results.

NATInstruct

We use GPT-3.5 to collect trajectories of math and question-answering tasks. Detailed statistics can be found in Table 1. It is also available at 🤗NATInstruct.

NAT Models

Our NAT models are trained with negative-aware training method based on Llama-2-chat. Each model is trained with both positive and negative samples. All model links can be found in Table 1.

🤗Model 📂Dataset(#Num)
NAT-math-7b gsm8k positive(5,100)
NAT-math-13b gsm8k negative(12,150)
NAT-qa-7b hotpotqa positive(2,147)
NAT-qa-13b hotpotqa negative(2,425)
NAT-strategy-qa-7b strategyqa positive(1,070)
NAT-strategy-qa-13b strategyqa negative(1,263)

Table 1: Models and datasets.

Interactive

You can try our model using the interactive script:

python -m utils.interactive \
  --model_name_or_path reasonwang/NAT-math-7b \
  --task math

Training

Run the following command or scripts/train.sh to train a Llama-2-chat model on gsm8k task. Important arguments are:

  • pos_path and neg_path: specify the path of positive trajectories and negative trajectories;
  • pos_num and neg_num: number of positive and negative trajectories;
torchrun --nproc_per_node=4 --nnodes=1 --master_port=29522 train/train.py \
  --model_name_or_path meta-llama/Llama-2-7b-chat-hf \
  --task_name gsm8k --template zero-shot-target-aware \
  --pos_path data/dataset/gsm8k/gsm8k_gpt-3.5_positive.json \
  --neg_path data/dataset/gsm8k/gsm8k_gpt-3.5_negative.json \
  --neg_num 10000 --prompt_path prompts/gsm8k/gsm8k_conversation.json \
  --question_path data/dataset/gsm8k/train.json \
  --deepspeed configs/deepspeed_z3_config.json \
  --output_dir data/checkpoints/NAT-7b-math \
  --num_train_epochs 2 --per_device_train_batch_size 2 \
  --gradient_accumulation_steps 8 --save_strategy no \
  --learning_rate 5e-5 --weight_decay 0.0 --warmup_ratio 0.03 \
  --lr_scheduler_type cosine --logging_strategy steps \
  --logging_steps 4 --model_max_length 4096 --bf16 True

Evaluation

Run the following command or scripts/evaluate.sh to evaluate a model.

python -m data.generate --model reasonwang/NAT-math-7b --template zero-shot-target-aware --temperature 0.2 --task_name gsm8k --task_split test

Citation

If you find our work or code helpful, please consider citing:

@misc{wang2024learning,
      title={Learning From Failure: Integrating Negative Examples when Fine-tuning Large Language Models as Agents}, 
      author={Renxi Wang and Haonan Li and Xudong Han and Yixuan Zhang and Timothy Baldwin},
      year={2024},
      eprint={2402.11651},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}

About

The official implementation of paper "Learning From Failure: Integrating Negative Examples when Fine-tuning Large Language Models as Agents"


Languages

Language:Python 97.7%Language:Shell 2.3%