Skip to content

Commit 002916a

Browse files
authored
Validate the context_file_path before EP compile graphs (#23611)
Validate the context_file_path before EP compile graphs to make it fail fast. To avoid the possibility that EP generate new file (context binary file or blob file) over write the existing file. Return error if the path points to folder.
1 parent 0887e36 commit 002916a

File tree

3 files changed

+138
-6
lines changed

3 files changed

+138
-6
lines changed

include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,7 @@ static const char* const kOrtSessionOptionEpContextEnable = "ep.context_enable";
275275

276276
// Specify the file path for the Onnx model which has EP context.
277277
// Default to original_file_name_ctx.onnx if not specified
278+
// Folder is not a valid option
278279
static const char* const kOrtSessionOptionEpContextFilePath = "ep.context_file_path";
279280

280281
// Flag to specify whether to dump the EP context into the Onnx model.

onnxruntime/core/framework/graph_partitioner.cc

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,29 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide
643643
return Status::OK();
644644
}
645645

646+
// Validate the ep_context_path to make sure it is file path and check whether the file exist already
647+
static Status EpContextFilePathCheck(const std::string& ep_context_path,
648+
const std::filesystem::path& model_path) {
649+
std::filesystem::path context_cache_path;
650+
if (!ep_context_path.empty()) {
651+
context_cache_path = ep_context_path;
652+
if (!context_cache_path.has_filename()) {
653+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "context_file_path should not point to a folder.");
654+
}
655+
} else if (!model_path.empty()) {
656+
context_cache_path = model_path.native() + ORT_TSTR("_ctx.onnx");
657+
} else {
658+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Both ep_context_path and model_path are empty.");
659+
}
660+
661+
if (std::filesystem::exists(context_cache_path)) {
662+
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to generate EP context model since the file '",
663+
context_cache_path, "' exist already.");
664+
}
665+
666+
return Status::OK();
667+
}
668+
646669
static Status CreateEpContextModel(const ExecutionProviders& execution_providers,
647670
const Graph& graph,
648671
const std::filesystem::path& ep_context_path,
@@ -678,11 +701,6 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers
678701
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Both ep_context_path and model_path are empty");
679702
}
680703

681-
if (std::filesystem::exists(context_cache_path)) {
682-
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to generate EP context model since the file '",
683-
context_cache_path, "' exist already.");
684-
}
685-
686704
Model ep_context_model(graph.Name(), false, graph.GetModel().MetaData(),
687705
graph.GetModel().ModelPath(), // use source model path so that external initializers can find the data file path
688706
IOnnxRuntimeOpSchemaRegistryList{graph.GetSchemaRegistry()},
@@ -1007,9 +1025,15 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr,
10071025

10081026
if (mode == Mode::kNormal || mode == Mode::kAssignOnly) {
10091027
#if !defined(ORT_MINIMAL_BUILD)
1028+
bool ep_context_enabled = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") == "1";
1029+
if (ep_context_enabled) {
1030+
std::string ep_context_path = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, "");
1031+
// Check before EP compile graphs
1032+
ORT_RETURN_IF_ERROR(EpContextFilePathCheck(ep_context_path, graph.ModelPath()));
1033+
}
1034+
10101035
ORT_RETURN_IF_ERROR(PartitionOnnxFormatModel(partition_params, mode, providers_, kernel_registry_mgr_, logger));
10111036

1012-
bool ep_context_enabled = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") == "1";
10131037
if (ep_context_enabled) {
10141038
std::string ep_context_path = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, "");
10151039
std::string external_ini_file_name = config_options.GetConfigOrDefault(kOrtSessionOptionsEpContextModelExternalInitializersFileName, "");

onnxruntime/test/providers/qnn/qnn_ep_context_test.cc

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,113 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCpuNodeWithoutExternalWeights) {
252252
EpCtxCpuNodeWithExternalIniFileTestBody(false);
253253
}
254254

255+
// Set ep.context_file_path to folder path which is not a valid option, check the error message
256+
TEST_F(QnnHTPBackendTests, QnnContextBinaryGenerationFolderPathNotExpected) {
257+
ProviderOptions provider_options;
258+
#if defined(_WIN32)
259+
provider_options["backend_path"] = "QnnHtp.dll";
260+
#else
261+
provider_options["backend_path"] = "libQnnHtp.so";
262+
#endif
263+
provider_options["offload_graph_io_quantization"] = "0";
264+
265+
const std::unordered_map<std::string, int> domain_to_version = {{"", 13}, {kMSDomain, 1}};
266+
267+
auto& logging_manager = DefaultLoggingManager();
268+
logging_manager.SetDefaultLoggerSeverity(logging::Severity::kERROR);
269+
270+
onnxruntime::Model model("QNN_EP_TestModel", false, ModelMetaData(), PathString(),
271+
IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {},
272+
logging_manager.DefaultLogger());
273+
Graph& graph = model.MainGraph();
274+
ModelTestBuilder helper(graph);
275+
bool single_ep_node = true;
276+
BuildGraphWithQAndNonQ(single_ep_node)(helper);
277+
helper.SetGraphOutputs();
278+
ASSERT_STATUS_OK(model.MainGraph().Resolve());
279+
280+
// Serialize the model to a string.
281+
std::string model_data;
282+
model.ToProto().SerializeToString(&model_data);
283+
284+
const auto model_data_span = AsByteSpan(model_data.data(), model_data.size());
285+
286+
const std::string ep_context_onnx_file = "./ep_context_folder_not_expected/";
287+
std::remove(ep_context_onnx_file.c_str());
288+
Ort::SessionOptions so;
289+
so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1");
290+
so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, ep_context_onnx_file.c_str());
291+
so.AppendExecutionProvider("QNN", provider_options);
292+
293+
try {
294+
Ort::Session session(*ort_env, model_data_span.data(), model_data_span.size(), so);
295+
FAIL(); // Should not get here!
296+
} catch (const Ort::Exception& excpt) {
297+
ASSERT_EQ(excpt.GetOrtErrorCode(), ORT_INVALID_ARGUMENT);
298+
ASSERT_THAT(excpt.what(), testing::HasSubstr("context_file_path should not point to a folder."));
299+
}
300+
}
301+
302+
// Create session 1 to generate context binary file
303+
// Create session 2 to do same thing, make sure session 2 failed because file exist already
304+
// Make sure no new file over write from session 2
305+
TEST_F(QnnHTPBackendTests, QnnContextBinaryGenerationNoOverWrite) {
306+
ProviderOptions provider_options;
307+
#if defined(_WIN32)
308+
provider_options["backend_path"] = "QnnHtp.dll";
309+
#else
310+
provider_options["backend_path"] = "libQnnHtp.so";
311+
#endif
312+
provider_options["offload_graph_io_quantization"] = "0";
313+
314+
const std::unordered_map<std::string, int> domain_to_version = {{"", 13}, {kMSDomain, 1}};
315+
316+
auto& logging_manager = DefaultLoggingManager();
317+
logging_manager.SetDefaultLoggerSeverity(logging::Severity::kERROR);
318+
319+
onnxruntime::Model model("QNN_EP_TestModel", false, ModelMetaData(), PathString(),
320+
IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {},
321+
logging_manager.DefaultLogger());
322+
Graph& graph = model.MainGraph();
323+
ModelTestBuilder helper(graph);
324+
bool single_ep_node = true;
325+
BuildGraphWithQAndNonQ(single_ep_node)(helper);
326+
helper.SetGraphOutputs();
327+
ASSERT_STATUS_OK(model.MainGraph().Resolve());
328+
329+
// Serialize the model to a string.
330+
std::string model_data;
331+
model.ToProto().SerializeToString(&model_data);
332+
333+
const auto model_data_span = AsByteSpan(model_data.data(), model_data.size());
334+
335+
const std::string ep_context_onnx_file = "./ep_context_no_over_write.onnx";
336+
const std::string ep_context_binary_file = "./ep_context_no_over_write.onnx_QNNExecutionProvider_QNN_10880527342279992768_1_0.bin";
337+
338+
std::remove(ep_context_onnx_file.c_str());
339+
Ort::SessionOptions so;
340+
so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1");
341+
so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, ep_context_onnx_file.c_str());
342+
so.AppendExecutionProvider("QNN", provider_options);
343+
344+
Ort::Session session1(*ort_env, model_data_span.data(), model_data_span.size(), so);
345+
346+
auto modify_time_1 = std::filesystem::last_write_time(ep_context_binary_file);
347+
348+
try {
349+
Ort::Session session2(*ort_env, model_data_span.data(), model_data_span.size(), so);
350+
FAIL(); // Should not get here!
351+
} catch (const Ort::Exception& excpt) {
352+
ASSERT_EQ(excpt.GetOrtErrorCode(), ORT_FAIL);
353+
ASSERT_THAT(excpt.what(), testing::HasSubstr("exist already."));
354+
auto modify_time_2 = std::filesystem::last_write_time(ep_context_binary_file);
355+
ASSERT_EQ(modify_time_1, modify_time_2);
356+
}
357+
358+
ASSERT_EQ(std::remove(ep_context_onnx_file.c_str()), 0);
359+
ASSERT_EQ(std::remove(ep_context_binary_file.c_str()), 0);
360+
}
361+
255362
// Create a model with Case + Add (quantized)
256363
// cast_input -> Cast -> Q -> DQ \
257364
// Add -> Q -> DQ -> output

0 commit comments

Comments
 (0)