diff --git a/include/onnxruntime/core/common/spin_pause.h b/include/onnxruntime/core/common/spin_pause.h index 49b71e5567d3e..163a99d598ac6 100644 --- a/include/onnxruntime/core/common/spin_pause.h +++ b/include/onnxruntime/core/common/spin_pause.h @@ -11,6 +11,10 @@ #include #endif +#if defined(_M_AMD64) || defined(__x86_64__) +#include +#endif + namespace onnxruntime { namespace concurrency { @@ -23,6 +27,13 @@ inline void SpinPause() { #endif } +inline void SpinTPAUSE() { +#if defined(_M_AMD64) || defined(__x86_64__) + const std::uint64_t spin_delay_cycles = 2000; + _tpause(0x0, __rdtsc() + spin_delay_cycles); +#endif +} + } // namespace concurrency } // namespace onnxruntime diff --git a/include/onnxruntime/core/platform/EigenNonBlockingThreadPool.h b/include/onnxruntime/core/platform/EigenNonBlockingThreadPool.h index f9b694efb936f..a5be5711f7837 100644 --- a/include/onnxruntime/core/platform/EigenNonBlockingThreadPool.h +++ b/include/onnxruntime/core/platform/EigenNonBlockingThreadPool.h @@ -54,6 +54,10 @@ #include "core/platform/ort_spin_lock.h" #include "core/platform/Barrier.h" +#if defined(CPUINFO_SUPPORTED) && !defined(__APPLE__) && !defined(__ANDROID__) && !defined(__wasm__) && !defined(_AIX) +#include +#endif + // ORT thread pool overview // ------------------------ // @@ -1535,6 +1539,7 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter constexpr int log2_spin = 20; const int spin_count = allow_spinning_ ? (1ull << log2_spin) : 0; const int steal_count = spin_count / 100; + const bool tpause = CPUIDInfo::GetCPUIDInfo().HasTPAUSE(); SetDenormalAsZero(set_denormal_as_zero_); profiler_.LogThreadId(thread_id); @@ -1554,8 +1559,14 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter if (spin_loop_status_.load(std::memory_order_relaxed) == SpinLoopStatus::kIdle) { break; } - onnxruntime::concurrency::SpinPause(); - } + + if (tpause) { + onnxruntime::concurrency::SpinTPAUSE(); + } + else { + onnxruntime::concurrency::SpinPause(); + } + } // Attempt to block if (!t) { diff --git a/onnxruntime/core/common/cpuid_info.cc b/onnxruntime/core/common/cpuid_info.cc index be881f6bc4bc2..4b8bf14ecd53d 100644 --- a/onnxruntime/core/common/cpuid_info.cc +++ b/onnxruntime/core/common/cpuid_info.cc @@ -42,6 +42,7 @@ #include "Windows.h" + #define HAS_WINDOWS_DESKTOP WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP) #ifndef PF_ARM_V82_DP_INSTRUCTIONS_AVAILABLE @@ -66,7 +67,7 @@ void decodeMIDR(uint32_t midr, uint32_t uarch[1]); namespace onnxruntime { #ifdef CPUIDINFO_ARCH_X86 - +#include #include #if defined(_MSC_VER) #include @@ -102,6 +103,15 @@ static inline int XGETBV() { #endif } +int filter(uint32_t code) { + if(code == STATUS_ILLEGAL_INSTRUCTION || code == STATUS_PRIVILEGED_INSTRUCTION) { + return EXCEPTION_EXECUTE_HANDLER; + } + else { + return EXCEPTION_CONTINUE_SEARCH; + } +} + void CPUIDInfo::X86Init() { int data[4] = {-1}; GetCPUID(0, data); @@ -131,6 +141,19 @@ void CPUIDInfo::X86Init() { // avx512_skylake = avx512f | avx512vl | avx512cd | avx512bw | avx512dq has_avx512_skylake_ = has_avx512 && (data[1] & ((1 << 16) | (1 << 17) | (1 << 28) | (1 << 30) | (1 << 31))); is_hybrid_ = (data[3] & (1 << 15)); + + // Check WAITPKG support + if((data[2] & (1 << 5))) { + // Some CPUs report TPAUSE support incorrectly, so a test is needed. + __try { + _tpause(0x0, __rdtsc() + 1000); + has_tpause_ = true; + } + __except(filter(GetExceptionCode())) { + has_tpause_ = false; + } + } + if (max_SubLeaves >= 1) { GetCPUID(7, 1, data); has_avx512_bf16_ = has_avx512 && (data[0] & (1 << 5)); diff --git a/onnxruntime/core/common/cpuid_info.h b/onnxruntime/core/common/cpuid_info.h index a3936b4bd11a6..7c7c86899b9c9 100644 --- a/onnxruntime/core/common/cpuid_info.h +++ b/onnxruntime/core/common/cpuid_info.h @@ -25,6 +25,7 @@ class CPUIDInfo { bool HasSSE3() const { return has_sse3_; } bool HasSSE4_1() const { return has_sse4_1_; } bool IsHybrid() const { return is_hybrid_; } + bool HasTPAUSE() const { return has_tpause_; } // ARM bool HasArmNeonDot() const { return has_arm_neon_dot_; } @@ -104,6 +105,7 @@ class CPUIDInfo { bool has_sse3_{false}; bool has_sse4_1_{false}; bool is_hybrid_{false}; + bool has_tpause_{false}; std::vector core_uarchs_; // micro-arch of each core