Skip to content

Commit 1748e06

Browse files
committed
feat: Enable EpContext OVIR Encapsulation
1 parent 9b245a4 commit 1748e06

File tree

6 files changed

+73
-6
lines changed

6 files changed

+73
-6
lines changed

onnxruntime/core/providers/openvino/backend_utils.cc

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,33 @@ void DestroyOVTensors(SharedContext::SharedWeights::Metadata::Map& metadata_map)
440440
metadata_map.clear();
441441
}
442442

443+
bool IsModelStreamXML(std::istream& model_stream) {
444+
std::streampos originalPos = model_stream.tellg();
445+
446+
// first, get the total size of model_stream in bytes
447+
model_stream.seekg(0, std::ios::end);
448+
auto end_pos = model_stream.tellg();
449+
// Restore the stream position
450+
model_stream.seekg(originalPos);
451+
auto total_size = end_pos - originalPos;
452+
453+
// Choose 32 bytes to hold content of:
454+
// '<?xml version-"1.0"?> <net '
455+
const std::streamsize header_check_len = 32;
456+
ORT_ENFORCE(total_size > header_check_len);
457+
458+
// read 32 bytes into header
459+
std::string header(header_check_len, '\0');
460+
model_stream.read(&header[0], header_check_len);
461+
// Clear any read errors
462+
model_stream.clear();
463+
// Restore the stream position
464+
model_stream.seekg(originalPos);
465+
466+
// return true if the header starts with '<?xml' and also includes '<net '
467+
return ((header.rfind("<?xml", 0) == 0) && (header.find("<net ") != std::string::npos));
468+
}
469+
443470
} // namespace backend_utils
444471
} // namespace openvino_ep
445472
} // namespace onnxruntime

onnxruntime/core/providers/openvino/backend_utils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ void printPerformanceCounts(const std::vector<OVProfilingInfo>& performanceMap,
7676

7777
void printPerformanceCounts(OVInferRequestPtr request, std::ostream& stream, std::string deviceName);
7878

79+
bool IsModelStreamXML(std::istream& model_stream);
80+
7981
} // namespace backend_utils
8082
} // namespace openvino_ep
8183
} // namespace onnxruntime

onnxruntime/core/providers/openvino/backends/basic_backend.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ BasicBackend::BasicBackend(std::unique_ptr<ONNX_NAMESPACE::ModelProto>& model_pr
7676
exe_network_ = OVCore::Get()->ImportModel(*model_stream,
7777
hw_target,
7878
device_config,
79-
subgraph_context_.subgraph_name);
79+
enable_causallm,
80+
session_context_.onnx_model_path_name.string());
8081
model_stream.reset(); // Delete stream after it is no longer needed
8182
} else if (!session_context_.has_external_weights &&
8283
!subgraph_context_.has_dynamic_input_shape &&

onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <algorithm>
88

99
#include "core/providers/openvino/onnx_ctx_model_helper.h"
10+
#include "core/providers/openvino/backend_utils.h"
1011

1112
namespace onnxruntime {
1213
namespace openvino_ep {
@@ -123,6 +124,16 @@ std::unique_ptr<std::istream> EPCtxHandler::GetModelBlobStream(const std::filesy
123124
ORT_ENFORCE(std::filesystem::exists(blob_filepath), "Blob file not found: ", blob_filepath.string());
124125
result.reset((std::istream*)new std::ifstream(blob_filepath, std::ios_base::binary | std::ios_base::in));
125126
}
127+
128+
bool isXML = backend_utils::IsModelStreamXML(*result);
129+
if (!isXML) {
130+
// If the model stream is not an XML (i.e. precompiled blob), the OpenVINO SDK version that it was
131+
// exported with must match the version that is currently running.
132+
ORT_ENFORCE((attrs.count(EP_SDK_VER) == 1) && (attrs.at(EP_SDK_VER).s() == openvino_sdk_version_),
133+
"EPCtx blob was exported / is compatible with with OpenVINO SDK version " + attrs.at(EP_SDK_VER).s() +
134+
", but OpenVINO SDK version currently in use is " + openvino_sdk_version_);
135+
}
136+
126137
LOGS_DEFAULT(VERBOSE) << "[OpenVINO EP] Read blob from EPContext Node";
127138
return result;
128139
}
@@ -142,7 +153,6 @@ bool EPCtxHandler::CheckForOVEPCtxNode(const Node& node) const {
142153
if (node.OpType() == EPCONTEXT_OP) {
143154
auto& attrs = node.GetAttributes();
144155
bool result = (attrs.count(SOURCE) == 1) && (attrs.at(SOURCE).s() == kOpenVINOExecutionProvider);
145-
result &= (attrs.count(EP_SDK_VER) == 1) && (attrs.at(EP_SDK_VER).s() == openvino_sdk_version_);
146156
result &= attrs.count(EMBED_MODE) == 1;
147157
result &= attrs.count(EP_CACHE_CONTEXT) == 1;
148158
return result;

onnxruntime/core/providers/openvino/ov_interface.cc

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,10 @@ OVExeNetwork OVCore::StatefulCompileModel(std::shared_ptr<OVNetwork>& model,
9595
LogBasicModelInfo(model);
9696
}
9797

98-
LOGS_DEFAULT(INFO) << log_tag << "Converting from Stateless OV Model to Stateful OV Model" << std::endl;
9998
bool model_status = IsStateful(model);
10099
LOGS_DEFAULT(INFO) << log_tag << "Model IsStateful() Status:\t" << (model_status ? "True" : "False");
101100
if (!model_status) {
101+
LOGS_DEFAULT(INFO) << log_tag << "Converting from Stateless OV Model to Stateful OV Model" << std::endl;
102102
PatchStatefulDecoder(model);
103103
}
104104

@@ -193,14 +193,40 @@ OVExeNetwork OVCore::CompileModel(const std::string& onnx_model,
193193
OVExeNetwork OVCore::ImportModel(std::istream& model_stream,
194194
std::string hw_target,
195195
const ov::AnyMap& device_config,
196+
bool enable_causallm,
196197
std::string name) {
197198
try {
198-
ov::CompiledModel obj;
199-
obj = core.import_model(model_stream, hw_target, device_config);
199+
OVExeNetwork exe;
200+
201+
bool isXML = backend_utils::IsModelStreamXML(model_stream);
202+
203+
if (!isXML) {
204+
auto obj = core.import_model(model_stream, hw_target, device_config);
205+
exe = OVExeNetwork(obj, hw_target);
206+
} else {
207+
// If the model is XML, we need to load it with the XML content in read_model()
208+
// where weights from bin file is directly consumed
209+
std::string xml_file_name = name;
210+
if (name.size() >= 5 && name.substr(name.size() - 5) == ".onnx") {
211+
xml_file_name.replace(name.size() - 5, 5, ".xml");
212+
} else {
213+
throw std::runtime_error("Invalid model name. Make sure *.onnx, *.xml, and *.bin carry the same name.");
214+
}
215+
216+
// Load the model explicitly with XML contents
217+
std::shared_ptr<ov::Model> model = core.read_model(xml_file_name);
218+
219+
if (enable_causallm) {
220+
exe = OVCore::Get()->StatefulCompileModel(model, hw_target, device_config);
221+
} else {
222+
auto obj = core.compile_model(model, hw_target, device_config);
223+
exe = OVExeNetwork(obj, hw_target);
224+
}
225+
}
226+
200227
#ifndef NDEBUG
201228
printDebugInfo(exe.Get());
202229
#endif
203-
OVExeNetwork exe(obj, hw_target);
204230
return exe;
205231
} catch (const Exception& e) {
206232
ORT_THROW(log_tag + " Exception while Loading Network for graph: " + name + e.what());

onnxruntime/core/providers/openvino/ov_interface.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ struct OVCore : WeakSingleton<OVCore> {
8282
OVExeNetwork ImportModel(std::istream& model_stream,
8383
std::string hw_target,
8484
const ov::AnyMap& device_config,
85+
bool enable_causallm,
8586
std::string name);
8687
std::vector<std::string> GetAvailableDevices() const;
8788
std::vector<std::string> GetAvailableDevices(const std::string& device_type) const;

0 commit comments

Comments
 (0)