Skip to content

Commit

Permalink
[Hardware][Intel] Support CPU inference with AVX2 ISA (vllm-project#5452
Browse files Browse the repository at this point in the history
)
  • Loading branch information
DamonFool authored Jun 13, 2024
1 parent 50eed24 commit cd9c0d6
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 2 deletions.
6 changes: 5 additions & 1 deletion cmake/cpu_extension.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ function (find_isa CPUINFO TARGET OUT)
endif()
endfunction()

find_isa(${CPUINFO} "avx2" AVX2_FOUND)
find_isa(${CPUINFO} "avx512f" AVX512_FOUND)

if (AVX512_FOUND)
Expand All @@ -53,8 +54,11 @@ if (AVX512_FOUND)
else()
message(WARNING "Disable AVX512-BF16 ISA support, no avx512_bf16 found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512BF16=1.")
endif()
elseif (AVX2_FOUND)
list(APPEND CXX_COMPILE_FLAGS "-mavx2")
message(WARNING "vLLM CPU backend using AVX2 ISA")
else()
message(FATAL_ERROR "vLLM CPU backend requires AVX512 ISA support.")
message(FATAL_ERROR "vLLM CPU backend requires AVX512 or AVX2 ISA support.")
endif()

message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}")
Expand Down
165 changes: 164 additions & 1 deletion csrc/cpu/cpu_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
#include <immintrin.h>
#include <torch/all.h>

#ifndef __AVX2__
static_assert(false, "AVX2 must be supported for the current implementation.");
#endif

namespace vec_op {

// FIXME: FP16 is not fully supported in Torch-CPU
Expand Down Expand Up @@ -104,6 +108,7 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
void save(void *ptr) const { *reinterpret_cast<__m256i *>(ptr) = reg; }
};

#ifdef __AVX512F__
struct BF16Vec32 : public Vec<BF16Vec32> {
constexpr static int VEC_ELEM_NUM = 32;

Expand All @@ -123,6 +128,34 @@ struct BF16Vec32 : public Vec<BF16Vec32> {

void save(void *ptr) const { *reinterpret_cast<__m512i *>(ptr) = reg; }
};
#else
struct BF16Vec32 : public Vec<BF16Vec32> {
constexpr static int VEC_ELEM_NUM = 32;

__m256i reg_low;
__m256i reg_high;

explicit BF16Vec32(const void *ptr)
: reg_low(_mm256_loadu_si256((__m256i const *)ptr)),
reg_high(_mm256_loadu_si256((__m256i const *)ptr + 1)) {}

explicit BF16Vec32(__m256i low, __m256i high) : reg_low(low),
reg_high(high) {}

explicit BF16Vec32(BF16Vec8 &vec8_data)
: reg_low((__m256i)_mm256_inserti32x4(
_mm256_castsi128_si256((__m128i)vec8_data.reg),
(__m128i)vec8_data.reg, 1)),
reg_high((__m256i)_mm256_inserti32x4(
_mm256_castsi128_si256((__m128i)vec8_data.reg),
(__m128i)vec8_data.reg, 1)) {}

void save(void *ptr) const {
*reinterpret_cast<__m256i *>(ptr) = reg_low;
*reinterpret_cast<__m256i *>((__m256i *)ptr + 1) = reg_high;
}
};
#endif

struct FP32Vec4 : public Vec<FP32Vec4> {
constexpr static int VEC_ELEM_NUM = 4;
Expand Down Expand Up @@ -226,6 +259,7 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
void save(float *ptr) const { _mm256_storeu_ps(ptr, reg); }
};

#ifdef __AVX512F__
struct FP32Vec16 : public Vec<FP32Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
union AliasReg {
Expand Down Expand Up @@ -290,6 +324,114 @@ struct FP32Vec16 : public Vec<FP32Vec16> {

void save(float *ptr) const { _mm512_storeu_ps(ptr, reg); }
};
#else
struct FP32Vec16 : public Vec<FP32Vec16> {
constexpr static int VEC_ELEM_NUM = 16;

union AliasReg {
__m256 reg;
float values[8];
};

__m256 reg_low;
__m256 reg_high;

explicit FP32Vec16(float v) : reg_low(_mm256_set1_ps(v)),
reg_high(_mm256_set1_ps(v)) {}

explicit FP32Vec16() : reg_low(_mm256_set1_ps(0.0)),
reg_high(_mm256_set1_ps(0.0)) {}

explicit FP32Vec16(const float *ptr) : reg_low(_mm256_loadu_ps(ptr)),
reg_high(_mm256_loadu_ps(ptr + 8)) {}

explicit FP32Vec16(__m256 low, __m256 high) : reg_low(low), reg_high(high) {}

explicit FP32Vec16(const FP32Vec16 &data) : reg_low(data.reg_low),
reg_high(data.reg_high) {}

explicit FP32Vec16(const FP32Vec4 &data)
: reg_low((__m256)_mm256_inserti128_si256(
_mm256_castsi128_si256((__m128i)data.reg),
(__m128i)data.reg, 1)),
reg_high((__m256)_mm256_inserti128_si256(
_mm256_castsi128_si256((__m128i)data.reg),
(__m128i)data.reg, 1)) {}

explicit FP32Vec16(const FP32Vec8 &data)
: reg_low(data.reg), reg_high(data.reg) {}

explicit FP32Vec16(const BF16Vec16 &v) {
__m128i low = _mm256_extractf128_si256(v.reg, 0);
__m128i high = _mm256_extractf128_si256(v.reg, 1);

__m256i v_low_epi32 = _mm256_cvtepu16_epi32(low);
__m256i v_high_epi32 = _mm256_cvtepu16_epi32(high);

__m256i v_low_shifted = _mm256_bslli_epi128(v_low_epi32, 2);
__m256i v_high_shifted = _mm256_bslli_epi128(v_high_epi32, 2);

reg_low = _mm256_castsi256_ps(v_low_shifted);
reg_high = _mm256_castsi256_ps(v_high_shifted);
}

explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {}

FP32Vec16 operator*(const FP32Vec16 &b) const {
return FP32Vec16(_mm256_mul_ps(reg_low, b.reg_low),
_mm256_mul_ps(reg_high, b.reg_high));
}

FP32Vec16 operator+(const FP32Vec16 &b) const {
return FP32Vec16(_mm256_add_ps(reg_low, b.reg_low),
_mm256_add_ps(reg_high, b.reg_high));
}

FP32Vec16 operator-(const FP32Vec16 &b) const {
return FP32Vec16(_mm256_sub_ps(reg_low, b.reg_low),
_mm256_sub_ps(reg_high, b.reg_high));
}

FP32Vec16 operator/(const FP32Vec16 &b) const {
return FP32Vec16(_mm256_div_ps(reg_low, b.reg_low),
_mm256_div_ps(reg_high, b.reg_high));
}

float reduce_sum() const {
FP32Vec8 low = FP32Vec8(reg_low);
FP32Vec8 high = FP32Vec8(reg_high);
return low.reduce_sum() + high.reduce_sum();
}

template <int group_size> float reduce_sub_sum(int idx) {
float sum = 0.0;
static_assert(VEC_ELEM_NUM % group_size == 0);
constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size));
uint32_t mask = base_mask << (idx * group_size);

AliasReg ar;

auto func = [&sum, &mask, &ar](int i) {
int flag = mask & 0x1;
mask = mask >> 1;
if (flag != 0) sum += ar.values[i];
};

ar.reg = reg_low;
unroll_loop<int, 8>(func);

ar.reg = reg_high;
unroll_loop<int, 8>(func);

return sum;
}

void save(float *ptr) const {
_mm256_storeu_ps(ptr, reg_low);
_mm256_storeu_ps(ptr + 8, reg_high);
}
};
#endif

template <typename T> struct VecType { using vec_type = void; };

Expand Down Expand Up @@ -336,14 +478,35 @@ template <> inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16 *ptr) {
*ptr = *(v_ptr + 1);
}

#ifdef __AVX512F__
inline BF16Vec8::BF16Vec8(const FP32Vec8 &v)
: reg(_mm256_cvtepi32_epi16(
_mm256_bsrli_epi128(_mm256_castps_si256(v.reg), 2))) {}

inline BF16Vec16::BF16Vec16(const FP32Vec16 &v)
: reg(_mm512_cvtepi32_epi16(
_mm512_bsrli_epi128(_mm512_castps_si512(v.reg), 2))) {}
#endif
#else
namespace{
__m128i FP32Vec8_to_BF16Vec8_avx2(__m256 a) {
__m256i ai = _mm256_castps_si256(a);
ai = _mm256_srli_epi32(ai, 16);
ai = _mm256_packus_epi32(ai, ai);
ai = _mm256_permute4x64_epi64(ai, 0b00111001);
return _mm256_extracti128_si256(ai, 0);
}
}

inline BF16Vec8::BF16Vec8(const FP32Vec8 &v)
: reg(FP32Vec8_to_BF16Vec8_avx2(v.reg)) {}

inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) {
BF16Vec8 low = BF16Vec8(FP32Vec8(v.reg_low));
BF16Vec8 high = BF16Vec8(FP32Vec8(v.reg_high));
reg = _mm256_insertf128_si256(_mm256_castsi128_si256(low.reg), high.reg, 1);
}
#endif // __AVX512F__
#endif // __AVX512BF16__

inline void prefetch(const void *addr) { _mm_prefetch(addr, _MM_HINT_T1); }

Expand Down

0 comments on commit cd9c0d6

Please sign in to comment.