anandanne / ResidualPrompts

Residual Prompt Tuning: a method for faster and better prompt tuning.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Residual Prompt Tuning

This repository contains the original implementation for "Residual Prompt Tuning: Improving Prompt Tuning with Residual Reparameterization" (ACL 2023) by Anastasia Razdaibiedina, Yuning Mao, Rui Hou, Madian Khabsa, Mike Lewis, Jimmy Ba and Amjad Almahairi.

🎊 Our work is accepted to ACL Findings 2023!

Table of contents

Overview

We introduce Residual Prompt Tuning – a simple and efficient method that significantly improves the performance and stability of prompt tuning. We propose to reparameterize soft prompt embedings using a shallow network with a residual connection.

This reparameterization gives the model more flexibility to decide between using a separate embedding for each prompt token versus the representation obtained from the shared reparameterization network. After training is completed, the reparameterization network can be discarded and original prompt embeddings can be replaced with their projections.

Our codebase includes pytorch implementation of:

  • original prompt tuning (following Lester et al.)
  • residual prompt tuning (our modification)
  • full model tuning

Installation

Clone this repo as follows:

git clone https://github.com/arazd/ResidualPrompts
cd ResidualPrompts
conda env create -f environment.yaml
conda activate nlp

Training

An example of training a 10-token soft prompt on WSC task using T5-base model and residual reparametrization with MLP1 type:

python train.py --task wsc --prefix_MLP MLP1 \
    --lr 0.3 --freeze_weights 1 --freeze_except xxxx \
    --model_name t5-base --early_stopping 1 \
    --test_eval_after_every_task 1 --select_k_per_class -1 \
    --batch_size 8 --num_epochs 20 --prefix_len 10 \
    --save_dir /home/%u/my_dir/ --save_name my_model_folder

About

Residual Prompt Tuning: a method for faster and better prompt tuning.

License:Apache License 2.0


Languages

Language:Python 86.3%Language:Jupyter Notebook 10.4%Language:Shell 3.2%