scaleapi / llm-engine

Scale LLM Engine public repository

Home Page:https://llm-engine.scale.com

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

max token length for finetune and completion endpoints on Lllama-2?

urimerhav opened this issue · comments

Great job with this repo. I was able to finetune Llama-2 and it certainly seems to have an effect.

Unfortunately the finetune silently accepts all inputs and the documentation states that you simply truncate inputs to max length. But it's not specified anywhere what's LLama-2's max length. Originally Meta released it with a bug that caused max length to be 2048 while the native max length seems to be 4096. So which is it?

Also, I tested my finetune model's completion code with inputs as big as 12,000 tokens and it still makes a completion. So I assume you truncate there as well? Only taking the tail of the prompt, presumably?

tldr:

  1. What is llama-2's max token length?
  2. Is there anything we can do to effect this or get better visibility into how the input got tokenized, etc?

BUMP!

commented

Hi @urimerhav, thanks for reaching out. Here are some answers to your questions:

Originally Meta released it with a bug that caused max length to be 2048 while the native max length seems to be 4096. So which is it?

Seems like 4096 as you may have found searching elsewhere.

Also, I tested my finetune model's completion code with inputs as big as 12,000 tokens and it still makes a completion. So I assume you truncate there as well? Only taking the tail of the prompt, presumably?

Do you have some more detailed repro steps? We just tried a long input and got an exception.

Is there anything we can do to effect this or get better visibility into how the input got tokenized, etc?

We haven't yet open sourced our fine-tuning code, although we fully intend to! The issue is that unlike the rest of LLM Engine, our fine-tuning scripts still have some internal dependencies that need to be ripped out. In the meantime, I can share a code snippet that might give some visibility into the tokenization process - we currently truncate to 1024 tokens for fine-tuning because we're currently just using A10's and wanted to avoid OOM:

class SFTCollator(object):
    """Collate examples for supervised fine-tuning.
    We intentionally mask out the prompt tokens to avoid training on them.
    """

    def __init__(
        self,
        tokenizer: PreTrainedTokenizer,
        max_sequence_length: int = None,
    ):
        self.tokenizer = tokenizer
        self.max_sequence_length = max_sequence_length

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        prompt_input_ids, completion_input_ids = tuple(
            [instance[key] for instance in instances] for key in (PROMPT_KEY, COMPLETION_KEY)
        )

        max_input_length = max(
            [len(p) + len(c) for p, c in zip(prompt_input_ids, completion_input_ids)]
        )
        input_ids = (
            torch.ones((len(prompt_input_ids), max_input_length), dtype=prompt_input_ids[0].dtype)
            * self.tokenizer.pad_token_id
        )
        labels = torch.ones(input_ids.shape, dtype=prompt_input_ids[0].dtype) * IGNORE_INDEX
        attention_mask = torch.zeros(input_ids.shape, dtype=torch.bool)

        for i, (prompt_ids, completion_ids) in enumerate(
            zip(prompt_input_ids, completion_input_ids)
        ):
            sequence_ids = torch.concatenate([prompt_ids, completion_ids])
            if self.tokenizer.padding_side == "right":
                input_ids[i][: len(sequence_ids)] = sequence_ids
                attention_mask[i][: len(sequence_ids)] = True
                labels[i][len(prompt_ids) : len(prompt_ids) + len(completion_ids)] = completion_ids
            else:
                input_ids[i][-len(sequence_ids) :] = sequence_ids
                attention_mask[i][-len(sequence_ids) :] = True
                labels[i][-len(completion_ids) :] = completion_ids

        return dict(
            input_ids=input_ids[:, : self.max_sequence_length],
            labels=labels[:, : self.max_sequence_length],
            attention_mask=attention_mask[:, : self.max_sequence_length],
        )