Skip to content

Optimize CPU time spent in inference path (continued) #695

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Jun 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 17 additions & 21 deletions onnxruntime/core/providers/openvino/backend_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,6 @@ BackendManager::BackendManager(SessionContext& session_context,
shared_context_{shared_context} {
subgraph_context_.is_ep_ctx_graph = ep_ctx_handle_.CheckForOVEPCtxNodeInGraph(subgraph);

bool cpu_or_gpu = session_context_.device_type.find("CPU") != std::string::npos ||
session_context_.device_type.find("GPU") != std::string::npos;
bool npu = session_context_.device_type.find("NPU") != std::string::npos;

subgraph_context_.model_precision = [&](const GraphViewer& graph_viewer) {
// return empty if graph has no inputs or if types are not one of FP32/FP16
// else assume the type of the first input
Expand Down Expand Up @@ -112,8 +108,7 @@ BackendManager::BackendManager(SessionContext& session_context,
if (ModelHasSymbolicInputDims(subgraph)) {
subgraph_context_.has_dynamic_input_shape = true;
LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Model has symbolic input dims";
if (cpu_or_gpu || (npu && session_context_.enable_causallm) &&
!session_context_.disable_dynamic_shapes) {
if (!session_context_.disable_dynamic_shapes) {
LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Starting backend initialization. "
<< "Creating backend Dynamic Shapes";
try {
Expand Down Expand Up @@ -579,30 +574,34 @@ void BackendManager::ValidateInputShapes(const reshape_t& shapes,
void BackendManager::Compute(OrtKernelContext* context) {
Ort::KernelContext ctx(context);
std::chrono::high_resolution_clock::time_point start_compute, end_compute;
bool cpu_or_gpu = session_context_.device_type.find("CPU") != std::string::npos ||
session_context_.device_type.find("GPU") != std::string::npos;
bool npu = session_context_.device_type.find("NPU") != std::string::npos;

#ifdef OPENVINO_FIL_ENABLED
static bool fil_enabled = true;
if (fil_enabled) {
start_compute = std::chrono::high_resolution_clock::now();
LOGS_DEFAULT(INFO) << "Start Compute";
}
#endif
// OV NPU doesn't support dynamic shaped model inference.

// if disable_dynamic_shapes is set to true then execution of dynamic model is done
// by rewriting the model to static shaped model at runtime based on input shape.
// disable_dynamic_shapes is always set to true for OV NPU plugin.
if (subgraph_context_.has_dynamic_input_shape &&
!session_context_.disable_dynamic_shapes &&
(cpu_or_gpu || (npu && session_context_.enable_causallm))) {
// disable_dynamic_shapes should be set for devices that don't support dynamic shapes.
bool need_dynamic_backend = subgraph_context_.has_dynamic_input_shape &&
session_context_.disable_dynamic_shapes;

if (!need_dynamic_backend) {
concrete_backend_->Infer(context);
} else if (subgraph_context_.has_dynamic_input_shape) {
} else {
std::vector<std::vector<int64_t>> tensor_shapes = GetInputTensorShapes(ctx);
auto key = MakeMapKeyString(tensor_shapes, session_context_.device_type);
std::shared_ptr<IBackend> dynamic_backend;
auto search = backend_map_.find(key);
if (search == backend_map_.end()) {

{
std::unique_lock<std::mutex> lock(mutex_);
dynamic_backend = backend_map_[key];
}

if (!dynamic_backend) {
ptr_stream_t model_stream;
LOGS_DEFAULT(INFO) << "[OpenVINO-EP] "
<< "Creating dynamic backend for key: " << key;
Expand Down Expand Up @@ -643,14 +642,11 @@ void BackendManager::Compute(OrtKernelContext* context) {
}
#endif
}
std::unique_lock<std::mutex> lock(mutex_);
backend_map_.insert({key, dynamic_backend});
} else {
dynamic_backend = search->second;
}

dynamic_backend->Infer(context);
} else {
concrete_backend_->Infer(context);
}
#ifdef OPENVINO_FIL_ENABLED
if (fil_enabled) {
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/providers/openvino/backend_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class BackendManager {

std::unique_ptr<ONNX_NAMESPACE::ModelProto> model_proto_;
std::shared_ptr<IBackend> concrete_backend_;
std::mutex mutex_;
std::map<std::string, std::shared_ptr<IBackend>> backend_map_;
SubGraphContext subgraph_context_;
EPCtxHandler& ep_ctx_handle_;
Expand Down
44 changes: 2 additions & 42 deletions onnxruntime/core/providers/openvino/backend_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -179,32 +179,6 @@ CreateOVModel(std::string&& model,
}
}

Ort::UnownedValue
GetOutputTensor(Ort::KernelContext& context, size_t batch_size,
OVInferRequestPtr infer_request,
std::string output_name,
const SubGraphContext::string_index_map_t& output_names) {
auto graph_output_blob = infer_request->GetTensor(output_name);

auto graph_output_dims = graph_output_blob->get_shape();

if (batch_size > 1) {
// Add the batch size as dim 0.
graph_output_dims.insert(graph_output_dims.begin(), batch_size);
}
size_t num_dims = graph_output_dims.size();
std::unique_ptr<int64_t[]> output_shape(new int64_t[num_dims]);
for (size_t j = 0; j < num_dims; j++) {
output_shape[j] = static_cast<int64_t>(graph_output_dims[j]);
}
auto it = output_names.find(output_name);
if (it == output_names.end()) {
ORT_THROW(log_tag + "Output names mismatch between OpenVINO and ONNX");
}
int index = it->second;
return context.GetOutput(index, output_shape.get(), num_dims);
}

Ort::UnownedValue
GetOutputTensor(Ort::KernelContext& context,
std::string output_name,
Expand All @@ -220,14 +194,9 @@ GetOutputTensor(Ort::KernelContext& context,
ORT_THROW(log_tag + "Output names mismatch between OpenVINO and ONNX");
}
int index = it->second;
auto shape = node->get_shape();
auto output_shape = ParameterShape::ToOrtShape(node->get_shape());

size_t num_dims = shape.size();
std::unique_ptr<int64_t[]> output_shape(new int64_t[num_dims]);
for (size_t j = 0; j < num_dims; j++) {
output_shape[j] = static_cast<int64_t>(shape[j]);
}
return context.GetOutput(index, output_shape.get(), num_dims);
return context.GetOutput(index, output_shape);
}

int GetFirstAvailableDevice(SessionContext& session_context) {
Expand Down Expand Up @@ -312,15 +281,6 @@ void FillInputBlob(OVTensorPtr inputBlob, size_t batch_slice_idx,
std::memcpy(input_data, batch_memory_offset, input_data_size);
}

void FillOutputBlob(OVTensorPtr outputBlob, Ort::UnownedValue& output_tensor,
size_t batch_slice_idx) {
auto output_data = outputBlob->data();
size_t output_data_size = outputBlob->get_byte_size();
char* tensor_data = output_tensor.GetTensorMutableData<char>();
char* batch_memory_offset = tensor_data + output_data_size * batch_slice_idx;
std::memcpy(batch_memory_offset, output_data, output_data_size);
}

void printPerformanceCounts(const std::vector<OVProfilingInfo>& performanceMap,
std::ostream& stream, std::string deviceName) {
int64_t totalTime = 0;
Expand Down
51 changes: 41 additions & 10 deletions onnxruntime/core/providers/openvino/backend_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,48 @@

namespace onnxruntime {
namespace openvino_ep {
constexpr std::string log_tag = "[OpenVINO-EP] ";

struct ParameterShape {
using ort_shape_t = std::vector<int64_t>;

static ov::PartialShape ToOvPartialShape(const ort_shape_t& ort_shape) {
std::vector<ov::Dimension> ov_shape(ort_shape.size());
std::transform(ort_shape.begin(), ort_shape.end(), ov_shape.begin(), [](int64_t dim) {
return dim == -1 ? ov::Dimension::dynamic() : ov::Dimension(dim);
});
return ov::PartialShape(ov_shape);
}

static ort_shape_t ToOrtShape(const ov::PartialShape& ov_shape) {
ort_shape_t ort_shape(ov_shape.size());
std::transform(ov_shape.begin(), ov_shape.end(), ort_shape.begin(), [](const auto& dim) {
return dim.is_dynamic() ? -1 : dim.get_length();
});
return ort_shape;
}

static ort_shape_t ToOrtShape(const ov::Shape& ov_shape) {
ort_shape_t ort_shape(ov_shape.size());
std::transform(ov_shape.begin(), ov_shape.end(), ort_shape.begin(), [](const auto& dim) {

Check warning on line 53 in onnxruntime/core/providers/openvino/backend_utils.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <algorithm> for transform [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/openvino/backend_utils.h:53: Add #include <algorithm> for transform [build/include_what_you_use] [4]
return narrow<int64_t>(dim);
});
return ort_shape;
}

operator ov::Shape() const { return ov_.get_shape(); }
operator const ov::PartialShape&() const { return ov_; }
operator const ort_shape_t&() const { return ort_; }

explicit ParameterShape(const ort_shape_t& ort_shape) : ort_(ort_shape), ov_(ToOvPartialShape(ort_shape)) {}
explicit ParameterShape(const ov::PartialShape& ov_partial_shape) : ov_(ov_partial_shape), ort_(ToOrtShape(ov_partial_shape)) {}

private:
ort_shape_t ort_;
ov::PartialShape ov_;
};

namespace backend_utils {
const std::string log_tag = "[OpenVINO-EP] ";

bool IsDebugEnabled();

Expand All @@ -48,19 +88,10 @@
const SubGraphContext::string_index_map_t& output_names,
std::shared_ptr<ov::Node> node);

Ort::UnownedValue
GetOutputTensor(Ort::KernelContext& context, size_t batch_size,
OVInferRequestPtr infer_request,
std::string output_name,
const SubGraphContext::string_index_map_t& output_names);

void FillInputBlob(OVTensorPtr inputBlob, size_t batch_slice_idx,
std::string input_name, Ort::KernelContext& context,
const SubGraphContext& subgraph_context);

void FillOutputBlob(OVTensorPtr outputBlob, Ort::UnownedValue& output_tensor,
size_t batch_slice_idx);

std::shared_ptr<const OVNetwork>
CreateOVModel(std::string&& model,
const SessionContext& session_context,
Expand Down
Loading
Loading