Skip to content

Commit

Permalink
Cfu fp16 (#14538)
Browse files Browse the repository at this point in the history
### Description
FP16 GEMM, including hardware agnostic driver code, a slow C++ kernel,
and ARM64 NEON kernel.


### Motivation and Context
First step in creating native support of fp16 model inferencing on ARM64
and AMD64 platforms.

---------

Co-authored-by: Chen Fu <fuchen@microsoft.com>
  • Loading branch information
chenfucn and Chen Fu authored Feb 15, 2023
1 parent b1abb8c commit 733ca85
Show file tree
Hide file tree
Showing 15 changed files with 3,091 additions and 27 deletions.
7 changes: 7 additions & 0 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ onnxruntime_add_static_library(onnxruntime_mlas
${MLAS_SRC_DIR}/platform.cpp
${MLAS_SRC_DIR}/threading.cpp
${MLAS_SRC_DIR}/sgemm.cpp
${MLAS_SRC_DIR}/halfgemm.cpp
${MLAS_SRC_DIR}/qgemm.cpp
${MLAS_SRC_DIR}/qdwconv.cpp
${MLAS_SRC_DIR}/convolve.cpp
Expand Down Expand Up @@ -59,6 +60,7 @@ function(setup_mlas_source_for_windows)

if(onnxruntime_target_platform STREQUAL "ARM64")
target_sources(onnxruntime_mlas PRIVATE
${MLAS_SRC_DIR}/halfgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp
${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp
Expand All @@ -73,6 +75,7 @@ function(setup_mlas_source_for_windows)
${MLAS_SRC_DIR}/arm64/DepthwiseQConvSymS8KernelNeon.asm
${MLAS_SRC_DIR}/arm64/DepthwiseQConvSymU8KernelNeon.asm
${MLAS_SRC_DIR}/arm64/DepthwiseQConvKernelSize9Neon.asm
${MLAS_SRC_DIR}/arm64/HalfGemmKernelNeon.asm
${MLAS_SRC_DIR}/arm64/QgemmU8X8KernelNeon.asm
${MLAS_SRC_DIR}/arm64/QgemmS8S8KernelNeon.asm
${MLAS_SRC_DIR}/arm64/QgemmU8X8KernelUdot.asm
Expand Down Expand Up @@ -305,6 +308,7 @@ else()
${MLAS_SRC_DIR}/aarch64/DepthwiseQConvSymS8KernelNeon.S
${MLAS_SRC_DIR}/aarch64/DepthwiseQConvSymU8KernelNeon.S
${MLAS_SRC_DIR}/aarch64/DepthwiseQConvKernelSize9Neon.S
${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S
${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelNeon.S
${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelNeon.S
${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUdot.S
Expand All @@ -314,10 +318,13 @@ else()
${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelNeon.S
${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelSdot.S
${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelSdotLd64.S
${MLAS_SRC_DIR}/halfgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp
${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp
)
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")

if(ONNXRUNTIME_MLAS_MULTI_ARCH)
onnxruntime_add_static_library(onnxruntime_mlas_arm64 ${mlas_platform_srcs})
set_target_properties(onnxruntime_mlas_arm64 PROPERTIES OSX_ARCHITECTURES "arm64")
Expand Down
38 changes: 38 additions & 0 deletions onnxruntime/core/common/cpuid_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ void CPUIDInfo::ArmLinuxInit() {
if (pytorch_cpuinfo_init_) {
is_hybrid_ = cpuinfo_get_uarchs_count() > 1;
has_arm_neon_dot_ = cpuinfo_has_arm_neon_dot();
has_fp16_ = cpuinfo_has_arm_neon_fp16_arith();
const uint32_t core_cnt = cpuinfo_get_cores_count();
core_uarchs_.resize(core_cnt, cpuinfo_uarch_unknown);
is_armv8_narrow_ld_.resize(core_cnt, false);
Expand All @@ -165,6 +166,7 @@ void CPUIDInfo::ArmLinuxInit() {
}
} else {
has_arm_neon_dot_ = ((getauxval(AT_HWCAP) & HWCAP_ASIMDDP) != 0);
has_fp16_ |= has_arm_neon_dot_;
}
}

Expand Down Expand Up @@ -220,9 +222,45 @@ void CPUIDInfo::ArmWindowsInit() {
lastUarch = uarch;
}
}

switch (lastUarch) {
case cpuinfo_uarch_cortex_a55:
case cpuinfo_uarch_cortex_a55r0:
case cpuinfo_uarch_cortex_a76:
case cpuinfo_uarch_neoverse_n1:
case cpuinfo_uarch_cortex_a77:
case cpuinfo_uarch_exynos_m4:
case cpuinfo_uarch_exynos_m5:
has_fp16_ = true;
break;
default:
break;
}
if (!has_fp16_) {
/*
* Detecting fp16 support. Different cores should have the same instruction set.
* So we just check the first ID_AA64PFR0_EL1
* Op0(0b11), Op1(0b000), CRn(0b0000), CRm(0b0100), Op2(0b000),
*/
uint64_t ID_AA64PFR0_EL1;
unsigned long valsize = sizeof(uint64_t);
auto retCode = ::RegGetValueA(
HKEY_LOCAL_MACHINE,
"HARDWARE\\DESCRIPTION\\System\\CentralProcessor\\0",
"CP 4020", RRF_RT_REG_QWORD, nullptr,
&ID_AA64PFR0_EL1, &valsize);
if (retCode == ERROR_SUCCESS) {
// AdvSIMD, bits [23:20]
auto advSimd = ID_AA64PFR0_EL1 >> 20;
if ((advSimd & 0xfULL) == 1) {
has_fp16_ = true;
}
}
}
#endif /* Application Family or OneCore Family */

has_arm_neon_dot_ = (IsProcessorFeaturePresent(PF_ARM_V82_DP_INSTRUCTIONS_AVAILABLE) != 0);
has_fp16_ |= has_arm_neon_dot_;
}

#endif /* (arm or arm64) and windows */
Expand Down
6 changes: 5 additions & 1 deletion onnxruntime/core/common/cpuid_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class CPUIDInfo {
bool HasAVX512f() const { return has_avx512f_; }
bool HasAVX512_BF16() const {return has_avx512_bf16_;}
bool HasAVX512Skylake() const { return has_avx512_skylake_; }
bool HasF16C() const { return has_f16c_; }
bool HasF16C() const { return has_f16c_; } /*fp16 conversion inst*/
bool HasSSE3() const { return has_sse3_; }
bool HasSSE4_1() const { return has_sse4_1_; }
bool IsHybrid() const { return is_hybrid_; }
Expand Down Expand Up @@ -85,6 +85,9 @@ class CPUIDInfo {
return is_armv8_narrow_ld_[coreIdx];
}

bool HasFp16VectorAcceleration() const {
return has_fp16_;
}

private:
CPUIDInfo() {
Expand Down Expand Up @@ -118,6 +121,7 @@ class CPUIDInfo {
std::vector<bool> is_armv8_narrow_ld_;

bool has_arm_neon_dot_{false};
bool has_fp16_{false};

#ifdef CPUIDINFO_ARCH_X86

Expand Down
178 changes: 175 additions & 3 deletions onnxruntime/core/mlas/inc/mlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ typedef enum { CblasLeft=141, CblasRight=142} CBLAS_SIDE;
#endif

//
// Forward declare the thread pool implementation class.
// Forward declare the thread pool implementation class and half precision floating point.
//
// N.B. Avoid including ONNX Runtime headers here to keep the dependencies for
// standalone MLAS test executables smaller.
Expand All @@ -100,10 +100,12 @@ namespace onnxruntime {
namespace concurrency {
class ThreadPool;
};
};
struct MLFloat16;
}; // namespace onnxruntime

using MLAS_THREADPOOL = onnxruntime::concurrency::ThreadPool;


//
// Platform routines.
//
Expand Down Expand Up @@ -613,7 +615,7 @@ MlasGemm(
// Currently only supported in ARM64
//
#if defined(MLAS_TARGET_ARM64)
constexpr size_t MLAS_SYMM_QGEMM_BUF_OVERRUN = 15;
constexpr size_t MLAS_SYMM_QGEMM_BUF_OVERRUN = 30;
#else
constexpr size_t MLAS_SYMM_QGEMM_BUF_OVERRUN = 0;
#endif
Expand Down Expand Up @@ -1367,3 +1369,173 @@ MlasQLinearMul(
size_t N,
bool IsScalarB
);

//
// Half precision routines
//

// Any type with size=2 should work
using MLAS_FP16 = onnxruntime::MLFloat16;

constexpr size_t FP16_SIZE = sizeof(uint16_t);


bool MLASCALL
MlasFp16AccelerationSupported();

/**
* @brief Interface for half gemm post processors.
*
* Example implementation of this interface includes activations,
* conversion from half precision to single precision, etc.
*
* Half GEMM is computed tile by tile. When a tile of result matrix
* is produced, the method Process() is called to process this tile.
* Parameters of this method describe the location and shape of the
* tile.
*/
class MLAS_HALF_GEMM_POSTPROCESSOR {
public:
virtual
void
Process(
MLAS_FP16*, /**< the address of matrix to process */
size_t, /**< the start row index of matrix */
size_t, /**< the start col index of matrix */
size_t, /**< the element count per row to process */
size_t, /**< the element count per col to process */
size_t /**< the leading dimension of matrix */
) const = 0;

virtual ~MLAS_HALF_GEMM_POSTPROCESSOR() {}
};

/**
* @brief Convert half gemm result matrix to single precision float matrix
*/
class MLAS_HALF_GEMM_2FLOAT_PROCESSOR : public MLAS_HALF_GEMM_POSTPROCESSOR {
public:
MLAS_HALF_GEMM_2FLOAT_PROCESSOR(
float* Output, /**< address of the output matrix, row major */
size_t RowStride /**< row stride of the output matrix */
) :
Output_(Output),
RowStride_(RowStride)
{}

void
Process(
MLAS_FP16* C,
size_t StartM,
size_t StartN,
size_t CountM,
size_t CountN,
size_t ldc
) const override;

private:
float* Output_;
size_t RowStride_;
};


/**
* @brief Data parameters for half precision GEMM routine
* All except C are [in] parameters
*/
struct MLAS_HALF_GEMM_DATA_PARAMS {
const void* A = nullptr; /**< address of A */
const void* B = nullptr; /**< address of B */
const MLAS_FP16* Bias = nullptr; /**< address of Bias, vector size N */
MLAS_FP16* C = nullptr; /**< address of result matrix */
size_t lda = 0; /**< leading dimension of A */
size_t ldb = 0; /**< leading dimension of B, 0 when B is pre-packed*/
size_t ldc = 0; /**< leading dimension of C*/
const MLAS_HALF_GEMM_POSTPROCESSOR* OutputProcessor = nullptr;
bool AIsfp32 = false; /**< matrix A is fp32, needs to be casted into fp16*/
bool BIsfp32 = false; /**< matrix B is fp32, needs to be casted into fp16*/
};

/**
* @brief Half precision Batched GEMM: C = A * B + Bias
* Either A or B can be fp32 or fp16
*
* Note: We only support uniform batching, so shapes and types of the
* input must be same across all parameter blocks.
*
* @param[in] M row size of matrix A and C
* @param[in] N column size of matrix B and C
* @param[in] K column size of matrix A and row size of matrix B
* @param[in] BatchN number of batches
* @param[inout] DataParams An array (size BatchN) of parameter blocks
* @param[in] ThreadPool
* @return
*/
void
MLASCALL
MlasHalfGemmBatch(
const size_t M,
const size_t N,
const size_t K,
const size_t BatchN,
const MLAS_HALF_GEMM_DATA_PARAMS* DataParams,
MLAS_THREADPOOL* ThreadPool = nullptr
);

/**
* @brief For half precision GEMM, returns size of the
* packing buffer needed for right hand side
* @param[in] N Number of columns
* @param[in] K Number of rows
* @param[in] float2half Whether the input is float that
* needs to be converted to half precision
* @return size of the packing buffer,
* 0 if operation not supported
*/
size_t
MLASCALL
MlasHalfGemmPackBSize(
size_t N,
size_t K,
bool float2half
);

/**
* @brief For half precision GEMM, pack the right hand
* side matrix B
*
* @param[in] N Number of columns
* @param[in] K Number of rows
* @param[in] B Address of matrix B
* @param[in] ldb leading dimension of input matrix B
* @param[out] PackedB Address of the packed matrix
*/
void
MLASCALL
MlasHalfGemmPackB(
size_t N,
size_t K,
const MLAS_FP16* B,
size_t ldb,
void* PackedB
);

/**
* @brief For half precision GEMM, convert the float matrix B
* to half precision and pack it into a packing buffer
*
* @param[in] N Number of columns
* @param[in] K Number of rows
* @param[in] B Address of matrix B
* @param[in] ldb leading dimension of input matrix B
* @param[out] PackedB Address of the packed matrix
*/
void
MLASCALL
MlasHalfGemmConvertPackB(
size_t N,
size_t K,
const float* B,
size_t ldb,
void* PackedB
);
Loading

0 comments on commit 733ca85

Please sign in to comment.