QwenLM / Qwen2

Qwen2 is the large language model series developed by Qwen team, Alibaba Cloud.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

官方给的微调案例无法处理batch数据,非常浪费显存。

yangkang2318 opened this issue · comments

即每个batch内的数据都倍padding到了max length,而不是padding至batch内的最大长度。原因在于tokenizer.apply_chat_template()无法处理batch句子。
一种解决方式是先用tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=False)生成文本,再定义一个Collator类进行处理,用

tokenizer(
            text = full_texts,
            text_target = input_texts,
            return_tensors="pt",
            padding="longest",
            max_length=self.tokenizer.model_max_length,
            truncation=True,
            return_attention_mask=True,
        )

进行处理

参考以下dataset里get item方法

def __getitem__(self, index):
    input = self.data[index]["input"]
    output=self.data[index]["output"]       
    msg= [
          {"role": "system", "content": "You are a helpful assistant."},
          {"role": "user", "content": input},
          {"role":"assistant","content":output},
   ]
    response=self.tokenizer.apply_chat_template(msg, tokenize=False, add_generation_prompt=False)
    input=response.split("<|im_start|>assistant\n")[0]
    input+="<|im_start|>assistant\n"
    return dict(input_ids=input, labels=response)

再定义Collator类

class Collator(object):

    def __init__(self, args, tokenizer):
        self.args = args
        self.only_train_response = args.only_train_response
        self.tokenizer = tokenizer
        if self.tokenizer.pad_token_id is None:
            self.tokenizer.pad_token_id = self.tokenizer.unk_token_id

    def __call__(self, batch):
        input_texts = [d["input_ids"] for d in batch]
        full_texts = [d["labels"] for d in batch]

        inputs = self.tokenizer(
            text = full_texts,
            text_target = input_texts,
            return_tensors="pt",
            padding="longest",
            max_length=self.tokenizer.model_max_length,
            truncation=True,
            return_attention_mask=True,
        )
        labels = copy.deepcopy(inputs["input_ids"])
        if self.only_train_response:
            # ignore padding
            labels[labels == self.tokenizer.pad_token_id] = -100
            # ignore input text
            labels[torch.where(inputs["labels"] != self.tokenizer.pad_token_id)] = -100

        inputs["labels"] = labels


        return inputs

将Colloator类实例collator传递给transformers.Trainer(data_collator=collator)

想问下这里“ inputs = self.tokenizer(
text = full_texts,
text_target = input_texts,”,为啥text是full_texts,反而text_target不包含答案

@wyclike 你可以参考hugging face的实现code

if text is None and text_target is None:
    raise ValueError("You need to specify either `text` or `text_target`.")
if text is not None:
    # The context manager will send the inputs as normal texts and not text_target, but we shouldn't change the
    # input mode in this case.
    if not self._in_target_context_manager:
        self._switch_to_input_mode()
    encodings = self._call_one(text=text, text_pair=text_pair, **all_kwargs)
if text_target is not None:
    self._switch_to_target_mode()
    target_encodings = self._call_one(text=text_target, text_pair=text_pair_target, **all_kwargs)
# Leave back tokenizer in input mode
self._switch_to_input_mode()

if text_target is None:
    return encodings
elif text is None:
    return target_encodings
else:
    encodings["labels"] = target_encodings["input_ids"]
    return encodings

text=full_text,text_target=input_texts时,返回结果encodings["input_ids"]full_textinput_ids,而encodings["labels"]input_textsinput_ids

然后以下代码就可以处理生成真正的labels

labels = copy.deepcopy(inputs["input_ids"])
        if self.only_train_response:
            # ignore padding
            labels[labels == self.tokenizer.pad_token_id] = -100
            # ignore input text
            labels[torch.where(inputs["labels"] != self.tokenizer.pad_token_id)] = -100

        inputs["labels"] = labels

当前脚本是会将一个batch的文本padding到一样的长度,然后并行计算。脚本里没有往trainer里传入data_collator,但Trainer里默认会使用DataCollatorWithPadding对一个batch的数据进行padding。详情可以进入Trainer的代码进行查阅。

我先将该issue close掉,如果有其他问题,欢迎重新open