Skip to content

Conversation

@kunal-vaishnavi
Copy link
Contributor

@kunal-vaishnavi kunal-vaishnavi commented Dec 18, 2025

Description

This PR updates how EOS token id detection is handled with ONNX Runtime GenAI when generating tokens. A new API called Generator.HitEOS() is introduced to detect whether an EOS token id has been generated. Another API called Generator.HitMaxLength() is also introduced to detect whether the max length has been hit before the generation loop has completed.

Motivation and Context

This PR is a follow-up to the issue fixed in an earlier PR. The earlier PR mentions different variations of the generation loop but all of the variations have an issue.

There are two scenarios for terminating the generation loop: 1) hitting the EOS token id and completing the generation loop or 2) hitting the max length before the generation loop has completed. However, none of the variations adequately cover the two scenarios for terminating the generation loop.

1. Original Generation Loop

while not IsDone():
    GenerateToken()
    GetLastToken()
    PrintLastToken()

Consider scenario 1 with this loop. After GenerateToken() produces the EOS token id, GetLastToken() will attempt to retrieve that token. However, ORT GenAI does not append the EOS token id to the list of sequences returned to the user (see the earlier PR for why). Instead, the second-to-last token will still be the last token in the list of sequences. Thus, GetLastToken() and PrintLastToken() will retrieve and again print the last token that the user saw.

2. Return Early Generation Loop

while not IsDone():
    GenerateToken()
    if IsDone():
        break
    GetLastToken()
    PrintLastToken()

Consider scenario 2 with this loop. After GenerateToken() produces a token and the max length has been reached, the generator's state is marked as done. Then IsDone() will be true and the newest token won't be retrieved and printed since the loop is exited early.

3. Infinite Generation Loop

while True:
    GenerateToken()
    if IsDone():
        break
    GetLastToken()
    PrintLastToken()

Consider scenario 2 with this loop. The same issue as the prior loop still applies. GenerateToken() will generate all of the tokens but once the max length is hit, IsDone() is true and the last token won't be retrieved and printed.

Conclusion

The reason that none of these generation loop variants work is because IsDone() currently covers both scenarios in one API and does not distinguish between them. One check needs to be in place in the condition of the while loop so that the loop continues, and another check needs to be after token generation to decide whether retrieving the last token should be done or not.

Solution

To fix this, a new API called Generator.HitEOS() is introduced. It returns true when the EOS token id is generated. The generation loop should be modified to the following.

while not IsDone():
    GenerateToken()
    if HitEOS():
        break
    GetLastToken()
    PrintLastToken()

If scenario 1 occurs in this loop, HitEOS() is true and the generation loop will exit early. If scenario 2 occurs in this loop, HitEOS() is false when the max length is reached. The last generated token can still be retrieved and printed. Then because the generator's state is done, IsDone() is true and the generation loop ends.

Here is a full end-to-end example demonstrating its usage.

import onnxruntime_genai as og

model = og.Model("/path/to/model/folder")
tokenizer = og.Tokenizer(model)
tokenizer_stream = tokenizer.create_stream()

params = og.GeneratorParams(model)
params.set_search_options(max_length=25)

generator = og.Generator(model, params)

tokens = tokenizer.encode("<|system|>You are a helpful AI assistant.<|end|><|user|>What color is the sky?<|end|><|assistant|>")
print(f"Prompt: {len(tokens)}")
generator.append_tokens(tokens)

count = 0
while not generator.is_done():
    generator.generate_next_token()
    count += 1
    if generator.hit_eos():
        break

    new_token = generator.get_next_tokens()[0]
    print(tokenizer_stream.decode(new_token), end="", flush=True)

print()
print(f"Generated: {count}")
print(f"Total: {len(tokens) + count}")

Scenario 1

Before with loop version 1:

Prompt: 18
The color of the sky can vary depending on the viewing conditions and the presence of particles and moisture in the atmosphere. On a clear day, the sky appears blue due to Rayleigh scattering, where the atmosphere scatters sunlight in all directions and blue wavelengths are scattered more than other colors because they travel as shorter, smaller waves. This scattering causes the sky to appear blue to an observer on the ground. However, the sky can also appear various shades of blue, gray, or even take on vibrant hues like red or orange just before or just after sunrise or sunset, due to the scattering of sunlight by particles and moisture in the atmosphere..
Generated: 128
Total: 146

After with generator.hit_eos():

Prompt: 18
The color of the sky can vary depending on the viewing conditions and the presence of particles and moisture in the atmosphere. On a clear day, the sky appears blue due to Rayleigh scattering, where the atmosphere scatters sunlight in all directions and blue wavelengths are scattered more than other colors because they travel as shorter, smaller waves. This scattering causes the sky to appear blue to an observer on the ground. However, the sky can also appear various shades of blue, gray, or even take on vibrant hues like red or orange just before or just after sunrise or sunset, due to the scattering of sunlight by particles and moisture in the atmosphere.
Generated: 128
Total: 146

Scenario 2

Before with loop version 2:

Prompt: 18
The color of the sky can
Generated: 7
Total: 25

After with generator.hit_eos():

Prompt: 18
The color of the sky can vary
Generated: 7
Total: 25

@apsonawane
Copy link
Contributor

In search_cuda.cpp line 173, we are checking for eos do we need to set hit_eos there or no?

assert(next_tokens_.size() == eos_seen_.size());

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants