Skip to content

Commit

Permalink
TPAUSE implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
kleiti committed Apr 18, 2024
1 parent da86f6f commit be6d515
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 3 deletions.
11 changes: 11 additions & 0 deletions include/onnxruntime/core/common/spin_pause.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
#include <xmmintrin.h>
#endif

#if defined(_M_AMD64) || defined(__x86_64__)
#include <cstdint>
#endif

namespace onnxruntime {

namespace concurrency {
Expand All @@ -23,6 +27,13 @@ inline void SpinPause() {
#endif
}

inline void SpinTPAUSE() {
#if defined(_M_AMD64) || defined(__x86_64__)
const std::uint64_t spin_delay_cycles = 2000;
_tpause(0x0, __rdtsc() + spin_delay_cycles);
#endif
}

} // namespace concurrency

} // namespace onnxruntime
15 changes: 13 additions & 2 deletions include/onnxruntime/core/platform/EigenNonBlockingThreadPool.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@
#include "core/platform/ort_spin_lock.h"
#include "core/platform/Barrier.h"

#if defined(CPUINFO_SUPPORTED) && !defined(__APPLE__) && !defined(__ANDROID__) && !defined(__wasm__) && !defined(_AIX)
#include <cpuinfo.h>
#endif

// ORT thread pool overview
// ------------------------
//
Expand Down Expand Up @@ -1535,6 +1539,7 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter
constexpr int log2_spin = 20;
const int spin_count = allow_spinning_ ? (1ull << log2_spin) : 0;
const int steal_count = spin_count / 100;
const bool tpause = CPUIDInfo::GetCPUIDInfo().HasTPAUSE();

SetDenormalAsZero(set_denormal_as_zero_);
profiler_.LogThreadId(thread_id);
Expand All @@ -1554,8 +1559,14 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter
if (spin_loop_status_.load(std::memory_order_relaxed) == SpinLoopStatus::kIdle) {
break;
}
onnxruntime::concurrency::SpinPause();
}

if (tpause) {
onnxruntime::concurrency::SpinTPAUSE();
}
else {
onnxruntime::concurrency::SpinPause();
}
}

// Attempt to block
if (!t) {
Expand Down
25 changes: 24 additions & 1 deletion onnxruntime/core/common/cpuid_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@

#include "Windows.h"


#define HAS_WINDOWS_DESKTOP WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP)

#ifndef PF_ARM_V82_DP_INSTRUCTIONS_AVAILABLE
Expand All @@ -66,7 +67,7 @@ void decodeMIDR(uint32_t midr, uint32_t uarch[1]);
namespace onnxruntime {

#ifdef CPUIDINFO_ARCH_X86

#include <excpt.h>
#include <memory>
#if defined(_MSC_VER)
#include <intrin.h>
Expand Down Expand Up @@ -102,6 +103,15 @@ static inline int XGETBV() {
#endif
}

int filter(uint32_t code) {
if(code == STATUS_ILLEGAL_INSTRUCTION || code == STATUS_PRIVILEGED_INSTRUCTION) {
return EXCEPTION_EXECUTE_HANDLER;
}
else {
return EXCEPTION_CONTINUE_SEARCH;
}
}

void CPUIDInfo::X86Init() {
int data[4] = {-1};
GetCPUID(0, data);
Expand Down Expand Up @@ -131,6 +141,19 @@ void CPUIDInfo::X86Init() {
// avx512_skylake = avx512f | avx512vl | avx512cd | avx512bw | avx512dq
has_avx512_skylake_ = has_avx512 && (data[1] & ((1 << 16) | (1 << 17) | (1 << 28) | (1 << 30) | (1 << 31)));
is_hybrid_ = (data[3] & (1 << 15));

// Check WAITPKG support
if((data[2] & (1 << 5))) {
// Some CPUs report TPAUSE support incorrectly, so a test is needed.
__try {
_tpause(0x0, __rdtsc() + 1000);
has_tpause_ = true;
}
__except(filter(GetExceptionCode())) {
has_tpause_ = false;
}
}

if (max_SubLeaves >= 1) {
GetCPUID(7, 1, data);
has_avx512_bf16_ = has_avx512 && (data[0] & (1 << 5));
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/common/cpuid_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class CPUIDInfo {
bool HasSSE3() const { return has_sse3_; }
bool HasSSE4_1() const { return has_sse4_1_; }
bool IsHybrid() const { return is_hybrid_; }
bool HasTPAUSE() const { return has_tpause_; }

// ARM
bool HasArmNeonDot() const { return has_arm_neon_dot_; }
Expand Down Expand Up @@ -104,6 +105,7 @@ class CPUIDInfo {
bool has_sse3_{false};
bool has_sse4_1_{false};
bool is_hybrid_{false};
bool has_tpause_{false};

std::vector<uint32_t> core_uarchs_; // micro-arch of each core

Expand Down

0 comments on commit be6d515

Please sign in to comment.