Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FlashInfer support #2354

Merged
merged 1 commit into from
Aug 9, 2024
Merged

Add FlashInfer support #2354

merged 1 commit into from
Aug 9, 2024

Conversation

danieldk
Copy link
Member

@danieldk danieldk commented Aug 2, 2024

What does this PR do?

This change adds support for FlashInfer. FlashInfer can be enabled using FLASH_INFER=1 and is currently only implemented in FlashCausalLM. Since this functionality is currently only for testing, FlashInfer is not installed anywhere yet.

The FlashInfer API is quite different from FlashAttention/vLLM in that it requires more global bookkeeping:

  • A wrapper class needs to be contstructed (which we just call state). Since this is fairly expensive (due to pinned host memory allocation), we only do this once in a FlashCausalLM instance or for each CUDA Graph size.
  • Each model forward call needs to be wrapped in begin_forward and end_forward. This sets up data structures that can be reused for all calls to attention for that forward call.

When calling attention, we need access to the state object. To avoid passing an argument down the call chain (which would require changes to all models), we use a context variable.

Each model forward call is wrapped using a context manager that does all the bookkeeping for such a call:

  • Set the context variable to the forward call's state.
  • Call begin_forward on the state.
  • Yield.
  • Call end_forward on the state.
  • Reset the context variable.

We cannot use a single shared global variable for this, since e.g. CUDA Graphs of different sizes each have their own state.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

This change adds support for FlashInfer. FlashInfer can be enabled using
`FLASH_INFER=1` and is currently only implemented in `FlashCausalLM`.
Since this functionality is currently only for testing, FlashInfer is
not installed anywhere yet.

The FlashInfer API is quite different from FlashAttention/vLLM in that
it requires more global bookkeeping:

* A wrapper class needs to be contstructed (which we just call *state*).
  Since this is fairly expensive (due to pinned host memory allocation),
  we only do this once in a FlashCausalLM instance or for each CUDA
  Graph size.
* Each model forward call needs to be wrapped in `begin_forward` and
  `end_forward`. This sets up data structures that can be reused for all
  calls to attention for that forward call.

When calling attention, we need access to the state object. To avoid
passing an argument down the call chain (which would require changes to
all models), we use a context variable.

Each model forward call is wrapped using a context manager that does all
the bookkeeping for such a call:

* Set the context variable to the forward call's state.
* Call `begin_forward` on the state.
* Yield.
* Call `end_forward` on the state.
* Reset the context variable.

We cannot use a single shared global variable for this, since e.g. CUDA
Graphs of different sizes each have their own state.
Copy link
Collaborator

@Narsil Narsil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Narsil Narsil merged commit 7830de1 into main Aug 9, 2024
9 checks passed
@Narsil Narsil deleted the feature/flash-infer branch August 9, 2024 09:42
yuanwu2017 pushed a commit to yuanwu2017/tgi-gaudi that referenced this pull request Sep 26, 2024
This change adds support for FlashInfer. FlashInfer can be enabled using
`FLASH_INFER=1` and is currently only implemented in `FlashCausalLM`.
Since this functionality is currently only for testing, FlashInfer is
not installed anywhere yet.

The FlashInfer API is quite different from FlashAttention/vLLM in that
it requires more global bookkeeping:

* A wrapper class needs to be contstructed (which we just call *state*).
  Since this is fairly expensive (due to pinned host memory allocation),
  we only do this once in a FlashCausalLM instance or for each CUDA
  Graph size.
* Each model forward call needs to be wrapped in `begin_forward` and
  `end_forward`. This sets up data structures that can be reused for all
  calls to attention for that forward call.

When calling attention, we need access to the state object. To avoid
passing an argument down the call chain (which would require changes to
all models), we use a context variable.

Each model forward call is wrapped using a context manager that does all
the bookkeeping for such a call:

* Set the context variable to the forward call's state.
* Call `begin_forward` on the state.
* Yield.
* Call `end_forward` on the state.
* Reset the context variable.

We cannot use a single shared global variable for this, since e.g. CUDA
Graphs of different sizes each have their own state.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants