Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hardware][Intel] fp8 kv cache support for CPU #5492

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

jikunshang
Copy link
Contributor

FILL IN THE PR DESCRIPTION HERE

Add FP8 kv cache for Intel CPU.

BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE


PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

@@ -43,7 +43,7 @@ if (AVX512_FOUND)
"-mavx512dq")

find_isa(${CPUINFO} "avx512_bf16" AVX512BF16_FOUND)
if (AVX512BF16_FOUND OR ENABLE_AVX512BF16)
if (AVX512BF16_FOUND AND ENABLE_AVX512BF16)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we want to compile for AVX512BF16 if we simply find it, not if we find it and force it


constexpr static int TOKEN_PER_GROUP = k_load_vec_type::get_elem_num() / x;
constexpr static int MAX_GROUP_NUM = 16 / TOKEN_PER_GROUP;
constexpr static int UNROLL_GROUP_NUM = MAX_GROUP_NUM / 4;

static_assert(MAX_GROUP_NUM == 8 || MAX_GROUP_NUM == 4);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why aren't these asserts valid anymore?

@@ -22,6 +22,16 @@ struct KernelVecType<float> {
using v_load_vec_type = vec_op::FP32Vec16;
};

template <>
struct KernelVecType<float, cpu_fp8> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It isn't immediately clear that cpu_fp8 is treated as a type since the PyTorch types use PascalCase - maybe at least adding a _t such as cpu_fp8_t or CpuFP8 or FP8Cpu?

});
} else if (kv_cache_dtype == "fp8") {
if (query.dtype() == at::ScalarType::Float) {
CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(float, cpu_fp8, true);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need this IS_FP8_KV_CACHE boolean in paged_attention_v1 but not in paged_attention_v2? It seems like it could be deduced by checking the cache_t within the kernel and this would bode better for future cache types being used.

@@ -5,6 +5,10 @@
#include <immintrin.h>
#include <torch/all.h>

#include "fp8_utils.h"

typedef uint8_t cpu_fp8;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calling out again that I think the name of this type doesn't seem like a type.

return _mm512_slli_epi16(_mm512_cvtepi8_epi16(a), 8);
}

static inline __m256i _mm256_cvte5m2_fp16(__m128i a) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For all these function signatures that have just e5m2 to signify fp8e5m2, I think it would be more clear to include the fp8

Comment on lines +48 to +68
const __m512i a_ = _mm512_inserti64x4(
_mm512_inserti64x4(_mm512_setzero_si512(), b, 0), a, 1);
const __mmask32 maska1_ = _mm512_cmp_epi16_mask(_mm512_and_si512(a_, vnaninf),
vnaninf, _MM_CMPINT_NE);
const __mmask32 maska2_ = _mm512_cmp_epi16_mask(
_mm512_and_si512(a_, vfixupmask), vfixupmask, _MM_CMPINT_EQ);
const __mmask32 maska3_ =
_mm512_cmp_epi16_mask(_mm512_and_si512(a_, _mm512_set1_epi16(0x7FFF)),
vsatuval, _MM_CMPINT_NLE);
__m512i vExp_ = _mm512_sub_epi16(
_mm512_srli_epi16(_mm512_and_si512(a_, vnaninf), 10), vExp_fp16);
vExp_ = _mm512_slli_epi16(_mm512_add_epi16(vExp_, vExp_e5m2), 10);
__m512i a_rne_ = _mm512_or_si512(vExp_, _mm512_and_si512(a_, vsMant));
a_rne_ = _mm512_mask_add_epi16(
a_rne_, maska1_, a_rne_,
_mm512_mask_add_epi16(vrneadd, maska2_, vrneadd, vfixup));
a_rne_ = _mm512_mask_mov_epi16(
a_rne_, maska3_,
_mm512_or_si512(_mm512_and_si512(a_rne_, vinfval), vsatuval));
a_rne_ = _mm512_mask_mov_epi16(a_rne_, ~maska1_, vinfval);
return _mm512_cvtepi16_epi8(_mm512_srli_epi16(a_rne_, 8));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any additional function header or inline comments would be appreciated!

return _mm512_mask_mov_epi16(a_, mask1_, vinfval);
}

static inline void cvt_fp16_e5m2_noINF_rne_intrinsic(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the difference between noinf and noINF?

Comment on lines +136 to +144
static inline void cast_fp16xn_to_fp8xn(const short* __restrict__ in,
unsigned char* out, int n) {
cvt_fp16_e5m2_noINF_rne_intrinsic(in, out, n);
}

static inline void cast_fp32xn_to_fp8xn(const float* __restrict__ in,
float* out, int n) {
cvt_fp32_e5m2_noinf_rne_intrinsic(in, out, n, 0);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems you aren't getting any reuse out of these functions, why not inline the body?

Comment on lines -127 to -138
if kv_cache_dtype != "auto":
raise NotImplementedError(
"Torch SDPA backend does not support FP8 KV cache. "
"Please use xFormers backend instead.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is only implemented for CPU usage of the Torch SPDA backend, right? Would fp8 kv cache work if we used CUDA with the SPDA backend, for instance?

Copy link
Collaborator

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also cc @comaniac !

CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(scalar_t, scalar_t, false);
CPU_KERNEL_GUARD_OUT(paged_attention_v1_impl)
});
} else if (kv_cache_dtype == "fp8") {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'd be better to align with the general interface. We now support fp8, fp8_e4m3 and fp8_e5m2. Note that fp8 defaults to fp8_e4m3. Please cover all cases and throw an error for unsupported cases.

@jikunshang jikunshang marked this pull request as draft June 17, 2024 02:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants