-
Notifications
You must be signed in to change notification settings - Fork 29.5k
[generate] beam search -- fix output cropping #37080
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[generate] beam search -- fix output cropping #37080
Conversation
Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the |
) | ||
|
||
dct = tok(ARTICLE, return_tensors="pt") | ||
generated_ids = hf.generate(**dct, num_beams=4) | ||
result = tok.batch_decode(generated_ids, skip_special_tokens=True)[0] | ||
result = tok.batch_decode(generated_ids)[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tests: update beam search tests to also print special tokens
e.g. this updated test fails on main because it is returning extra pad tokens, because of the incorrect crop
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM thanks for digging and fixing this quickly!
* handle jagged beams * better comment * bart -- beam search tests print special tokens * more bart test updates * more tests! * better comment
* handle jagged beams * better comment * bart -- beam search tests print special tokens * more bart test updates * more tests! * better comment
What does this PR do?
VLLM is seeing some output differences in their CI when beam search is being used. The difference can be tracked to the beam search refactor (#35802).
Inspecting the outputs, we can see that there are a few additional pad tokens on the right. This is because the output was not being cropped correctly when the selected beam is shorter than the generation length (i.e. when the highest-scoring beam is NOT from the latest decoding iteration, but rather some previously completed beam).
After #35802: output length = input length + number of decoding iterations
Before #35802 and in this PR: output length = length of the longest selected beam
This PR also changed a few beam search tests to check their special tokens, which would have prevented this bug.