LoftQ helps you fine-tune LLMs with limited GPUs. 🚀 LoftQ finds good enough quantized LoRA initialization: quantized backbone Q and LoRA adapters A and B, given a pre-trained weight W.
This repo implements the paper 🔗: LoftQ: LoRA-Fine-Tuning-Aware Quantization for Large Language Models.
Our models are available on Huggingface Hub 🤗 LoftQ
We use bitsandbytes to implement the quantization. This package only support CUDA >= 11.0 and does not support CPU. However, we also provide fake quantization for fast and parallel training if GPUs are adequate.
pip install -r requirements.txt
- Apply LoftQ to a full-precision pre-trained weight and save.
- Load LoftQ initialization and train.
For step 1, we have provided off-the-shelf LoftQ initializations (see supported model list) in Huggingface Hub LoftQ. If you want to do it yourself, jump to LoftQ DIY.
For step 2, below is an example of loading 4bit Mistral-7B with 64rank LoRA adapters from Huggingface Hub.
import torch
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
# fetch the MODEL_ID at https://huggingface.co/LoftQ
MODEL_ID = "LoftQ/Mistral-7B-v0.1-4bit-64rank"
base_model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16, # you may change it with different models
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16, # bfloat16 is recommended
bnb_4bit_use_double_quant=False,
bnb_4bit_quant_type='nf4',
),
)
peft_model = PeftModel.from_pretrained(
base_model,
MODEL_ID,
subfolder="loftq_init",
is_trainable=True,
)
# Do training with peft_model ...
We provide quantize_save.py as an example to apply LoftQ with
different bits(--bits
), ranks(--rank
), and alternating steps (--iter
, a hyper-parameter in LoftQ, see Algorithm 1 in LoftQ paper). Currently, this example supports
llama-2
, falcon
, mistral
, bart
, t5
, deberta
, bert
, roberta
.
Below is an example of obtaining 4bit LLAMA-2-7b with 16-rank LoRA adapters by 5 alternating steps.
SAVE_DIR="model_zoo/loftq/"
python quantize_save_load.py \
--model_name_or_path meta-llama/Llama-2-7b-hf \ # high-precision model id in HF
--token HF_TOKEN \ # your HF token if the model is private, e.g., llama-2
--bits 4 \
--iter 5 \
--rank 16 \
--save_dir $SAVE_DIR
The above commands end up with creating the model directory under $SAVE_DIR
.
Specifically, the model directory is named as
MODEL_DIR = SAVE_DIR + f"{args.model_name_or_path.split('/')[-1]}-{args.bits}bits-{args.rank}rank"
In this example, MODEL_DIR="model_zoo/loftq/Llama-2-7b-hf-4bit-16rank"
, where the backbone is stored in $MODEL_DIR
and the LoRA adapters are at the sub-folder $MODEL_DIR/loftq_init
.
Similar to loading from Huggingface Hub, we only need to change the MODEL_ID
to the MODEL_DIR
.
import torch
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
MODEL_DIR = "model_zoo/loftq/Llama-2-7b-hf-4bit-16rank"
base_model = AutoModelForCausalLM.from_pretrained(
MODEL_DIR,
torch_dtype=torch.bfloat16,
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=False,
bnb_4bit_quant_type='nf4',
),
)
peft_model = PeftModel.from_pretrained(
base_model,
MODEL_DIR,
subfolder="loftq_init",
is_trainable=True,
)
# Do training with peft_model ...
We also provide an example to fine-tune LLAMA-7b with LoftQ on GSM8K.
python train_gsm8k.py \
--model_name_or_path LoftQ/Llama-2-7b-hf-4bit-64rank \
--learning_rate 3e-4 \
--seed 11 \
--expt_name gsm8k_llama2_7b_4bit_64rank_loftq \
--output_dir exp_results/ \
--num_train_epochs 6 \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 8 \
--evaluation_strategy "no" \
--save_strategy "epoch" \
--weight_decay 0.1 \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 10 \
--do_train \
--report_to tensorboard
Here is the command to test GSM8K with adapters we have fine-tuned. It is stored in the subfolder='gsm8k'
of the target model in LoftQ Huggingface hub.
python test_gsm8k.py \
--model_name_or_path LoftQ/Llama-2-7b-hf-4bit-64rank \
--batch_size 16
- GLUE:
glue/run_glue.py
- Question Answering:
glue/run_qa.py
- Summarization:
train_summarization.py
- WikiText-2:
train_clm.py
- GSM8K:
train_gsm8k.py
More example scripts are in scripts.
LLAMA-2 on WikiText-2 and GSM8K
Bit | WikiText-2 | WikiText-2 | GSM8K | GSM8K |
---|---|---|---|---|
LLAMA-2-7b | LLAMA-2-13b | LLAMA-2-7b | LLAMA-2-13b | |
16 | 5.08 | 5.12 | 36.9 | 43.1 |
4 | 5.24 | 5.16 | 35.0 | 45.0 |
3 | 5.63 | 5.13 | 32.9 | 44.4 |
2.5 | 5.78 | 5.22 | 31.1 | 41.1 |
2.25 | 6.13 | 5.45 | 26.5 | 38.1 |
2 | 7.85 | 7.69 | 20.9 | 25.4 |
Models are fine-tuned through causal language modeling on training sets and are tested on validation/test sets.
BART-large on CNN/DailyMail and XSum
Bit | Rank | XSum | CNN/DailyMail |
---|---|---|---|
Lead-3* | 16.30/1.60/11.95 | 40.42/17.62/36.67 | |
16 | 16 | 43.95/20.72/35.68 | 45.03/21.84/42.15 |
4 | 16 | 44.51/21.14/36.18 | 43.96/21.06/40.96 |
2 | 16 | 40.81/17.85/32.80 | 42.52/19.81/39.51 |
16 | 8 | 43.40/20.20/35.20 | 44.72/21.58/41.84 |
4 | 8 | 44.08/20.72/35.89 | 43.81/20.95/40.84 |
2 | 8 | 39.63/16.65/31.62 | 42.24/19.44/29.04 |
*: Using the first 3 sentences in the document as the summary
DeBERTa-V3-base on GLUE using Normal Float Datatype
Bit | Rank | MNLI | QNLI | RTE | SST | MRPC | CoLA | QQP | STSB | SQuAD | ANLI |
---|---|---|---|---|---|---|---|---|---|---|---|
m / mm | Acc | Acc | Acc | Acc | Acc | Mcc | P/S Corr | EM/F1 | Acc | ||
16 | 16 | 90.5/90.6 | 94.0 | 82.0 | 95.3 | 89.5/93.3 | 69.2 | 92.4/89.8 | 91.6/91.1 | 88.5/92.8 | 59.8 |
2 | 16 | 84.7/85.1 | 86.6 | 61.4 | 90.2 | 83.8/88.6 | 37.4 | 90.3/86.9 | 87.1/86.9 | 81.5/88.6 | 47.1 |
2 | 32 | 86.0/86.1 | 89.9 | 61.7 | 92.0 | 83.6/87.2 | 47.5 | 91.0/87.9 | 87.5/87.0 | 82.9/89.8 | 49.0 |
DeBERTa-V3-base on GLUE using Uniform Quantization Datatype
Bit | Rank | MNLI | QNLI | RTE | SST | MRPC | CoLA | QQP | STSB | SQuAD |
---|---|---|---|---|---|---|---|---|---|---|
m / mm | Acc | Acc | Acc | Acc | Acc | Mcc | P/S Corr | Em/F1 | ||
16 | 16 | 90.5/90.6 | 94.0 | 82.0 | 95.3 | 89.5/93.3 | 69.2 | 92.4/89.8 | 91.6/91.1 | 88.5/92.8 |
2 | 16 | 87.3/87.1 | 90.6 | 61.1 | 94.0 | 87.0/90.6 | 59.1 | 90.9/88.0 | 87.9/87.6 | 84.4/91.2 |
2 | 32 | 88.0/88.1 | 92.2 | 63.2 | 94.7 | 87.5/91.2 | 60.5 | 91.3/88.3 | 89.5/89.2 | 85.2/91.6 |
@article{li2023loftq,
title={Loftq: Lora-fine-tuning-aware quantization for large language models},
author={Li, Yixiao and Yu, Yifan and Liang, Chen and He, Pengcheng and Karampatziakis, Nikos and Chen, Weizhu and Zhao, Tuo},
journal={arXiv preprint arXiv:2310.08659},
year={2023}
}
Model Name | Bits | Ranks |
---|---|---|
LLAMA-2-7b | 4 | 64 |
LLAMA-2-13b | 4 | 64 |
LLAMA-2-70b | 4 | 64 |
Mistral | 4 | 64 |
Mistral | 4 | 32 |
BART-large | 4 | 8 |
BART-large | 4 | 16 |
BART-large | 4 | 32 |
BART-large | 2 | 8 |