From cd9c0d65d98f86fbd2235ee41b80107097a57f77 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jie=20Fu=20=28=E5=82=85=E6=9D=B0=29?= Date: Fri, 14 Jun 2024 07:22:24 +0800 Subject: [PATCH] [Hardware][Intel] Support CPU inference with AVX2 ISA (#5452) --- cmake/cpu_extension.cmake | 6 +- csrc/cpu/cpu_types.hpp | 165 +++++++++++++++++++++++++++++++++++++- 2 files changed, 169 insertions(+), 2 deletions(-) diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index 61d4843838ba0..a644e5b6a8b21 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -33,6 +33,7 @@ function (find_isa CPUINFO TARGET OUT) endif() endfunction() +find_isa(${CPUINFO} "avx2" AVX2_FOUND) find_isa(${CPUINFO} "avx512f" AVX512_FOUND) if (AVX512_FOUND) @@ -53,8 +54,11 @@ if (AVX512_FOUND) else() message(WARNING "Disable AVX512-BF16 ISA support, no avx512_bf16 found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512BF16=1.") endif() +elseif (AVX2_FOUND) + list(APPEND CXX_COMPILE_FLAGS "-mavx2") + message(WARNING "vLLM CPU backend using AVX2 ISA") else() - message(FATAL_ERROR "vLLM CPU backend requires AVX512 ISA support.") + message(FATAL_ERROR "vLLM CPU backend requires AVX512 or AVX2 ISA support.") endif() message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}") diff --git a/csrc/cpu/cpu_types.hpp b/csrc/cpu/cpu_types.hpp index 034c406a532d5..d7621aaae81c9 100644 --- a/csrc/cpu/cpu_types.hpp +++ b/csrc/cpu/cpu_types.hpp @@ -5,6 +5,10 @@ #include #include +#ifndef __AVX2__ +static_assert(false, "AVX2 must be supported for the current implementation."); +#endif + namespace vec_op { // FIXME: FP16 is not fully supported in Torch-CPU @@ -104,6 +108,7 @@ struct BF16Vec16 : public Vec { void save(void *ptr) const { *reinterpret_cast<__m256i *>(ptr) = reg; } }; +#ifdef __AVX512F__ struct BF16Vec32 : public Vec { constexpr static int VEC_ELEM_NUM = 32; @@ -123,6 +128,34 @@ struct BF16Vec32 : public Vec { void save(void *ptr) const { *reinterpret_cast<__m512i *>(ptr) = reg; } }; +#else +struct BF16Vec32 : public Vec { + constexpr static int VEC_ELEM_NUM = 32; + + __m256i reg_low; + __m256i reg_high; + + explicit BF16Vec32(const void *ptr) + : reg_low(_mm256_loadu_si256((__m256i const *)ptr)), + reg_high(_mm256_loadu_si256((__m256i const *)ptr + 1)) {} + + explicit BF16Vec32(__m256i low, __m256i high) : reg_low(low), + reg_high(high) {} + + explicit BF16Vec32(BF16Vec8 &vec8_data) + : reg_low((__m256i)_mm256_inserti32x4( + _mm256_castsi128_si256((__m128i)vec8_data.reg), + (__m128i)vec8_data.reg, 1)), + reg_high((__m256i)_mm256_inserti32x4( + _mm256_castsi128_si256((__m128i)vec8_data.reg), + (__m128i)vec8_data.reg, 1)) {} + + void save(void *ptr) const { + *reinterpret_cast<__m256i *>(ptr) = reg_low; + *reinterpret_cast<__m256i *>((__m256i *)ptr + 1) = reg_high; + } +}; +#endif struct FP32Vec4 : public Vec { constexpr static int VEC_ELEM_NUM = 4; @@ -226,6 +259,7 @@ struct FP32Vec8 : public Vec { void save(float *ptr) const { _mm256_storeu_ps(ptr, reg); } }; +#ifdef __AVX512F__ struct FP32Vec16 : public Vec { constexpr static int VEC_ELEM_NUM = 16; union AliasReg { @@ -290,6 +324,114 @@ struct FP32Vec16 : public Vec { void save(float *ptr) const { _mm512_storeu_ps(ptr, reg); } }; +#else +struct FP32Vec16 : public Vec { + constexpr static int VEC_ELEM_NUM = 16; + + union AliasReg { + __m256 reg; + float values[8]; + }; + + __m256 reg_low; + __m256 reg_high; + + explicit FP32Vec16(float v) : reg_low(_mm256_set1_ps(v)), + reg_high(_mm256_set1_ps(v)) {} + + explicit FP32Vec16() : reg_low(_mm256_set1_ps(0.0)), + reg_high(_mm256_set1_ps(0.0)) {} + + explicit FP32Vec16(const float *ptr) : reg_low(_mm256_loadu_ps(ptr)), + reg_high(_mm256_loadu_ps(ptr + 8)) {} + + explicit FP32Vec16(__m256 low, __m256 high) : reg_low(low), reg_high(high) {} + + explicit FP32Vec16(const FP32Vec16 &data) : reg_low(data.reg_low), + reg_high(data.reg_high) {} + + explicit FP32Vec16(const FP32Vec4 &data) + : reg_low((__m256)_mm256_inserti128_si256( + _mm256_castsi128_si256((__m128i)data.reg), + (__m128i)data.reg, 1)), + reg_high((__m256)_mm256_inserti128_si256( + _mm256_castsi128_si256((__m128i)data.reg), + (__m128i)data.reg, 1)) {} + + explicit FP32Vec16(const FP32Vec8 &data) + : reg_low(data.reg), reg_high(data.reg) {} + + explicit FP32Vec16(const BF16Vec16 &v) { + __m128i low = _mm256_extractf128_si256(v.reg, 0); + __m128i high = _mm256_extractf128_si256(v.reg, 1); + + __m256i v_low_epi32 = _mm256_cvtepu16_epi32(low); + __m256i v_high_epi32 = _mm256_cvtepu16_epi32(high); + + __m256i v_low_shifted = _mm256_bslli_epi128(v_low_epi32, 2); + __m256i v_high_shifted = _mm256_bslli_epi128(v_high_epi32, 2); + + reg_low = _mm256_castsi256_ps(v_low_shifted); + reg_high = _mm256_castsi256_ps(v_high_shifted); + } + + explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {} + + FP32Vec16 operator*(const FP32Vec16 &b) const { + return FP32Vec16(_mm256_mul_ps(reg_low, b.reg_low), + _mm256_mul_ps(reg_high, b.reg_high)); + } + + FP32Vec16 operator+(const FP32Vec16 &b) const { + return FP32Vec16(_mm256_add_ps(reg_low, b.reg_low), + _mm256_add_ps(reg_high, b.reg_high)); + } + + FP32Vec16 operator-(const FP32Vec16 &b) const { + return FP32Vec16(_mm256_sub_ps(reg_low, b.reg_low), + _mm256_sub_ps(reg_high, b.reg_high)); + } + + FP32Vec16 operator/(const FP32Vec16 &b) const { + return FP32Vec16(_mm256_div_ps(reg_low, b.reg_low), + _mm256_div_ps(reg_high, b.reg_high)); + } + + float reduce_sum() const { + FP32Vec8 low = FP32Vec8(reg_low); + FP32Vec8 high = FP32Vec8(reg_high); + return low.reduce_sum() + high.reduce_sum(); + } + + template float reduce_sub_sum(int idx) { + float sum = 0.0; + static_assert(VEC_ELEM_NUM % group_size == 0); + constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size)); + uint32_t mask = base_mask << (idx * group_size); + + AliasReg ar; + + auto func = [&sum, &mask, &ar](int i) { + int flag = mask & 0x1; + mask = mask >> 1; + if (flag != 0) sum += ar.values[i]; + }; + + ar.reg = reg_low; + unroll_loop(func); + + ar.reg = reg_high; + unroll_loop(func); + + return sum; + } + + void save(float *ptr) const { + _mm256_storeu_ps(ptr, reg_low); + _mm256_storeu_ps(ptr + 8, reg_high); + } +}; +#endif template struct VecType { using vec_type = void; }; @@ -336,6 +478,7 @@ template <> inline void storeFP32(float v, c10::BFloat16 *ptr) { *ptr = *(v_ptr + 1); } +#ifdef __AVX512F__ inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) : reg(_mm256_cvtepi32_epi16( _mm256_bsrli_epi128(_mm256_castps_si256(v.reg), 2))) {} @@ -343,7 +486,27 @@ inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) : reg(_mm512_cvtepi32_epi16( _mm512_bsrli_epi128(_mm512_castps_si512(v.reg), 2))) {} -#endif +#else +namespace{ +__m128i FP32Vec8_to_BF16Vec8_avx2(__m256 a) { + __m256i ai = _mm256_castps_si256(a); + ai = _mm256_srli_epi32(ai, 16); + ai = _mm256_packus_epi32(ai, ai); + ai = _mm256_permute4x64_epi64(ai, 0b00111001); + return _mm256_extracti128_si256(ai, 0); +} +} + +inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) + : reg(FP32Vec8_to_BF16Vec8_avx2(v.reg)) {} + +inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) { + BF16Vec8 low = BF16Vec8(FP32Vec8(v.reg_low)); + BF16Vec8 high = BF16Vec8(FP32Vec8(v.reg_high)); + reg = _mm256_insertf128_si256(_mm256_castsi128_si256(low.reg), high.reg, 1); +} +#endif // __AVX512F__ +#endif // __AVX512BF16__ inline void prefetch(const void *addr) { _mm_prefetch(addr, _MM_HINT_T1); }