Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix prefix caching + speculative decoding #2711

Merged
merged 1 commit into from
Nov 4, 2024
Merged

Conversation

tgaddair
Copy link
Contributor

What does this PR do?

Following #2402, it was remarked by @ZTianle that enabling speculative decoding when using prefix caching was resulting in incorrect results. We also internally noticed that it sometimes caused CUDA device-side assertion failures.

The root cause was determined to be incorrect slot assignments that were occurring when calling slots.expand and incrementing the slots by arange_int. This assumed that the next slots in the allocation were the next ordinal slot IDs, but when using prefix caching, this often is not the case. Due to the radix tree format, slot IDs are frequently discontiguous, and so we need to expand the slot_indices and use those to index into the allocated slots rather than expanding the slots directly.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

cc @OlivierDehaene @Narsil @danieldk

@Narsil
Copy link
Collaborator

Narsil commented Nov 4, 2024

Thanks for the PR, accepting this one since the fixed seems to work (will patch main directly after)

@Narsil Narsil merged commit aadc9cb into huggingface:main Nov 4, 2024
13 of 23 checks passed
@tgaddair tgaddair deleted the patch-2 branch November 4, 2024 19:46
@ZTianle
Copy link

ZTianle commented Nov 11, 2024

Thank you @tgaddair for fixing the speculative decoding issue with prefix caching. I've noticed that while this fix works well with flash decoding+prefix caching, there still appear to be compatibility issues when using speculative decoding together with flashinfer - specifically, it continues to generate incorrect outputs.
Has there been any investigation into making speculative decoding work properly with the flashinfer and flashinfer + prefix caching combination? Would be helpful to understand if this is on the roadmap or if there are known technical challenges preventing this integration. @Narsil @danieldk

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants