Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add basic FP8 KV cache support #2603

Merged
merged 2 commits into from
Oct 4, 2024
Merged

Add basic FP8 KV cache support #2603

merged 2 commits into from
Oct 4, 2024

Conversation

danieldk
Copy link
Member

@danieldk danieldk commented Oct 2, 2024

What does this PR do?

This change adds rudimentary FP8 KV cache support. The support is enabled by passing --kv-cache-dtype fp8_e5m2 to the launcher. Doing so uses this type for the KV cache. However support is still limited:

  • Only the fp8_e5m2 type is supported.
  • The KV cache layout is the same as float16/bfloat16 (HND).
  • The FP8 KV cache is only supported for FlashInfer.
  • Loading of scales is not yet supported.

This PR is intentionally small to keep things reviewable. I'll follow it up with PRs that add more functionality.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@danieldk danieldk force-pushed the feature/fp8-kv-cache branch 2 times, most recently from 2628268 to 37df2ff Compare October 3, 2024 11:12
Narsil
Narsil previously approved these changes Oct 4, 2024
Copy link
Collaborator

@Narsil Narsil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great and actually relatively simple Change !

server/text_generation_server/layers/attention/kv_cache.py Outdated Show resolved Hide resolved
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
else:
reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"auto", 1.0 ? What are those flags ? They didn't seem to be used before, aren't they defaulted in paged ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specifying them here breaks IPEX no ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, yes, the others have these arguments, but not IPEX, so I reverted this part of the PR (the KV cache now uses the existing reshape_and_cache wrappers.

This change adds rudimentary FP8 KV cache support. The support is
enabled by passing `--kv-cache-dtype fp8_e5m2` to the launcher. Doing so
uses this type for the KV cache. However support is still limited:

* Only the `fp8_e5m2` type is supported.
* The KV cache layout is the same as `float16`/`bfloat16` (HND).
* The FP8 KV cache is only supported for FlashInfer.
* Loading of scales is not yet supported.
@danieldk danieldk force-pushed the feature/fp8-kv-cache branch from 4cc5405 to ed5c2fb Compare October 4, 2024 13:25
@danieldk danieldk mentioned this pull request Oct 4, 2024
5 tasks
Comment on lines +314 to +315
kv_cache.store(
key=kv[:, :, 0].contiguous(), value=kv[:, :, 1].contiguous(), slots=slots
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not important, but its strange we need the .contiguous() calls here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I wasn't sure if it was needed, these seems to be key/value striding in other places that is non-contiguous, but I also didn't want to touch it.

Copy link
Collaborator

@drbh drbh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Great addition

@danieldk danieldk merged commit 2358c2b into main Oct 4, 2024
12 of 13 checks passed
@danieldk danieldk deleted the feature/fp8-kv-cache branch October 4, 2024 15:51
@Narsil Narsil mentioned this pull request Oct 8, 2024
4 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants