Skip to content

Commit 9ee331a

Browse files
authored
Revert "Add support for session option ep.stop_context_sharing (#655)" (#674)
This reverts commit 269f6fe.
1 parent e59c069 commit 9ee331a

File tree

10 files changed

+99
-152
lines changed

10 files changed

+99
-152
lines changed

onnxruntime/core/providers/openvino/backend_manager.cc

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -83,23 +83,22 @@ BackendManager::BackendManager(SessionContext& session_context,
8383
}
8484
std::string device_type = session_context_.device_type;
8585

86-
// Check if model is using external weights
87-
if (auto filename = backend_utils::GetExternalWeightFilename(subgraph)) {
88-
std::filesystem::path weights_filepath = session_context_.onnx_model_path_name.parent_path() / filename.value();
89-
90-
// Initialize external weights with fully qualified path
91-
if (!std::filesystem::exists(weights_filepath)) {
92-
ORT_THROW("Error: Failed to locate weight file at ", weights_filepath.string());
86+
auto& sw = shared_context_.shared_weights;
87+
if (session_context_.so_share_ep_contexts) {
88+
std::filesystem::path weight_filename = session_context_.onnx_model_path_name.parent_path();
89+
if (sw.external_weight_filename.empty() && !sw.metadata.empty()) {
90+
// Reasonable assumption that all metadata entries have the same external file location
91+
sw.external_weight_filename = sw.metadata.begin()->second.location;
9392
}
93+
weight_filename /= sw.external_weight_filename;
94+
std::ifstream weight_file(weight_filename);
9495

95-
external_weights_.emplace(weights_filepath);
96-
}
97-
98-
if (session_context_.so_share_ep_contexts) {
99-
ORT_ENFORCE(external_weights_.has_value(), "Expected external weight object to be valid");
100-
backend_utils::CreateOVTensors(session_context_.device_type,
101-
shared_context_.shared_weights.metadata,
102-
external_weights_.value());
96+
if (weight_file) {
97+
if (!sw.mapped_weights) {
98+
sw.mapped_weights = std::make_unique<SharedContext::SharedWeights::WeightsFile>(weight_filename);
99+
}
100+
backend_utils::CreateOVTensors(session_context_.device_type, sw.metadata, *sw.mapped_weights);
101+
}
103102
}
104103

105104
if (ModelHasSymbolicInputDims(subgraph)) {
@@ -325,7 +324,7 @@ static bool IsQDQGraph(const onnxruntime::GraphViewer& graph_viewer) {
325324
static void DumpOpenVINOEPModel([[maybe_unused]] const std::filesystem::path& onnx_model_path_name,
326325
[[maybe_unused]] ONNX_NAMESPACE::ModelProto* model_proto,
327326
[[maybe_unused]] const onnxruntime::Node& fused_node) {
328-
#ifdef NOT_RELEASE
327+
#ifndef RELEASE
329328
if (openvino_ep::backend_utils::IsDebugEnabled()) {
330329
auto model_name = onnx_model_path_name.empty() ? "unknown.onnx" : onnx_model_path_name.filename();
331330

@@ -385,12 +384,7 @@ BackendManager::GetModelProtoFromFusedNode(const onnxruntime::Node& fused_node,
385384
if (session_context_.device_type.find("NPU") != std::string::npos &&
386385
(enable_ovep_qdq_optimizer || session_context_.so_share_ep_contexts)) {
387386
std::unique_ptr<onnxruntime::Model> model;
388-
Status status = CreateModelWithStrippedQDQNodes(subgraph,
389-
logger,
390-
session_context_.so_share_ep_contexts,
391-
enable_ovep_qdq_optimizer,
392-
model,
393-
shared_context_.shared_weights.metadata);
387+
Status status = CreateModelWithStrippedQDQNodes(subgraph, logger, session_context_.so_share_ep_contexts, enable_ovep_qdq_optimizer, model, shared_context_.shared_weights);
394388
auto model_proto = model->ToProto();
395389
model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
396390
print_model_proto_duration();

onnxruntime/core/providers/openvino/backend_manager.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ class BackendManager {
5454
EPCtxHandler& ep_ctx_handle_;
5555
SessionContext& session_context_;
5656
SharedContext& shared_context_;
57-
std::optional<fs::path> external_weights_;
5857
};
5958

6059
} // namespace openvino_ep

onnxruntime/core/providers/openvino/backend_utils.cc

Lines changed: 24 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
#include <sstream>
55
#include <fstream>
66
#include <utility>
7-
#include <string>
87

98
#include <filesystem>
109
#include <stdexcept>
@@ -21,7 +20,22 @@ using Exception = ov::Exception;
2120
namespace onnxruntime {
2221
namespace openvino_ep {
2322

24-
std::ostream& operator<<(std::ostream& stream, const Metadata::Map& metadata) {
23+
SharedContext::SharedWeights::WeightsFile::WeightsFile(std::filesystem::path filename) : file_(filename, std::ios::in | std::ios::binary) {
24+
try {
25+
file_.exceptions(std::ifstream::failbit | std::ifstream::badbit);
26+
weights_size_ = file_.seekg(0, std::ios::end).tellg();
27+
} catch (std::ifstream::failure& e) {
28+
ORT_THROW("Error: Failed to open weight file at ", filename.string(), " ", e.what());
29+
}
30+
}
31+
32+
void SharedContext::SharedWeights::WeightsFile::load_weights(size_t file_offset, void* data, size_t size) {
33+
ORT_ENFORCE(file_offset < weights_size_ && size <= weights_size_ && (file_offset <= weights_size_ - size), "Error: File offset is out of bounds.");
34+
file_.seekg(file_offset);
35+
file_.read(reinterpret_cast<char*>(data), size);
36+
}
37+
38+
std::ostream& operator<<(std::ostream& stream, const SharedContext::SharedWeights::Metadata::Map& metadata) {
2539
try {
2640
stream << metadata.size();
2741

@@ -55,14 +69,14 @@ std::ostream& operator<<(std::ostream& stream, const Metadata::Map& metadata) {
5569
return stream;
5670
}
5771

58-
std::istream& operator>>(std::istream& stream, Metadata::Map& metadata) {
72+
std::istream& operator>>(std::istream& stream, SharedContext::SharedWeights::Metadata::Map& metadata) {
5973
size_t map_size{0};
6074
try {
6175
stream >> map_size;
6276

6377
while (!stream.eof()) {
64-
Metadata::Key key;
65-
Metadata::Value value;
78+
SharedContext::SharedWeights::Metadata::Key key;
79+
SharedContext::SharedWeights::Metadata::Value value;
6680
stream >> key.name;
6781
stream >> value.location;
6882
stream >> value.data_offset;
@@ -385,19 +399,8 @@ ov::element::Type GetOpenVINOElementType(ONNX_NAMESPACE::TensorProto_DataType dt
385399

386400
// Function to handle tensor creation from external data
387401
void CreateOVTensors(const std::string& device_name,
388-
Metadata::Map& metadata_map,
389-
std::filesystem::path& weights_filepath) {
390-
// File is guaranteed to exist at this point
391-
std::ifstream file(weights_filepath, std::ios::in | std::ios::binary);
392-
file.exceptions(std::ifstream::failbit | std::ifstream::badbit);
393-
size_t weights_size = std::filesystem::file_size(weights_filepath);
394-
395-
const auto load_weights = [&file, weights_size](size_t file_offset, void* data, size_t size) {
396-
ORT_ENFORCE(file_offset < weights_size && size <= weights_size && (file_offset <= weights_size - size), "Error: File offset is out of bounds.");
397-
file.seekg(file_offset);
398-
file.read(reinterpret_cast<char*>(data), size);
399-
};
400-
402+
SharedContext::SharedWeights::Metadata::Map& metadata_map,
403+
SharedContext::SharedWeights::WeightsFile& weights) {
401404
for (auto& [key, value] : metadata_map) {
402405
if (value.tensor) continue;
403406

@@ -413,18 +416,18 @@ void CreateOVTensors(const std::string& device_name,
413416
auto&& remote_tensor = npu_context.create_l0_host_tensor(ov_elementType, value.dimensions, ov::intel_npu::TensorType::INPUT);
414417

415418
// Copy data to remote tensor
416-
load_weights(value.data_offset, remote_tensor.get(), value.size);
419+
weights.load_weights(value.data_offset, remote_tensor.get(), value.size);
417420
value.tensor = std::make_shared<ov::Tensor>(remote_tensor);
418421
} else {
419422
// Use vanilla tensors
420423
value.tensor = std::make_shared<ov::Tensor>(ov_elementType, value.dimensions);
421-
load_weights(value.data_offset, value.tensor->data(), value.size);
424+
weights.load_weights(value.data_offset, value.tensor->data(), value.size);
422425
}
423426
ORT_ENFORCE(value.tensor->get_byte_size() == value.size, "Unexpected tensor size mismatch");
424427
}
425428
}
426429

427-
void DestroyOVTensors(Metadata::Map& metadata_map) {
430+
void DestroyOVTensors(SharedContext::SharedWeights::Metadata::Map& metadata_map) {
428431
for (auto& [key, value] : metadata_map) {
429432
if (value.tensor) {
430433
value.tensor.reset();
@@ -433,51 +436,6 @@ void DestroyOVTensors(Metadata::Map& metadata_map) {
433436
metadata_map.clear();
434437
}
435438

436-
std::optional<std::string> GetExternalWeightFilename(const GraphViewer& graph) {
437-
auto get_external_location = [](const ONNX_NAMESPACE::TensorProto& proto) -> std::optional<std::string> {
438-
using mutable_proto_t = ONNX_NAMESPACE::TensorProto*;
439-
auto& mutable_proto = *const_cast<mutable_proto_t>(&proto);
440-
auto* entry_protos = mutable_proto.mutable_external_data();
441-
442-
if (proto.has_data_location() && proto.data_location() == ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL) {
443-
for (int i = 0; i < entry_protos->size(); i++) {
444-
auto& string_entry_proto{entry_protos->at(i)};
445-
const auto& pb_key{*(string_entry_proto.mutable_key())};
446-
const auto& pb_value{*(string_entry_proto.mutable_value())};
447-
if (pb_key == "location") {
448-
return std::make_optional<std::string>(pb_value);
449-
}
450-
}
451-
}
452-
453-
return std::nullopt;
454-
};
455-
456-
// Handle constant initializers
457-
auto& initializers = graph.GetAllInitializedTensors();
458-
for (const auto& it : initializers) {
459-
if (auto result = get_external_location(*it.second)) {
460-
return result;
461-
}
462-
}
463-
464-
// Handle outer-scope constant initializers
465-
for (auto& node_idx : graph.GetNodesInTopologicalOrder()) {
466-
const auto& node = graph.GetNode(node_idx);
467-
for (const auto& input : node->InputDefs()) {
468-
if (graph.IsConstantInitializer(input->Name(), true)) {
469-
const auto& initializer_tensor = *graph.GetConstantInitializer(input->Name(), true);
470-
471-
if (auto result = get_external_location(initializer_tensor)) {
472-
return result;
473-
}
474-
}
475-
}
476-
}
477-
478-
return std::nullopt;
479-
}
480-
481439
} // namespace backend_utils
482440
} // namespace openvino_ep
483441
} // namespace onnxruntime

onnxruntime/core/providers/openvino/backend_utils.h

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,18 +67,15 @@ CreateOVModel(std::string&& model,
6767
std::map<std::string, std::shared_ptr<ov::Node>>& const_outputs_map);
6868

6969
void CreateOVTensors(const std::string& device_name,
70-
Metadata::Map& metadata_map,
71-
std::filesystem::path& weights_filepath);
72-
void DestroyOVTensors(Metadata::Map& metadata_map);
70+
SharedContext::SharedWeights::Metadata::Map& metadata_map,
71+
SharedContext::SharedWeights::WeightsFile& weights);
72+
void DestroyOVTensors(SharedContext::SharedWeights::Metadata::Map& metadata_map);
7373

7474
void printPerformanceCounts(const std::vector<OVProfilingInfo>& performanceMap,
7575
std::ostream& stream, std::string deviceName);
7676

7777
void printPerformanceCounts(OVInferRequestPtr request, std::ostream& stream, std::string deviceName);
7878

79-
// Returns the location string from the first external initializer nodes found or nullopt if none found
80-
std::optional<std::string> GetExternalWeightFilename(const GraphViewer& graph);
81-
8279
} // namespace backend_utils
8380
} // namespace openvino_ep
8481
} // namespace onnxruntime

onnxruntime/core/providers/openvino/backends/basic_backend.cc

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -125,12 +125,10 @@ BasicBackend::BasicBackend(std::unique_ptr<ONNX_NAMESPACE::ModelProto>& model_pr
125125
std::function<void(OVInferRequestPtr)> initializer = [](OVInferRequestPtr) {};
126126
auto metadata = shared_context_.shared_weights.metadata;
127127
if (session_context_.so_share_ep_contexts) {
128-
// When shared ep contexts is set external weight references are transformed to model inputs. This
129-
// creates an initializer to populate/bind input weight tensors to each inference request
130128
initializer = [&metadata](OVInferRequestPtr ir_ptr) {
131129
const auto input_count = ir_ptr->GetNumInputs();
132130
for (auto i = 0u; i < input_count; i++) {
133-
using Key = Metadata::Key;
131+
using Key = SharedContext::SharedWeights::Metadata::Key;
134132
const auto tensor_key = Key{ir_ptr->GetInputTensorName(i)};
135133
if (metadata.contains(tensor_key)) {
136134
auto& value = metadata.at(tensor_key);
@@ -139,8 +137,6 @@ BasicBackend::BasicBackend(std::unique_ptr<ONNX_NAMESPACE::ModelProto>& model_pr
139137
}
140138
};
141139
}
142-
143-
// Create inference request queue and initialize according to passed function
144140
inferRequestsQueue_ = std::unique_ptr<InferRequestsQueue>(new InferRequestsQueue(exe_network_, num_infer_req, std::move(initializer)));
145141
}
146142

onnxruntime/core/providers/openvino/contexts.h

Lines changed: 37 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -18,42 +18,52 @@ namespace openvino_ep {
1818

1919
namespace fs = std::filesystem;
2020

21-
struct Metadata {
22-
struct Key {
23-
std::string name;
24-
bool operator==(const Key&) const = default;
25-
};
26-
struct Hash {
27-
std::size_t operator()(const Key& key) const noexcept {
28-
return std::hash<std::string>()(key.name);
29-
}
30-
};
31-
struct Value {
32-
std::string location;
33-
unsigned int data_offset;
34-
unsigned int size;
35-
std::vector<size_t> dimensions;
36-
std::int32_t element_type;
37-
std::shared_ptr<ov::Tensor> tensor;
38-
};
39-
using Map = std::unordered_map<Key, Value, Hash>;
40-
friend std::ostream& operator<<(std::ostream& right, const Metadata::Map& metadata);
41-
friend std::istream& operator>>(std::istream& right, Metadata::Map& metadata);
42-
};
43-
4421
class SharedContext : public WeakSingleton<SharedContext> {
4522
// Keep the core alive as long as the shared SharedContext are alive.
4623
std::shared_ptr<OVCore> OVCore_;
4724

4825
public:
4926
SharedContext() : OVCore_(OVCore::Get()) {}
5027
struct SharedWeights {
28+
struct Metadata {
29+
struct Key {
30+
std::string name;
31+
bool operator==(const Key&) const = default;
32+
};
33+
struct Hash {
34+
std::size_t operator()(const Key& key) const noexcept {
35+
return std::hash<std::string>()(key.name);
36+
}
37+
};
38+
struct Value {
39+
std::string location;
40+
unsigned int data_offset;
41+
unsigned int size;
42+
std::vector<size_t> dimensions;
43+
std::int32_t element_type;
44+
std::shared_ptr<ov::Tensor> tensor;
45+
};
46+
using Map = std::unordered_map<Key, Value, Hash>;
47+
friend std::ostream& operator<<(std::ostream& right, const Metadata::Map& metadata);
48+
friend std::istream& operator>>(std::istream& right, Metadata::Map& metadata);
49+
};
50+
51+
struct WeightsFile {
52+
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(WeightsFile);
53+
WeightsFile() = delete;
54+
explicit WeightsFile(std::filesystem::path filename);
55+
56+
void load_weights(size_t file_offset, void* data, size_t size);
57+
58+
private:
59+
std::ifstream file_;
60+
size_t weights_size_;
61+
};
62+
63+
fs::path external_weight_filename;
64+
std::unique_ptr<WeightsFile> mapped_weights;
5165
Metadata::Map metadata;
5266
} shared_weights;
53-
54-
void clear() { // Deletes the data stored in the SharedContext
55-
shared_weights.metadata.clear();
56-
}
5767
};
5868

5969
using config_t = std::map<std::string, ov::AnyMap>;
@@ -92,7 +102,6 @@ struct ProviderInfo {
92102
bool so_context_embed_mode{false}; // ORT session option
93103
bool so_share_ep_contexts{false}; // ORT session option
94104
fs::path so_context_file_path{}; // ORT session option
95-
bool so_stop_share_ep_contexts{false}; // ORT session option
96105
const ConfigOptions* config_options{NULL};
97106
const std::unordered_set<std::string> valid_provider_keys = {"device_type", "device_id", "device_luid", "cache_dir", "precision",
98107
"load_config", "context", "num_of_threads", "model_priority", "num_streams", "enable_opencl_throttling", "enable_qdq_optimizer",

onnxruntime/core/providers/openvino/openvino_execution_provider.cc

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ OpenVINOExecutionProvider::~OpenVINOExecutionProvider() {
6565
backend_manager.ShutdownBackendManager();
6666
}
6767
backend_managers_.clear();
68-
shared_context_.reset();
6968
}
7069

7170
std::vector<std::unique_ptr<ComputeCapability>>
@@ -107,12 +106,7 @@ common::Status OpenVINOExecutionProvider::Compile(
107106
auto& metadata = shared_context_->shared_weights.metadata;
108107
if (session_context_.so_share_ep_contexts && metadata.empty()) {
109108
// Metadata is always read from model location, this could be a source or epctx model
110-
fs::path metadata_filename;
111-
if (session_context_.so_context_file_path.empty()) {
112-
metadata_filename = session_context_.onnx_model_path_name.parent_path() / "metadata.bin";
113-
} else {
114-
metadata_filename = session_context_.so_context_file_path.parent_path() / "metadata.bin";
115-
}
109+
fs::path metadata_filename = session_context_.onnx_model_path_name.parent_path() / "metadata.bin";
116110
std::ifstream file(metadata_filename, std::ios::binary);
117111
if (file) {
118112
file >> metadata;
@@ -197,10 +191,6 @@ common::Status OpenVINOExecutionProvider::Compile(
197191
}
198192
}
199193

200-
if (session_context_.so_stop_share_ep_contexts) {
201-
shared_context_->clear();
202-
}
203-
204194
return status;
205195
}
206196

onnxruntime/core/providers/openvino/openvino_provider_factory.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ void ParseConfigOptions(ProviderInfo& pi) {
2828
pi.so_context_embed_mode = pi.config_options->GetConfigOrDefault(kOrtSessionOptionEpContextEmbedMode, "0") == "1";
2929
pi.so_share_ep_contexts = pi.config_options->GetConfigOrDefault(kOrtSessionOptionShareEpContexts, "0") == "1";
3030
pi.so_context_file_path = pi.config_options->GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, "");
31-
pi.so_stop_share_ep_contexts = pi.config_options->GetConfigOrDefault(kOrtSessionOptionStopShareEpContexts, "0") == "1";
3231

3332
if (pi.so_share_ep_contexts) {
3433
ov::AnyMap map;

0 commit comments

Comments
 (0)