Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions onnxruntime/core/optimizer/insert_cast_transformer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,12 +221,12 @@ static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime::
}

static Status ForceSingleNodeCPUFloat16ToFloat32(onnxruntime::Graph& graph, const KernelRegistry& cpu_kernel_registry) {
for (auto& node : graph.Nodes()) {
if (IsIsolatedFp16NodeOnCpu(node, graph, cpu_kernel_registry)) {
// unassign the node so that NeedInsertCast will return true for it, forcing it to fp32
node.SetExecutionProviderType("");
}
}
// for (auto& node : graph.Nodes()) {
// if (IsIsolatedFp16NodeOnCpu(node, graph, cpu_kernel_registry)) {
// // unassign the node so that NeedInsertCast will return true for it, forcing it to fp32
// node.SetExecutionProviderType("");
// }
// }
Comment on lines +224 to +229
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove, or block comment using /* ... */ and add a comment as to why we've commented it out.


return Status::OK();
}
Expand Down
30 changes: 29 additions & 1 deletion onnxruntime/core/providers/cpu/cpu_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Aco
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Atan);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, float, Gemm);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, double, Gemm);
#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, MatMul);
#endif
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, Hardmax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, LogSoftmax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, double, LogSoftmax);
Expand Down Expand Up @@ -344,6 +347,9 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, Flatten);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, float, Gemm);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, double, Gemm);
#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, MLFloat16, MatMul);
#endif
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, float, MatMul);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, double, MatMul);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, int32_t, MatMul);
Expand Down Expand Up @@ -514,6 +520,9 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Sp
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, ScatterND);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, float, Gemm);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, double, Gemm);
#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, MatMul);
#endif
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, GatherElements);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint8_t, BitShift);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint32_t, BitShift);
Expand Down Expand Up @@ -620,6 +629,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, string, Expand);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Gemm);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Gemm);
#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, MLFloat16, MatMul);
#endif
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, MatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, MatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, MatMul);
Expand Down Expand Up @@ -2814,7 +2826,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {


BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kQuadricDomain, 1, QuadricCustomOp)>,

};

for (auto& function_table_entry : function_table) {
Expand All @@ -2827,6 +2839,13 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
return Status::OK();
}

#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
#pragma message("MLAS_F16VEC_INTRINSICS_SUPPORTED is defined")
#else
#pragma message("MLAS_F16VEC_INTRINSICS_SUPPORTED is NOT defined")
#endif


#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
Status RegisterFp16Kernels(KernelRegistry& kernel_registry) {
static const BuildKernelCreateInfoFn function_table[] = {
Expand All @@ -2853,6 +2872,14 @@ Status RegisterFp16Kernels(KernelRegistry& kernel_registry) {
MLFloat16, LeakyRelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, MLFloat16,
LeakyRelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8,
MLFloat16, MatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10,
MLFloat16, MatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12,
MLFloat16, MatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, MLFloat16,
MatMul)>
};

for (auto& function_table_entry : function_table) {
Expand Down Expand Up @@ -3104,6 +3131,7 @@ Status RegisterCPUKernels(KernelRegistry& kernel_registry) {
ORT_RETURN_IF_ERROR(RegisterOnnxOperatorKernels(kernel_registry));
#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
if (MlasFp16AccelerationSupported()) {
#pragma message("calling RegisterFp16Kernels")
ORT_RETURN_IF_ERROR(RegisterFp16Kernels(kernel_registry));
}
#endif
Expand Down
35 changes: 34 additions & 1 deletion onnxruntime/core/providers/cpu/math/matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,34 @@ ONNX_CPU_OPERATOR_TYPED_KERNEL(
.TypeConstraint("T", BuildKernelDefConstraints<int64_t, uint64_t>()),
MatMul<int64_t>);

ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(
MatMul,
7, 8,
MLFloat16,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<MLFloat16>()),
MatMul<MLFloat16>);

ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(
MatMul,
9, 10,
MLFloat16,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<MLFloat16>()),
MatMul<MLFloat16>);

ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(
MatMul,
11, 12,
MLFloat16,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<MLFloat16>()),
MatMul<MLFloat16>);

ONNX_CPU_OPERATOR_TYPED_KERNEL(
MatMul,
13,
MLFloat16,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<MLFloat16>()),
MatMul<MLFloat16>);

template <typename T>
Status MatMul<T>::Compute(OpKernelContext* ctx) const {
concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool();
Expand All @@ -108,7 +136,12 @@ Status MatMul<T>::Compute(OpKernelContext* ctx) const {
// be filled out with zeros.
EigenMatrixMapRowMajor<T> dest(y->MutableData<T>(),
narrow<Eigen::Index>(helper.M()), narrow<Eigen::Index>(helper.N()));
dest.setZero();
if constexpr (std::is_same<T, MLFloat16>::value) {
dest.setConstant(MLFloat16(0.0f));
} else {
dest.setZero();
}

return Status::OK();
}

Expand Down
76 changes: 76 additions & 0 deletions onnxruntime/core/util/math_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,30 @@ EIGEN_MATMUL_FUNCTION(uint32_t)
EIGEN_MATMUL_FUNCTION(int64_t)
EIGEN_MATMUL_FUNCTION(uint64_t)




// template <>
// void MatMul<MLFloat16>(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const MLFloat16* A, const MLFloat16* B, MLFloat16* C, concurrency::ThreadPool* thread_pool) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment here.

// // // Convert MLFloat16* to Eigen::half* using reinterpret_cast
// // const Eigen::half* A_half = reinterpret_cast<const Eigen::half*>(A);
// // const Eigen::half* B_half = reinterpret_cast<const Eigen::half*>(B);
// // Eigen::half* C_half = reinterpret_cast<Eigen::half*>(C);

// // // Perform matrix multiplication using Eigen
// // auto C_mat = Eigen::Map<Eigen::Matrix<Eigen::half, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>(C_half, M, N);
// // C_mat.noalias() = Eigen::Map<const Eigen::Matrix<Eigen::half, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>(A_half, M, K) *
// // Eigen::Map<const Eigen::Matrix<Eigen::half, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>(B_half, K, N);

// // Optionally, handle threading with thread_pool if needed (not shown here).

// math::Gemm<Eigen::half>(CblasNoTrans, CblasNoTrans, M, N, K, *reinterpret_cast<Eigen::half*>(&alpha),
// reinterpret_cast<const Eigen::half*>(A), reinterpret_cast<const Eigen::half*>(B), *reinterpret_cast<Eigen::half*>(&beta), reinterpret_cast<Eigen::half*>(C), thread_pool);
// }

// template void MatMul<MLFloat16>(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const MLFloat16* A, const MLFloat16* B, MLFloat16* C, concurrency::ThreadPool*);


////////////////////////////////////////////////////////////////////////////////
// BLAS alternatives.
// Depending on whether we have specified an external BLAS library or not, we
Expand Down Expand Up @@ -185,6 +209,58 @@ void MatMul<float>(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const float* A, const
MlasGemm(CblasNoTrans, CblasNoTrans, M, N, K, 1.f, A, K, B, N, 0.f, C, N, threadpool);
}


template <>
void MatMul<MLFloat16>(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const MLFloat16* a_data, const MLFloat16* b_data, MLFloat16* y_data, concurrency::ThreadPool* thread_pool) {

MLFloat16 alpha = MLFloat16(1.0f);
MLFloat16 beta = MLFloat16(0.0f);
// if input is empty tensor, return directly as nothing need to be calculated.
if (M == 0 || N == 0)
return;

#if defined(__GNUC__) && defined(HAS_CLASS_MEMACCESS)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wclass-memaccess"
#endif

memset(&beta, 0, sizeof(MLFloat16));
#if defined(__GNUC__) && defined(HAS_CLASS_MEMACCESS)
#pragma GCC diagnostic pop
#endif
#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED


MLAS_HALF_GEMM_DATA_PARAMS data;
data.A = a_data;
data.lda = K;
data.B = b_data;
data.ldb = N;
data.C = y_data;
data.ldc = N;
// if (c_shape != nullptr) {
// data.Bias = c_data;
// }
MlasHalfGemmBatch(M, N, K, 1, &data, thread_pool);
return;

#endif
// Fallback to Eigen
// // Broadcast the bias as needed if bias is given
// GemmBroadcastBias(M, N, beta, c_data, c_shape, y_data);
#if defined(__GNUC__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#endif
math::Gemm<Eigen::half>(CblasNoTrans, CblasNoTrans, M, N, K, *reinterpret_cast<Eigen::half*>(&alpha),
reinterpret_cast<const Eigen::half*>(a_data), reinterpret_cast<const Eigen::half*>(b_data), *reinterpret_cast<Eigen::half*>(&beta), reinterpret_cast<Eigen::half*>(y_data), thread_pool);
#if defined(__GNUC__)
#pragma GCC diagnostic pop
#endif

}


#ifdef MLAS_SUPPORTS_GEMM_DOUBLE
template <>
void MatMul<double>(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const double* A, const double* B, double* C, ThreadPool* threadpool) {
Expand Down
Loading