Update handling EOS token id detection #1925
Open
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
This PR updates how EOS token id detection is handled with ONNX Runtime GenAI when generating tokens. A new API called
Generator.HitEOS()is introduced to detect whether an EOS token id has been generated. Another API calledGenerator.HitMaxLength()is also introduced to detect whether the max length has been hit before the generation loop has completed.Motivation and Context
This PR is a follow-up to the issue fixed in an earlier PR. The earlier PR mentions different variations of the generation loop but all of the variations have an issue.
There are two scenarios for terminating the generation loop: 1) hitting the EOS token id and completing the generation loop or 2) hitting the max length before the generation loop has completed. However, none of the variations adequately cover the two scenarios for terminating the generation loop.
1. Original Generation Loop
Consider scenario 1 with this loop. After
GenerateToken()produces the EOS token id,GetLastToken()will attempt to retrieve that token. However, ORT GenAI does not append the EOS token id to the list of sequences returned to the user (see the earlier PR for why). Instead, the second-to-last token will still be the last token in the list of sequences. Thus,GetLastToken()andPrintLastToken()will retrieve and again print the last token that the user saw.2. Return Early Generation Loop
Consider scenario 2 with this loop. After
GenerateToken()produces a token and the max length has been reached, the generator's state is marked as done. ThenIsDone()will be true and the newest token won't be retrieved and printed since the loop is exited early.3. Infinite Generation Loop
Consider scenario 2 with this loop. The same issue as the prior loop still applies.
GenerateToken()will generate all of the tokens but once the max length is hit,IsDone()is true and the last token won't be retrieved and printed.Conclusion
The reason that none of these generation loop variants work is because
IsDone()currently covers both scenarios in one API and does not distinguish between them. One check needs to be in place in the condition of the while loop so that the loop continues, and another check needs to be after token generation to decide whether retrieving the last token should be done or not.Solution
To fix this, a new API called
Generator.HitEOS()is introduced. It returnstruewhen the EOS token id is generated. The generation loop should be modified to the following.If scenario 1 occurs in this loop,
HitEOS()istrueand the generation loop will exit early. If scenario 2 occurs in this loop,HitEOS()isfalsewhen the max length is reached. The last generated token can still be retrieved and printed. Then because the generator's state is done,IsDone()istrueand the generation loop ends.Here is a full end-to-end example demonstrating its usage.
Scenario 1
Before with loop version 1:
After with
generator.hit_eos():Scenario 2
Before with loop version 2:
After with
generator.hit_eos():