-
-
Notifications
You must be signed in to change notification settings - Fork 5.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bugfix] Fix incorrect updates to num_computed_tokens in multi-step scheduling #9038
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -55,19 +55,21 @@ def test_num_computed_tokens_update(num_scheduler_steps: int, | |
engine.step() | ||
|
||
if not seq.is_finished(): | ||
assert seq.data.get_num_computed_tokens( | ||
) == prompt_len + num_prompt_steps - 1 | ||
|
||
prompt_num_computed_tokens = seq.data.get_num_computed_tokens() | ||
# Test correctness of num_computed_tokens after the prompt steps | ||
assert prompt_num_computed_tokens == \ | ||
prompt_len + num_prompt_steps - 1 | ||
|
||
decode_step_counter = 0 | ||
while not seq.is_finished(): | ||
# Test correctness of num_computed_tokens after the decode steps | ||
assert seq.data.get_num_computed_tokens( | ||
) == prompt_num_computed_tokens + decode_step_counter | ||
for _ in range(num_scheduler_steps): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. QQ: why do we need the for loop here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. when Also, multi-step doesn't provide any guarantees that output processing will happen every step. The only guarantee is that after the completion of |
||
# decode step | ||
engine.step() | ||
decode_step_counter += 1 | ||
|
||
# Test correctness of num_computed_tokens after the sequence finish. | ||
assert seq.data.get_num_computed_tokens( | ||
) == prompt_len + num_output_tokens - 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: merge this call with line 58 so we only call get_num_computed_tokens once.