huggingface / blog

Public repo for HF blog posts

Home Page:https://hf.co/blog

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

gemma fine tuning blog: formatting_func never called

dangbert opened this issue · comments

I'm trying to finetune gemma as shown in this blog post: https://huggingface.co/blog/gemma-peft which uses / links to this notebook (currently @ version ef0c851)

After finetuning I wasn't getting an output in the expected format, and upon investigation found that formatting_func was never being called at all.

Here is my code which is a direct copy of the notebook, except I removed the reading of the HF token from the os.environ, and I raise an exception inside formatting_func to demonstrate it's not being called (as the finetuning completes successfully).

import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, GemmaTokenizer

model_id = "google/gemma-7b"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, device_map={"":0})

text = "Quote: Imagination is more"
device = "cuda:0"
inputs = tokenizer(text, return_tensors="pt").to(device)

outputs = model.generate(**inputs, max_new_tokens=20)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

from peft import LoraConfig

lora_config = LoraConfig(
    r=8,
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
    task_type="CAUSAL_LM",
)

from datasets import load_dataset

data = load_dataset("Abirate/english_quotes")
data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True)

import transformers
from trl import SFTTrainer

def formatting_func(example):
    raise RuntimeError("if you can read this, formatting_func was called")
    text = f"Quote: {example['quote'][0]}\nAuthor: {example['author'][0]}"
    return [text]

trainer = SFTTrainer(
    model=model,
    train_dataset=data["train"],
    args=transformers.TrainingArguments(
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        warmup_steps=2,
        max_steps=10,
        learning_rate=2e-4,
        fp16=True,
        logging_steps=1,
        output_dir="outputs",
        optim="paged_adamw_8bit"
    ),
    peft_config=lora_config,
    formatting_func=formatting_func,
)
trainer.train()

my python environment: Python 3.11.8, requirements.txt

(I have extra dependencies installed in this environment so sorry that file is so big).

I believe I understand why it's not being called:

the code for SFTTrainer._prepare_dataset does nothing if "input_ids" is in dataset.column_names, and data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True) adds this field already before the SFTTrainer is instantiated. So formatting_func is never used

Hi @dangbert
This is correct, if you already process the dataset the formatting_func will be ignored. I made: huggingface/trl#1577 to throw a warning in case users face this usecase + will update the notebook