Skip to content

Commit

Permalink
All LearningModelSessions created from a common LearningModelDevice s…
Browse files Browse the repository at this point in the history
…hould share the same thread pool (#11457)

* Share thread pools between devices

* make tests reuse device

* Change cpu thread pool options for dml sessions to use 1 thread with no spinning

* fix test failure

* Update missing type constraints for dft

* Add comment and rename inference session parameter

* default missing causing inconsistent test behavior

Co-authored-by: Sheil Kumar <sheilk@microsoft.com>
  • Loading branch information
smk2007 and Sheil Kumar authored May 13, 2022
1 parent 5709ed2 commit 6255194
Show file tree
Hide file tree
Showing 26 changed files with 422 additions and 98 deletions.
6 changes: 4 additions & 2 deletions onnxruntime/contrib_ops/cpu/signal/dft.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,17 @@ ONNX_OPERATOR_KERNEL_EX(
kMSExperimentalDomain,
1,
kCpuExecutionProvider,
KernelDefBuilder().TypeConstraint("T", BuildKernelDefConstraints<float, double>()),
KernelDefBuilder().TypeConstraint("T1", BuildKernelDefConstraints<float, double>())
.TypeConstraint("T2", BuildKernelDefConstraints<int64_t>()),
DFT);

ONNX_OPERATOR_KERNEL_EX(
IDFT,
kMSExperimentalDomain,
1,
kCpuExecutionProvider,
KernelDefBuilder().TypeConstraint("T", BuildKernelDefConstraints<float, double>()),
KernelDefBuilder().TypeConstraint("T1", BuildKernelDefConstraints<float, double>())
.TypeConstraint("T2", BuildKernelDefConstraints<int64_t>()),
IDFT);

ONNX_OPERATOR_KERNEL_EX(
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/graph/signal_ops/signal_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ void RegisterSignalSchemas() {
"If axis=N-1 and onesided is 1, the following shape is expected: [batch_idx][signal_dim1][signal_dim2]...[floor(signal_dimN/2)+1][2].",
"T1")
.TypeConstraint(
"T1",
"T1",
{"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"},
"Constrain input and output types to float tensors.")
.TypeConstraint(
Expand Down
146 changes: 85 additions & 61 deletions onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -268,69 +268,75 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options,
if (use_per_session_threads_) {
LOGS(*session_logger_, INFO) << "Creating and using per session threadpools since use_per_session_threads_ is true";
{
bool allow_intra_op_spinning =
session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsConfigAllowIntraOpSpinning, "1") == "1";
OrtThreadPoolParams to = session_options_.intra_op_param;
std::basic_stringstream<ORTCHAR_T> ss;
if (to.name) {
ss << to.name << ORT_TSTR("-");
}
ss << ORT_TSTR("session-") << session_id_ << ORT_TSTR("-intra-op");
thread_pool_name_ = ss.str();
to.name = thread_pool_name_.c_str();
to.set_denormal_as_zero = set_denormal_as_zero;
// If the thread pool can use all the processors, then
// we set affinity of each thread to each processor.
to.auto_set_affinity = to.thread_pool_size == 0 &&
session_options_.execution_mode == ExecutionMode::ORT_SEQUENTIAL &&
to.affinity_vec_len == 0;
to.allow_spinning = allow_intra_op_spinning;
to.dynamic_block_base_ = std::stoi(session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsConfigDynamicBlockBase, "0"));
LOGS(*session_logger_, INFO) << "Dynamic block base set to " << to.dynamic_block_base_;

// Set custom threading functions
to.custom_create_thread_fn = session_options_.custom_create_thread_fn;
to.custom_thread_creation_options = session_options.custom_thread_creation_options;
to.custom_join_thread_fn = session_options_.custom_join_thread_fn;

if (to.custom_create_thread_fn) {
ORT_ENFORCE(to.custom_join_thread_fn, "custom join thread function not set for intra op thread pool");
if (!external_intra_op_thread_pool_)
{
bool allow_intra_op_spinning =
session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsConfigAllowIntraOpSpinning, "1") == "1";
OrtThreadPoolParams to = session_options_.intra_op_param;
std::basic_stringstream<ORTCHAR_T> ss;
if (to.name) {
ss << to.name << ORT_TSTR("-");
}
ss << ORT_TSTR("session-") << session_id_ << ORT_TSTR("-intra-op");
thread_pool_name_ = ss.str();
to.name = thread_pool_name_.c_str();
to.set_denormal_as_zero = set_denormal_as_zero;
// If the thread pool can use all the processors, then
// we set affinity of each thread to each processor.
to.auto_set_affinity = to.thread_pool_size == 0 &&
session_options_.execution_mode == ExecutionMode::ORT_SEQUENTIAL &&
to.affinity_vec_len == 0;
to.allow_spinning = allow_intra_op_spinning;
to.dynamic_block_base_ = std::stoi(session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsConfigDynamicBlockBase, "0"));
LOGS(*session_logger_, INFO) << "Dynamic block base set to " << to.dynamic_block_base_;

// Set custom threading functions
to.custom_create_thread_fn = session_options_.custom_create_thread_fn;
to.custom_thread_creation_options = session_options.custom_thread_creation_options;
to.custom_join_thread_fn = session_options_.custom_join_thread_fn;

if (to.custom_create_thread_fn) {
ORT_ENFORCE(to.custom_join_thread_fn, "custom join thread function not set for intra op thread pool");
}
thread_pool_ =
concurrency::CreateThreadPool(&Env::Default(), to, concurrency::ThreadPoolType::INTRA_OP);
}
thread_pool_ =
concurrency::CreateThreadPool(&Env::Default(), to, concurrency::ThreadPoolType::INTRA_OP);
}
if (session_options_.execution_mode == ExecutionMode::ORT_PARALLEL) {
bool allow_inter_op_spinning =
session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsConfigAllowInterOpSpinning, "1") == "1";
OrtThreadPoolParams to = session_options_.inter_op_param;
// If the thread pool can use all the processors, then
// we set thread affinity.
to.auto_set_affinity =
to.thread_pool_size == 0 && session_options_.execution_mode == ExecutionMode::ORT_SEQUENTIAL;
std::basic_stringstream<ORTCHAR_T> ss;
if (to.name) {
ss << to.name << ORT_TSTR("-");
}
ss << ORT_TSTR("session-") << session_id_ << ORT_TSTR("-inter-op");
inter_thread_pool_name_ = ss.str();
to.name = inter_thread_pool_name_.c_str();
to.set_denormal_as_zero = set_denormal_as_zero;
to.allow_spinning = allow_inter_op_spinning;
to.dynamic_block_base_ = std::stoi(session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsConfigDynamicBlockBase, "0"));

// Set custom threading functions
to.custom_create_thread_fn = session_options_.custom_create_thread_fn;
to.custom_thread_creation_options = session_options.custom_thread_creation_options;
to.custom_join_thread_fn = session_options_.custom_join_thread_fn;

if (to.custom_create_thread_fn) {
ORT_ENFORCE(to.custom_join_thread_fn, "custom join thread function not set for inter op thread pool");
}
inter_op_thread_pool_ =
concurrency::CreateThreadPool(&Env::Default(), to, concurrency::ThreadPoolType::INTER_OP);
if (inter_op_thread_pool_ == nullptr) {
LOGS(*session_logger_, INFO) << "Failed to create the inter-op thread pool for the parallel executor, setting ExecutionMode to SEQUENTIAL";
session_options_.execution_mode = ExecutionMode::ORT_SEQUENTIAL;
if (!external_inter_op_thread_pool_)
{
bool allow_inter_op_spinning =
session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsConfigAllowInterOpSpinning, "1") == "1";
OrtThreadPoolParams to = session_options_.inter_op_param;
// If the thread pool can use all the processors, then
// we set thread affinity.
to.auto_set_affinity =
to.thread_pool_size == 0 && session_options_.execution_mode == ExecutionMode::ORT_SEQUENTIAL;
std::basic_stringstream<ORTCHAR_T> ss;
if (to.name) {
ss << to.name << ORT_TSTR("-");
}
ss << ORT_TSTR("session-") << session_id_ << ORT_TSTR("-inter-op");
inter_thread_pool_name_ = ss.str();
to.name = inter_thread_pool_name_.c_str();
to.set_denormal_as_zero = set_denormal_as_zero;
to.allow_spinning = allow_inter_op_spinning;
to.dynamic_block_base_ = std::stoi(session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsConfigDynamicBlockBase, "0"));

// Set custom threading functions
to.custom_create_thread_fn = session_options_.custom_create_thread_fn;
to.custom_thread_creation_options = session_options.custom_thread_creation_options;
to.custom_join_thread_fn = session_options_.custom_join_thread_fn;

if (to.custom_create_thread_fn) {
ORT_ENFORCE(to.custom_join_thread_fn, "custom join thread function not set for inter op thread pool");
}
inter_op_thread_pool_ =
concurrency::CreateThreadPool(&Env::Default(), to, concurrency::ThreadPoolType::INTER_OP);
if (inter_op_thread_pool_ == nullptr) {
LOGS(*session_logger_, INFO) << "Failed to create the inter-op thread pool for the parallel executor, setting ExecutionMode to SEQUENTIAL";
session_options_.execution_mode = ExecutionMode::ORT_SEQUENTIAL;
}
}
}
} else {
Expand Down Expand Up @@ -363,6 +369,24 @@ InferenceSession::InferenceSession(const SessionOptions& session_options, const
ConstructorCommon(session_options, session_env);
}

InferenceSession::InferenceSession(const SessionOptions& session_options,
const Environment& session_env,
onnxruntime::concurrency::ThreadPool* external_intra_op_thread_pool,
onnxruntime::concurrency::ThreadPool* external_inter_op_thread_pool)
:
#if !defined(ORT_MINIMAL_BUILD)
graph_transformation_mgr_(session_options.max_num_graph_transformation_steps),
insert_cast_transformer_("CastFloat16Transformer"),
#endif
logging_manager_(session_env.GetLoggingManager()),
external_intra_op_thread_pool_(external_intra_op_thread_pool),
external_inter_op_thread_pool_(external_inter_op_thread_pool),
environment_(session_env)
{
// Initialize assets of this session instance
ConstructorCommon(session_options, session_env);
}

#if !defined(ORT_MINIMAL_BUILD)
InferenceSession::InferenceSession(const SessionOptions& session_options, const Environment& session_env,
const std::string& model_uri)
Expand Down Expand Up @@ -1153,7 +1177,7 @@ Status PartitionOrtFormatModel(onnxruntime::Graph& graph,
std::unordered_map<std::string, HashValue> compiled_kernel_hashes;

GraphPartitioner partitioner(kernel_registry_manager, providers);
ORT_RETURN_IF_ERROR(partitioner.Partition(graph,
ORT_RETURN_IF_ERROR(partitioner.Partition(graph,
session_state.GetMutableFuncMgr(),
layout_transformer::TransformLayoutForCompilingEP,
GraphPartitioner::Mode::kOrtFormatLoad,
Expand Down
38 changes: 35 additions & 3 deletions onnxruntime/core/session/inference_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,20 @@ class InferenceSession {
explicit InferenceSession(const SessionOptions& session_options,
const Environment& session_env);

#if !defined(ORT_MINIMAL_BUILD)
/**
Create a new InferenceSession that accepts thread pools for intra and inter op thread execution.
Used by WinML only!
@param session_options Session options.
@param session_env This represents the context for the session and contains the logger and the global threadpools.
@param external_intra_op_thread_pool This represents the intra op threadpool.
@param external_inter_op_thread_pool This represents the inter op threadpool.
*/
explicit InferenceSession(const SessionOptions& session_options,
const Environment& session_env,
onnxruntime::concurrency::ThreadPool* external_intra_op_thread_pool,
onnxruntime::concurrency::ThreadPool* external_inter_op_thread_pool);

#if !defined(ORT_MINIMAL_BUILD)
/**
Create a new InferenceSession
@param session_options Session options.
Expand Down Expand Up @@ -489,11 +501,27 @@ class InferenceSession {
// specific flags in session options
// These methods assume that session options have been finalized before the call.
onnxruntime::concurrency::ThreadPool* GetIntraOpThreadPoolToUse() const {
return session_options_.use_per_session_threads ? thread_pool_.get() : intra_op_thread_pool_from_env_;
if (session_options_.use_per_session_threads) {
if (external_intra_op_thread_pool_) {
return external_intra_op_thread_pool_;
} else {
return thread_pool_.get();
}
} else {
return intra_op_thread_pool_from_env_;
}
}

onnxruntime::concurrency::ThreadPool* GetInterOpThreadPoolToUse() const {
return session_options_.use_per_session_threads ? inter_op_thread_pool_.get() : inter_op_thread_pool_from_env_;
if (session_options_.use_per_session_threads) {
if (external_inter_op_thread_pool_) {
return external_inter_op_thread_pool_;
} else {
return inter_op_thread_pool_.get();
}
} else {
return inter_op_thread_pool_from_env_;
}
}

/// convenience pointer to logger. should always be the same as session_state_.Logger();
Expand Down Expand Up @@ -649,6 +677,10 @@ class InferenceSession {
onnxruntime::concurrency::ThreadPool* intra_op_thread_pool_from_env_{};
onnxruntime::concurrency::ThreadPool* inter_op_thread_pool_from_env_{};

// External threadpools.
onnxruntime::concurrency::ThreadPool* external_intra_op_thread_pool_{};
onnxruntime::concurrency::ThreadPool* external_inter_op_thread_pool_{};

// initialized from session options
// Determines which threadpools will be intialized and used for the duration of this session.
// If true, use the per session ones, or else the global threadpools.
Expand Down
9 changes: 8 additions & 1 deletion winml/adapter/winml_adapter_apis.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ namespace AI {
namespace MachineLearning {
namespace Adapter {

ORT_API(void, ReleaseThreadPool, OrtThreadPool*);
ORT_API(void, ReleaseModel, OrtModel*);
ORT_API(void, ReleaseExecutionProvider, OrtExecutionProvider*);

Expand Down Expand Up @@ -44,7 +45,7 @@ ORT_API_STATUS(SaveModel, _In_ const OrtModel* in, _In_ const wchar_t* const fil
ORT_API_STATUS(OrtSessionOptionsAppendExecutionProviderEx_DML, _In_ OrtSessionOptions* options, _In_ ID3D12Device* d3d_device, _In_ ID3D12CommandQueue* cmd_queue, bool metacommands_enabled);

// OrtSession methods
ORT_API_STATUS(CreateSessionWithoutModel, _In_ OrtEnv* env, _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** session);
ORT_API_STATUS(CreateSessionWithoutModel, _In_ OrtEnv* env, _In_ const OrtSessionOptions* options, _In_ OrtThreadPool* inter_op_thread_pool, _In_ OrtThreadPool* intra_op_thread_pool, _Outptr_ OrtSession** session);

//Do not release provider... as there is no release method available
ORT_API_STATUS(SessionGetExecutionProvider, _In_ OrtSession* session, _In_ size_t index, _Out_ OrtExecutionProvider** provider);
Expand Down Expand Up @@ -133,6 +134,12 @@ ORT_API_STATUS(JoinModels,
size_t num_linkages,
bool promote_unlinked_outputs,
_In_ const char* const join_node_prefix);

ORT_API_STATUS(CreateThreadPool,
ThreadPoolType type,
OrtThreadPoolOptions* params,
_Outptr_ OrtThreadPool** out);

// maps and sequences???
//ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange().Map().at(ONNX_NAMESPACE::ONNX_DOMAIN).second

Expand Down
4 changes: 3 additions & 1 deletion winml/adapter/winml_adapter_c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,11 @@ static constexpr WinmlAdapterApi winml_adapter_api_1 = {
&winmla::OperatorGetNumOutputs,
&winmla::OperatorGetOutputName,
&winmla::JoinModels,
&winmla::CreateThreadPool,

// Release
&winmla::ReleaseModel
&winmla::ReleaseModel,
&winmla::ReleaseThreadPool,
};

const WinmlAdapterApi* ORT_API_CALL OrtGetWinMLAdapter(_In_ uint32_t ort_version) NO_EXCEPTION {
Expand Down
41 changes: 40 additions & 1 deletion winml/adapter/winml_adapter_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

ORT_RUNTIME_CLASS(Model);
ORT_RUNTIME_CLASS(ExecutionProvider);
ORT_RUNTIME_CLASS(ThreadPool);

struct WinmlAdapterApi;
typedef struct WinmlAdapterApi WinmlAdapterApi;
Expand Down Expand Up @@ -44,6 +45,37 @@ struct OrtProfilerEventRecord {

typedef void(ORT_API_CALL* OrtProfilingFunction)(const OrtProfilerEventRecord* event_record);

enum class ThreadPoolType : uint8_t {
INTRA_OP,
INTER_OP
};

struct OrtThreadPoolOptions {
//0: Use default setting. (All the physical cores or half of the logical cores)
//1: Don't create thread pool
//n: Create a thread pool with n threads.
int thread_pool_size = 0;
//If it is true and thread_pool_size = 0, populate the thread affinity information in ThreadOptions.
//Otherwise if the thread_options has affinity information, we'll use it and set it.
//In the other case, don't set affinity
bool auto_set_affinity = false;
//If it is true, the thread pool will spin a while after the queue became empty.
bool allow_spinning = true;
//It it is non-negative, thread pool will split a task by a decreasing block size
//of remaining_of_total_iterations / (num_of_threads * dynamic_block_base_)
int dynamic_block_base_ = 0;

unsigned int stack_size = 0;
//Index is thread id, value is processor ID
//If the vector is empty, no explict affinity binding
size_t* affinity_vec = nullptr;
size_t affinity_vec_len = 0;
const ORTCHAR_T* name = nullptr;

// Set or unset denormal as zero
bool set_denormal_as_zero = false;
};

struct WinmlAdapterApi {
/**
* OverrideSchema
Expand Down Expand Up @@ -237,7 +269,8 @@ struct WinmlAdapterApi {
* c-abi, WinML uses this so that it can perform optimizations prior to loading the model, and initializing.
* Moreover, WinML needs a new api to support the OrtModel type, and prevent the parsing model protobufs again on session creation.
*/
OrtStatus*(ORT_API_CALL* CreateSessionWithoutModel)(_In_ OrtEnv* env, _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** session)NO_EXCEPTION;
OrtStatus*(ORT_API_CALL* CreateSessionWithoutModel)(_In_ OrtEnv* env, _In_ const OrtSessionOptions* options,
_In_ OrtThreadPool* inter_op_thread_pool, _In_ OrtThreadPool* intra_op_thread_pool, _Outptr_ OrtSession** session)NO_EXCEPTION;

/**
* SessionGetExecutionProvider
Expand Down Expand Up @@ -471,5 +504,11 @@ struct WinmlAdapterApi {
bool promote_unlinked_outputs,
_In_ const char* const join_node_prefix)NO_EXCEPTION;

OrtStatus*(ORT_API_CALL* CreateThreadPool)(
_In_ ThreadPoolType type,
_In_ OrtThreadPoolOptions* params,
_Outptr_ OrtThreadPool** out)NO_EXCEPTION;

ORT_CLASS_RELEASE(Model);
ORT_CLASS_RELEASE(ThreadPool);
};
Loading

0 comments on commit 6255194

Please sign in to comment.