Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize Fence checking performance #1593

Merged
merged 2 commits into from
Aug 9, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions onnxruntime/core/framework/allocation_planner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,9 @@ class PlannerImpl {
// Initialize execution plan:
plan_.execution_plan.reserve(num_graph_nodes);

// Initialize node_has_fence.
plan_.node_has_fence.resize(graph_viewer_.MaxNodeIndex());

// Initialize allocation plan:
plan_.allocation_plan.resize(num_ml_values);
}
Expand Down Expand Up @@ -585,6 +588,51 @@ class PlannerImpl {
return Status::OK();
}

// Whether a given NodeArg has fence or not.
// If the buffer is reused, need to check whether original OrtValue has fence or not.
bool HasFence(const onnxruntime::NodeArg* arg) {
bool has_fence = false;
if (arg && arg->Exists()) {
OrtValueIndex index = Index(arg->Name());
AllocPlanPerValue& value_plan = AllocPlan(index);

has_fence = value_plan.create_fence_if_async;
if (value_plan.alloc_kind == AllocKind::kReuse)
{
// Buffer reused, check original buffer to see if fence is shared.
has_fence = has_fence || AllocPlan(value_plan.reused_buffer).create_fence_if_async;
}
}

return has_fence;
}

// Compute fence check. Set has_fence flag if either one of inputs, implicit inputs or outputs of a given node has fence.
Status ComputeFenceCheck() {

for (SequentialExecutionPlan::NodeExecutionPlan& step : plan_.execution_plan) {
auto pnode = graph_viewer_.GetNode(step.node_index);
if (pnode == nullptr) return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Can not find the node ", step.node_index);

bool has_fence = false;
for (auto node_input : pnode->InputDefs()) {
has_fence = has_fence || HasFence(node_input);
}

for (auto node_input : pnode->ImplicitInputDefs()) {
has_fence = has_fence || HasFence(node_input);
}

for (auto node_output : pnode->OutputDefs()) {
has_fence = has_fence || HasFence(node_output);
}

plan_.node_has_fence[step.node_index] = has_fence;
}

return Status::OK();
}

// Convert information in a freelist (about which ml-value becomes free when) into
// a deallocation plan in the format required in an ExecutionPlan
void GenerateDeallocationPlan() {
Expand Down Expand Up @@ -642,6 +690,9 @@ Status PlannerImpl::CreatePlan() {
// determine sharing/reuse among ml-values
ORT_RETURN_IF_ERROR(ComputeReusePlan());

// Determine nodes that need fence check. This needs to be done after ComputeUseCounts and ComputeReusePlan.
ORT_RETURN_IF_ERROR(ComputeFenceCheck());

// convert information in the freelist_ into a deallocation plan in required format
GenerateDeallocationPlan();

Expand Down
76 changes: 41 additions & 35 deletions onnxruntime/core/framework/parallel_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ Status ParallelExecutor::RunNodeAsync(size_t p_node_index,
TimePoint sync_time_begin;
TimePoint kernel_begin_time;
const bool f_profiler_enabled = session_state.Profiler().IsEnabled();
const SequentialExecutionPlan& exec_plan = *session_state.GetExecutionPlan();

// Avoid context switching if possible.
while (keep_running) {
Expand Down Expand Up @@ -149,33 +150,34 @@ Status ParallelExecutor::RunNodeAsync(size_t p_node_index,
}
// sync before compute
int queue_id = p_op_kernel->KernelDef().ExecQueueId();

for (int input_index = 0; input_index < op_kernel_context.InputCount(); ++input_index) {
Fence_t fence = op_kernel_context.InputFence(input_index);
if (fence) {
auto execution_provider_type = p_op_kernel->Node().GetExecutionProviderType();
if (OrtMemTypeCPUInput == p_op_kernel->KernelDef().InputMemoryType(input_index)) {
execution_provider_type = kCpuExecutionProvider;
if (exec_plan.NodeHasFence(node_index)) {
for (int input_index = 0; input_index < op_kernel_context.InputCount(); ++input_index) {
Fence_t fence = op_kernel_context.InputFence(input_index);
if (fence) {
auto execution_provider_type = p_op_kernel->Node().GetExecutionProviderType();
if (OrtMemTypeCPUInput == p_op_kernel->KernelDef().InputMemoryType(input_index)) {
execution_provider_type = kCpuExecutionProvider;
}
fence->BeforeUsingAsInput(execution_provider_type, queue_id);
}
fence->BeforeUsingAsInput(execution_provider_type, queue_id);
}
}

for (int input_index = 0; input_index < op_kernel_context.ImplicitInputCount(); ++input_index) {
Fence_t fence = op_kernel_context.ImplicitInputFence(input_index);
if (fence) {
auto execution_provider_type = p_op_kernel->Node().GetExecutionProviderType();
if (OrtMemTypeCPUInput == p_op_kernel->KernelDef().InputMemoryType(input_index)) {
execution_provider_type = kCpuExecutionProvider;
for (int input_index = 0; input_index < op_kernel_context.ImplicitInputCount(); ++input_index) {
Fence_t fence = op_kernel_context.ImplicitInputFence(input_index);
if (fence) {
auto execution_provider_type = p_op_kernel->Node().GetExecutionProviderType();
if (OrtMemTypeCPUInput == p_op_kernel->KernelDef().InputMemoryType(input_index)) {
execution_provider_type = kCpuExecutionProvider;
}
fence->BeforeUsingAsInput(execution_provider_type, queue_id);
}
fence->BeforeUsingAsInput(execution_provider_type, queue_id);
}
}

for (int output_index = 0; output_index < op_kernel_context.OutputCount(); ++output_index) {
Fence_t fence = op_kernel_context.OutputFence(output_index);
if (fence) {
fence->BeforeUsingAsOutput(p_op_kernel->Node().GetExecutionProviderType(), queue_id);
for (int output_index = 0; output_index < op_kernel_context.OutputCount(); ++output_index) {
Fence_t fence = op_kernel_context.OutputFence(output_index);
if (fence) {
fence->BeforeUsingAsOutput(p_op_kernel->Node().GetExecutionProviderType(), queue_id);
}
}
}

Expand Down Expand Up @@ -209,32 +211,36 @@ Status ParallelExecutor::RunNodeAsync(size_t p_node_index,
sync_time_begin = session_state.Profiler().StartTime();
}
// sync after compute for outputs
for (int input_index = 0; input_index < op_kernel_context.InputCount(); ++input_index) {
Fence_t fence = op_kernel_context.InputFence(input_index);
if (fence) {
fence->AfterUsedAsInput(queue_id);
if (exec_plan.NodeHasFence(node_index)) {
for (int input_index = 0; input_index < op_kernel_context.InputCount(); ++input_index) {
Fence_t fence = op_kernel_context.InputFence(input_index);
if (fence) {
fence->AfterUsedAsInput(queue_id);
}
}
}

for (int input_index = 0; input_index < op_kernel_context.ImplicitInputCount(); ++input_index) {
Fence_t fence = op_kernel_context.ImplicitInputFence(input_index);
if (fence) {
fence->AfterUsedAsInput(queue_id);
for (int input_index = 0; input_index < op_kernel_context.ImplicitInputCount(); ++input_index) {
Fence_t fence = op_kernel_context.ImplicitInputFence(input_index);
if (fence) {
fence->AfterUsedAsInput(queue_id);
}
}
}

for (int output_index = 0; output_index < op_kernel_context.OutputCount(); ++output_index) {
Fence_t fence = op_kernel_context.OutputFence(output_index);
if (fence) {
fence->AfterUsedAsOutput(queue_id);
for (int output_index = 0; output_index < op_kernel_context.OutputCount(); ++output_index) {
Fence_t fence = op_kernel_context.OutputFence(output_index);
if (fence) {
fence->AfterUsedAsOutput(queue_id);
}
}
}

if (f_profiler_enabled) {
session_state.Profiler().EndTimeAndRecordEvent(profiling::NODE_EVENT,
p_op_kernel->Node().Name() + "_fence_after",
sync_time_begin,
{{"op_name", p_op_kernel->KernelDef().OpName()}});
}

//std::cout << "Run async node finish: " << p_node_index << std::endl;

keep_running = false;
Expand Down
9 changes: 9 additions & 0 deletions onnxruntime/core/framework/sequential_execution_plan.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ struct SequentialExecutionPlan : public ExecutionPlanBase {
// Execution_plan: represents the nodes in the sequential order to be executed
std::vector<NodeExecutionPlan> execution_plan;

// Records whether a given node has fence on its input or output, key is node index.
std::vector<bool> node_has_fence;

// to_be_freed: vector elements represent indices of ml-values to be freed (as described above)
std::vector<OrtValueIndex> to_be_freed;

Expand All @@ -84,6 +87,12 @@ struct SequentialExecutionPlan : public ExecutionPlanBase {
}
return locations;
}

// Whether a given node needs fence check or not.
bool NodeHasFence(onnxruntime::NodeIndex node_index) const {
return node_has_fence[node_index];
}

};

// Output details of an execution plan:
Expand Down
72 changes: 38 additions & 34 deletions onnxruntime/core/framework/sequential_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,32 +71,34 @@ Status SequentialExecutor::Execute(const SessionState& session_state, const std:

// sync before compute
int queue_id = p_op_kernel->KernelDef().ExecQueueId();
for (int input_index = 0; input_index < op_kernel_context.InputCount(); ++input_index) {
Fence_t fence = op_kernel_context.InputFence(input_index);
if (fence) {
auto execution_provider_type = p_op_kernel->Node().GetExecutionProviderType();
if (OrtMemTypeCPUInput == p_op_kernel->KernelDef().InputMemoryType(input_index)) {
execution_provider_type = kCpuExecutionProvider;
if (seq_exec_plan.NodeHasFence(node_index)) {
for (int input_index = 0; input_index < op_kernel_context.InputCount(); ++input_index) {
Fence_t fence = op_kernel_context.InputFence(input_index);
if (fence) {
auto execution_provider_type = p_op_kernel->Node().GetExecutionProviderType();
if (OrtMemTypeCPUInput == p_op_kernel->KernelDef().InputMemoryType(input_index)) {
execution_provider_type = kCpuExecutionProvider;
}
fence->BeforeUsingAsInput(execution_provider_type, queue_id);
}
fence->BeforeUsingAsInput(execution_provider_type, queue_id);
}
}

for (int input_index = 0; input_index < op_kernel_context.ImplicitInputCount(); ++input_index) {
Fence_t fence = op_kernel_context.ImplicitInputFence(input_index);
if (fence) {
auto execution_provider_type = p_op_kernel->Node().GetExecutionProviderType();
if (OrtMemTypeCPUInput == p_op_kernel->KernelDef().InputMemoryType(input_index)) {
execution_provider_type = kCpuExecutionProvider;
for (int input_index = 0; input_index < op_kernel_context.ImplicitInputCount(); ++input_index) {
Fence_t fence = op_kernel_context.ImplicitInputFence(input_index);
if (fence) {
auto execution_provider_type = p_op_kernel->Node().GetExecutionProviderType();
if (OrtMemTypeCPUInput == p_op_kernel->KernelDef().InputMemoryType(input_index)) {
execution_provider_type = kCpuExecutionProvider;
}
fence->BeforeUsingAsInput(execution_provider_type, queue_id);
}
fence->BeforeUsingAsInput(execution_provider_type, queue_id);
}
}

for (int output_index = 0; output_index < op_kernel_context.OutputCount(); ++output_index) {
Fence_t fence = op_kernel_context.OutputFence(output_index);
if (fence) {
fence->BeforeUsingAsOutput(p_op_kernel->Node().GetExecutionProviderType(), queue_id);
for (int output_index = 0; output_index < op_kernel_context.OutputCount(); ++output_index) {
Fence_t fence = op_kernel_context.OutputFence(output_index);
if (fence) {
fence->BeforeUsingAsOutput(p_op_kernel->Node().GetExecutionProviderType(), queue_id);
}
}
}

Expand Down Expand Up @@ -138,24 +140,26 @@ Status SequentialExecutor::Execute(const SessionState& session_state, const std:
}

// sync after compute for outputs
for (int input_index = 0; input_index < op_kernel_context.InputCount(); ++input_index) {
Fence_t fence = op_kernel_context.InputFence(input_index);
if (fence) {
fence->AfterUsedAsInput(queue_id);
if (seq_exec_plan.NodeHasFence(node_index)) {
for (int input_index = 0; input_index < op_kernel_context.InputCount(); ++input_index) {
Fence_t fence = op_kernel_context.InputFence(input_index);
if (fence) {
fence->AfterUsedAsInput(queue_id);
}
}
}

for (int input_index = 0; input_index < op_kernel_context.ImplicitInputCount(); ++input_index) {
Fence_t fence = op_kernel_context.ImplicitInputFence(input_index);
if (fence) {
fence->AfterUsedAsInput(queue_id);
for (int input_index = 0; input_index < op_kernel_context.ImplicitInputCount(); ++input_index) {
Fence_t fence = op_kernel_context.ImplicitInputFence(input_index);
if (fence) {
fence->AfterUsedAsInput(queue_id);
}
}
}

for (int output_index = 0; output_index < op_kernel_context.OutputCount(); ++output_index) {
Fence_t fence = op_kernel_context.OutputFence(output_index);
if (fence) {
fence->AfterUsedAsOutput(queue_id);
for (int output_index = 0; output_index < op_kernel_context.OutputCount(); ++output_index) {
Fence_t fence = op_kernel_context.OutputFence(output_index);
if (fence) {
fence->AfterUsedAsOutput(queue_id);
}
}
}

Expand Down