Skip to content

Commit

Permalink
fp8 kv cache support
Browse files Browse the repository at this point in the history
  • Loading branch information
jikunshang committed Jul 1, 2024
1 parent 614aa51 commit 174c369
Show file tree
Hide file tree
Showing 7 changed files with 526 additions and 139 deletions.
2 changes: 1 addition & 1 deletion cmake/cpu_extension.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED)
"-mavx512dq")

find_isa(${CPUINFO} "avx512_bf16" AVX512BF16_FOUND)
if (AVX512BF16_FOUND OR ENABLE_AVX512BF16)
if (AVX512BF16_FOUND AND ENABLE_AVX512BF16)
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3)
list(APPEND CXX_COMPILE_FLAGS "-mavx512bf16")
Expand Down
290 changes: 182 additions & 108 deletions csrc/cpu/attention.cpp

Large diffs are not rendered by default.

93 changes: 68 additions & 25 deletions csrc/cpu/cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,33 @@ void copy_blocks_cpu_impl(std::vector<torch::Tensor> const& key_caches,
}
}

template <typename scalar_t>
void reshape_and_cache_cpu_impl(
const scalar_t* __restrict__ key, const scalar_t* __restrict__ value,
scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache,
const int64_t* __restrict__ slot_mapping, const int num_tokens,
const int key_stride, const int value_stride, const int num_heads,
const int head_size, const int block_size, const int x) {
template <typename scalar_t, typename cache_t = scalar_t>
cache_t assign_cache_value(const scalar_t* src) {
return *src;
}

template <>
uint8_t assign_cache_value<float, uint8_t>(const float* src) {
uint8_t res = cast_fp32x1_to_fp8x1(*src);
return res;
}

template <>
uint8_t assign_cache_value<int16_t, uint8_t>(const int16_t* src) {
uint8_t res = cast_bf16x1_to_fp8x1(*src);
return res;
}

template <typename scalar_t, typename cache_t = scalar_t, bool use_fp8 = false>
void reshape_and_cache_cpu_impl(const scalar_t* __restrict__ key,
const scalar_t* __restrict__ value,
cache_t* __restrict__ key_cache,
cache_t* __restrict__ value_cache,
const int64_t* __restrict__ slot_mapping,
const int num_tokens, const int key_stride,
const int value_stride, const int num_heads,
const int head_size, const int block_size,
const int kv_cache_stride, const int x) {
const int block_elem_num = num_heads * head_size * block_size;

#pragma omp parallel for collapse(2)
Expand All @@ -53,19 +73,20 @@ void reshape_and_cache_cpu_impl(
const scalar_t* src_value_head_ptr = value + src_value_head_idx;
const int64_t block_index = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size;
scalar_t* target_key_head_ptr = key_cache +
block_elem_num * block_index +
head_idx * block_size * head_size;
scalar_t* target_value_head_ptr = value_cache +
block_elem_num * block_index +
head_idx * block_size * head_size;
cache_t* target_key_head_ptr = key_cache +
kv_cache_stride * block_index +
head_idx * block_size * head_size;
cache_t* target_value_head_ptr = value_cache +
kv_cache_stride * block_index +
head_idx * block_size * head_size;

for (int src_key_idx = 0; src_key_idx < head_size; src_key_idx += x) {
const int64_t target_offset =
src_key_idx * block_size + block_offset * x;
for (int i = 0; i < x; ++i) {
target_key_head_ptr[target_offset + i] =
src_key_head_ptr[src_key_idx + i];
assign_cache_value<scalar_t, cache_t>(src_key_head_ptr +
src_key_idx + i);
}
}

Expand All @@ -74,7 +95,8 @@ void reshape_and_cache_cpu_impl(
const int64_t target_offset =
src_value_idx * block_size + block_offset;
target_value_head_ptr[target_offset] =
src_value_head_ptr[src_value_idx];
assign_cache_value<scalar_t, cache_t>(src_value_head_ptr +
src_value_idx);
}
}
}
Expand Down Expand Up @@ -104,6 +126,17 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
});
}

#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_KV_CACHE) \
CPU_KERNEL_GUARD_IN(reshape_and_cache_cpu_impl) \
reshape_and_cache_cpu_impl<KV_T, CACHE_T, IS_FP8_KV_CACHE>( \
reinterpret_cast<KV_T*>(key.data_ptr()), \
reinterpret_cast<KV_T*>(value.data_ptr()), \
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), num_tokens, key_stride, value_stride, \
num_heads, head_size, block_size, kv_cache_stride, x); \
CPU_KERNEL_GUARD_OUT(reshape_and_cache_cpu_impl)

void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor& slot_mapping,
Expand All @@ -115,20 +148,30 @@ void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
int head_size = key.size(2);
int block_size = key_cache.size(3);
int x = key_cache.size(4);
int kv_cache_stride = key_cache.stride(0);

int key_stride = key.stride(0);
int value_stride = value.stride(0);

VLLM_DISPATCH_FLOATING_TYPES(
key.scalar_type(), "reshape_and_cache_cpu_impl", [&] {
CPU_KERNEL_GUARD_IN(reshape_and_cache_cpu_impl)
reshape_and_cache_cpu_impl<scalar_t>(
key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(), value_cache.data_ptr<scalar_t>(),
slot_mapping.data_ptr<int64_t>(), num_tokens, key_stride,
value_stride, num_heads, head_size, block_size, x);
CPU_KERNEL_GUARD_OUT(reshape_and_cache_cpu_impl)
});
if (kv_cache_dtype == "auto") {
if (key.dtype() == at::ScalarType::Float) {
CALL_RESHAPE_AND_CACHE(float, float, false);
} else if (key.dtype() == at::ScalarType::Half) {
TORCH_CHECK(false, "Unsupported data type: Half");
} else if (key.dtype() == at::ScalarType::BFloat16) {
CALL_RESHAPE_AND_CACHE(int16_t, int16_t, false);
}
} else if (kv_cache_dtype == "fp8") {
if (key.dtype() == at::ScalarType::Float) {
CALL_RESHAPE_AND_CACHE(float, uint8_t, true);
} else if (key.dtype() == at::ScalarType::Half) {
TORCH_CHECK(false, "Unsupported data type: Half");
} else if (key.dtype() == at::ScalarType::BFloat16) {
CALL_RESHAPE_AND_CACHE(int16_t, uint8_t, true);
}
} else {
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
}
}

void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
Expand Down
1 change: 0 additions & 1 deletion csrc/cpu/cpu_types.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

#ifndef CPU_TYPES_HPP
#define CPU_TYPES_HPP

Expand Down
19 changes: 19 additions & 0 deletions csrc/cpu/cpu_types_x86.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
#include <immintrin.h>
#include <torch/all.h>

#include "fp8_utils.h"

typedef uint8_t cpu_fp8;

#ifndef __AVX2__
static_assert(false, "AVX2 must be supported for the current implementation.");
#endif
Expand Down Expand Up @@ -50,6 +54,19 @@ template <typename T> struct Vec {
struct FP32Vec8;
struct FP32Vec16;

struct FP8Vec16 : public Vec<FP8Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
union AliasReg {
__m128 reg;
cpu_fp8 values[VEC_ELEM_NUM];
};
__m128 reg;

explicit FP8Vec16() : reg(_mm_set1_ps(0)) {}
explicit FP8Vec16(const cpu_fp8 *ptr) : reg((__m128)_mm_loadu_epi8(ptr)) {}

};

#ifdef __AVX512FP16__
struct FP16Vec8 : public Vec<FP16Vec8> {
constexpr static int VEC_ELEM_NUM = 8;
Expand Down Expand Up @@ -279,6 +296,8 @@ struct FP32Vec16 : public Vec<FP32Vec16> {

explicit FP32Vec16(const FP32Vec16 &data) : reg(data.reg) {}

explicit FP32Vec16(const FP8Vec16 &data) : reg(cast_fp8x16_to_fp32x16((__m128)data.reg)) {}

explicit FP32Vec16(const FP32Vec4 &data)
: reg((__m512)_mm512_inserti32x4(
_mm512_inserti32x4(
Expand Down
Loading

0 comments on commit 174c369

Please sign in to comment.