mistralai / mistral-inference

Official inference library for Mistral models

Home Page:https://mistral.ai/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Suggested improvement of eos logic in generate.py

vvatter opened this issue · comments

In the generate() function of generate.py, there is some curious XOR logic for updating the boolean is_finished vector:

        if eos_id is not None:
            is_finished = is_finished ^ (next_token == eos_id).cpu()

Even once it reaches an eos token, Mistral likes to keep talking, so this means that if you are running large batches, the shortest response might hit eos and then generate another eos and flip back to is_finished == False before the longest response has finished, which will often keep happening up until you hit max_tokens. It seems to me that this should be an OR.

Additionally, the current approach allows tokens following an EOS to be included in outputs, which, since the tokenizer decodes EOS as an empty string, might contribute to confusing output sequences. This could potentially relate to the issues discussed in #149 .

To address both issues, I suggest the following modifications to ensure that is_finished remains True after encountering an eos token and to not return tokens after this point.

        if eos_id is not None:
            is_finished = is_finished | (next_token == eos_id).cpu()
            next_token = next_token * (~is_finished).to(next_token.device)
            next_token = next_token + eos_id * is_finished.to(next_token.device)