From 7045910d1057a20d54de82fea31bf45d12818f15 Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Fri, 28 Aug 2020 13:24:29 -0700 Subject: [PATCH] Support RegisterCustomOpsLibrary via the Python API (#4764) --- docs/AddingCustomOp.md | 7 +- onnxruntime/core/session/custom_ops.cc | 14 +- onnxruntime/core/session/inference_session.cc | 3 +- .../python/onnxruntime_pybind_state.cc | 221 +++++++++++++----- .../python/onnxruntime_pybind_state_common.h | 82 ++++++- .../test/python/onnxruntime_test_python.py | 51 +++- .../python/orttraining_pybind_state.cc | 96 ++++---- 7 files changed, 349 insertions(+), 125 deletions(-) diff --git a/docs/AddingCustomOp.md b/docs/AddingCustomOp.md index 5db820f170d98..396d8caf8266b 100644 --- a/docs/AddingCustomOp.md +++ b/docs/AddingCustomOp.md @@ -2,12 +2,15 @@ Adding a new op =============== ## A new op can be written and registered with ONNXRuntime in the following 3 ways -### 1. Using the experimental custom op API in the C API (onnxruntime_c_api.h) -Note: These APIs are experimental and will change in the next release. They're released now for feedback and experimentation. +### 1. Using the custom op API in the C/C++ APIs (onnxruntime_c_api.h) * Create an OrtCustomOpDomain with the domain name used by the custom ops * Create an OrtCustomOp structure for each op and add them to the OrtCustomOpDomain with OrtCustomOpDomain_Add * Call OrtAddCustomOpDomain to add the custom domain of ops to the session options See [this](../onnxruntime/test/shared_lib/test_inference.cc) for an example called MyCustomOp that uses the C++ helper API (onnxruntime_cxx_api.h). +You can also compile the custom ops into a shared library and use that to run a model via the C++ API. The same test file contains an example. +The source code for a sample custom op shared library containing two custom kernels is [here](../onnxruntime/test/testdata/custom_op_library/custom_op_library.cc). +See [this](../onnxruntime/test/python/onnxruntime_test_python.py) for an example called testRegisterCustomOpsLibrary that uses the Python API +to register a shared library that contains custom op kernels. Currently, the only supported Execution Providers (EPs) for custom ops registered via this approach are the `CUDA` and the `CPU` EPs. ### 2. Using RegisterCustomRegistry API diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index 7389c79c18e8e..2b5650769d3cb 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -98,8 +98,18 @@ common::Status CreateCustomRegistry(const std::vector& op_do output = std::make_shared(); for (auto& domain : op_domains) { - if (domain->domain_[0]) - ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange::Instance().AddDomainToVersion(domain->domain_, 1, 1000); + // Domain is not empty - add it to the DomainToVersion ONNX map + // If domain is empty, it is assumed to be part of the ONNX domain + if (domain->domain_[0]) { + // Add it to the DomainToVersion ONNX map if it doesn't already exist + // For example, two sessions using the same session_options should not add the same custom op domain to the version map twice + auto& domain_to_version_range_instance = ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange::Instance(); + const auto& domain_to_version_map = domain_to_version_range_instance.Map(); + + if (domain_to_version_map.find(domain->domain_) == domain_to_version_map.end()) { + domain_to_version_range_instance.AddDomainToVersion(domain->domain_, 1, 1000); + } + } std::vector schemas_list; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 4ce318c62b5c0..dd814a91bbb37 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -408,8 +408,7 @@ common::Status InferenceSession::RegisterCustomRegistry(std::shared_ptrGetKernelRegistry()); - // if (custom_schema_registries_.empty()) - // custom_schema_registries_.push_back(); + custom_schema_registries_.push_back(custom_registry->GetOpschemaRegistry()); return Status::OK(); } diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 9d3838cbb5996..3a3baa9000eff 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -16,11 +16,17 @@ #include "core/framework/data_types_internal.h" #include "core/framework/kernel_registry.h" #include "core/framework/random_seed.h" -#include "core/framework/session_options.h" #include "core/framework/tensorprotoutils.h" #include "core/framework/TensorSeq.h" #include "core/graph/graph_viewer.h" #include "core/session/IOBinding.h" +#include "core/session/abi_session_options_impl.h" +#include "core/platform/env.h" + +struct OrtStatus { + OrtErrorCode code; + char msg[1]; // a null-terminated string +}; #if USE_CUDA #define BACKEND_PROC "GPU" @@ -189,6 +195,61 @@ namespace py = pybind11; using namespace onnxruntime; using namespace onnxruntime::logging; +// Custom op section starts +static Env& platform_env = Env::Default(); + +CustomOpLibrary::CustomOpLibrary(const char* library_path, OrtSessionOptions& ort_so) { + { + OrtPybindThrowIfError(platform_env.LoadDynamicLibrary(library_path, &library_handle_)); + + if (!library_handle_) + throw std::runtime_error("RegisterCustomOpsLibrary: Failed to load library"); + + OrtStatus*(ORT_API_CALL * RegisterCustomOps)(OrtSessionOptions * options, const OrtApiBase* api); + + OrtPybindThrowIfError(platform_env.GetSymbolFromLibrary(library_handle_, "RegisterCustomOps", (void**)&RegisterCustomOps)); + + if (!RegisterCustomOps) + throw std::runtime_error("RegisterCustomOpsLibrary: Entry point RegisterCustomOps not found in library"); + + auto* status_raw = RegisterCustomOps(&ort_so, OrtGetApiBase()); + // Manage the raw Status pointer using a smart pointer + auto status = std::unique_ptr(status_raw); + + // A non-nullptr indicates status indicates some error + if (status) { + // TODO: How to handle unload failure ? + // Currently we ignore the returned status assuming it is successful + platform_env.UnloadDynamicLibrary(library_handle_); + + // Construct error message string + std::string error_string = status->msg; + + // Throw + throw std::runtime_error(error_string); + } + + library_path_ = std::string(library_path); + } +} + +// Unload the library when the destructor is triggered +CustomOpLibrary::~CustomOpLibrary() { + UnloadLibrary(); +} + +// Logic to unload the library +void CustomOpLibrary::UnloadLibrary() { + auto status = platform_env.UnloadDynamicLibrary(library_handle_); + + if (!status.IsOK()) { + const logging::Logger& default_logger = logging::LoggingManager::DefaultLogger(); + LOGS(default_logger, WARNING) << "Unable to unload the custom op shared library: " << library_path_; + } +} + +// Custom op section ends + template void AddNonTensor(const OrtValue& val, std::vector& pyobjs, const DataTransferManager* /*data_transfer_manager*/) { pyobjs.push_back(py::cast(val.Get())); @@ -577,6 +638,21 @@ void GenerateProviderOptionsMap(const std::vector& providers, } } +void RegisterCustomOpDomainsAndLibraries(PyInferenceSession* sess, const PySessionOptions& so) { + if (!so.custom_op_domains_.empty()) { + // Register all custom op domains that will be needed for the session + std::vector custom_op_domains; + custom_op_domains.reserve(so.custom_op_domains_.size()); + for (size_t i = 0; i < so.custom_op_domains_.size(); ++i) { + custom_op_domains.emplace_back(so.custom_op_domains_[i]); + } + OrtPybindThrowIfError(sess->GetSessionHandle()->AddCustomOpDomains(custom_op_domains)); + + // Register all custom op libraries that will be needed for the session + sess->AddCustomOpLibraries(so.custom_op_libraries_); + } +} + void InitializeSession(InferenceSession* sess, const std::vector& provider_types) { if (provider_types.empty()) { // use default registration priority. @@ -854,7 +930,10 @@ void addObjectMethods(py::module& m, Environment& env) { py::class_ binding(m, "SessionIOBinding"); binding - .def(py::init()) + .def(py::init([](PyInferenceSession* sess) { + auto sess_io_binding = onnxruntime::make_unique(sess->GetSessionHandle()); + return sess_io_binding; + })) .def("bind_input", [](SessionIOBinding* io_binding, const std::string& name, py::object arr_on_cpu) -> void { OrtValue mlvalue; @@ -953,36 +1032,36 @@ void addObjectMethods(py::module& m, Environment& env) { return rfetch; }); - py::class_ + py::class_ sess(m, "SessionOptions", R"pbdoc(Configuration information for a session.)pbdoc"); sess .def(py::init()) - .def_readwrite("enable_cpu_mem_arena", &SessionOptions::enable_cpu_mem_arena, + .def_readwrite("enable_cpu_mem_arena", &PySessionOptions::enable_cpu_mem_arena, R"pbdoc(Enables the memory arena on CPU. Arena may pre-allocate memory for future usage. Set this option to false if you don't want it. Default is True.)pbdoc") - .def_readwrite("enable_profiling", &SessionOptions::enable_profiling, + .def_readwrite("enable_profiling", &PySessionOptions::enable_profiling, R"pbdoc(Enable profiling for this session. Default is false.)pbdoc") - .def_readwrite("optimized_model_filepath", &SessionOptions::optimized_model_filepath, + .def_readwrite("optimized_model_filepath", &PySessionOptions::optimized_model_filepath, R"pbdoc(File path to serialize optimized model. By default, optimized model is not serialized if optimized_model_filepath is not provided.)pbdoc") - .def_readwrite("enable_mem_pattern", &SessionOptions::enable_mem_pattern, + .def_readwrite("enable_mem_pattern", &PySessionOptions::enable_mem_pattern, R"pbdoc(Enable the memory pattern optimization. Default is true.)pbdoc") - .def_readwrite("logid", &SessionOptions::session_logid, + .def_readwrite("logid", &PySessionOptions::session_logid, R"pbdoc(Logger id to use for session output.)pbdoc") - .def_readwrite("log_severity_level", &SessionOptions::session_log_severity_level, + .def_readwrite("log_severity_level", &PySessionOptions::session_log_severity_level, R"pbdoc(Log severity level. Applies to session load, initialization, etc. 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2.)pbdoc") - .def_readwrite("log_verbosity_level", &SessionOptions::session_log_verbosity_level, + .def_readwrite("log_verbosity_level", &PySessionOptions::session_log_verbosity_level, R"pbdoc(VLOG level if DEBUG build and session_log_verbosity_level is 0. Applies to session load, initialization, etc. Default is 0.)pbdoc") .def_property( - "intra_op_num_threads", [](const SessionOptions* options) -> int { return options->intra_op_param.thread_pool_size; }, [](SessionOptions* options, int value) -> void { options->intra_op_param.thread_pool_size = value; }, R"pbdoc(Sets the number of threads used to parallelize the execution within nodes. Default is 0 to let onnxruntime choose.)pbdoc") + "intra_op_num_threads", [](const PySessionOptions* options) -> int { return options->intra_op_param.thread_pool_size; }, [](PySessionOptions* options, int value) -> void { options->intra_op_param.thread_pool_size = value; }, R"pbdoc(Sets the number of threads used to parallelize the execution within nodes. Default is 0 to let onnxruntime choose.)pbdoc") .def_property( - "inter_op_num_threads", [](const SessionOptions* options) -> int { return options->inter_op_param.thread_pool_size; }, [](SessionOptions* options, int value) -> void { options->inter_op_param.thread_pool_size = value; }, R"pbdoc(Sets the number of threads used to parallelize the execution of the graph (across nodes). Default is 0 to let onnxruntime choose.)pbdoc") - .def_readwrite("execution_mode", &SessionOptions::execution_mode, + "inter_op_num_threads", [](const PySessionOptions* options) -> int { return options->inter_op_param.thread_pool_size; }, [](PySessionOptions* options, int value) -> void { options->inter_op_param.thread_pool_size = value; }, R"pbdoc(Sets the number of threads used to parallelize the execution of the graph (across nodes). Default is 0 to let onnxruntime choose.)pbdoc") + .def_readwrite("execution_mode", &PySessionOptions::execution_mode, R"pbdoc(Sets the execution mode. Default is sequential.)pbdoc") .def_property( "graph_optimization_level", - [](const SessionOptions* options) -> GraphOptimizationLevel { + [](const PySessionOptions* options) -> GraphOptimizationLevel { GraphOptimizationLevel retval = ORT_ENABLE_ALL; switch (options->graph_optimization_level) { case onnxruntime::TransformerLevel::Default: @@ -1005,7 +1084,7 @@ Applies to session load, initialization, etc. Default is 0.)pbdoc") return retval; }, - [](SessionOptions* options, GraphOptimizationLevel level) -> void { + [](PySessionOptions* options, GraphOptimizationLevel level) -> void { switch (level) { case ORT_DISABLE_ALL: options->graph_optimization_level = onnxruntime::TransformerLevel::Default; @@ -1022,11 +1101,11 @@ Applies to session load, initialization, etc. Default is 0.)pbdoc") } }, R"pbdoc(Graph optimization level for this session.)pbdoc") - .def_readwrite("use_deterministic_compute", &SessionOptions::use_deterministic_compute, + .def_readwrite("use_deterministic_compute", &PySessionOptions::use_deterministic_compute, R"pbdoc(Whether to use deterministic compute. Default is false.)pbdoc") .def( "add_free_dimension_override_by_denotation", - [](SessionOptions* options, const char* dim_name, int64_t dim_value) + [](PySessionOptions* options, const char* dim_name, int64_t dim_value) -> void { options->free_dimension_overrides.push_back( onnxruntime::FreeDimensionOverride{ dim_name, @@ -1035,7 +1114,7 @@ Applies to session load, initialization, etc. Default is 0.)pbdoc") "Rpbdoc(Specify the dimension size for each denotation associated with an input's free dimension.)pbdoc") .def( "add_free_dimension_override_by_name", - [](SessionOptions* options, const char* dim_name, int64_t dim_value) + [](PySessionOptions* options, const char* dim_name, int64_t dim_value) -> void { options->free_dimension_overrides.push_back( onnxruntime::FreeDimensionOverride{ dim_name, @@ -1044,7 +1123,7 @@ Applies to session load, initialization, etc. Default is 0.)pbdoc") "Rpbdoc(Specify values of named dimensions within model inputs.)pbdoc") .def( "add_session_config_entry", - [](SessionOptions* options, const char* config_key, const char* config_value) -> void { + [](PySessionOptions* options, const char* config_key, const char* config_value) -> void { const Status status = AddSessionConfigEntryImpl(*options, config_key, config_value); if (!status.IsOK()) throw std::runtime_error(status.ErrorMessage()); @@ -1052,14 +1131,32 @@ Applies to session load, initialization, etc. Default is 0.)pbdoc") "Rpbdoc(Set a single session configuration entry as a pair of strings.)pbdoc") .def( "get_session_config_entry", - [](SessionOptions* options, const char* config_key) -> std::string { + [](PySessionOptions* options, const char* config_key) -> std::string { const std::string key(config_key); if (!HasSessionConfigEntry(*options, key)) throw std::runtime_error("SessionOptions does not have configuration with key: " + key); return options->session_configurations.at(key); }, - "Rpbdoc(Get a single session configuration value using the given configuration key.)pbdoc"); + "Rpbdoc(Get a single session configuration value using the given configuration key.)pbdoc") + .def( + "register_custom_ops_library", + [](PySessionOptions* options, const std::string& library_path) + -> void { + // We need to pass in an `OrtSessionOptions` instance because the exported method in the shared library expects that + // Once we have access to the `OrtCustomOpDomains` within the passed in `OrtSessionOptions` instance, we place it + // into the container we are maintaining for that very purpose and the `ortSessionoptions` instance can go out of scope. + OrtSessionOptions s; + + options->custom_op_libraries_.emplace_back(std::make_shared(library_path.c_str(), s)); + + // reserve enough memory to hold current contents and the new incoming contents + options->custom_op_domains_.reserve(options->custom_op_domains_.size() + s.custom_op_domains_.size()); + for (size_t i = 0; i < s.custom_op_domains_.size(); ++i) { + options->custom_op_domains_.emplace_back(s.custom_op_domains_[i]); + } + }, + "Rpbdoc(Specify the path to the shared library containing the custom op kernels required to run a model.)pbdoc"); py::class_(m, "RunOptions", R"pbdoc(Configuration information for a single Run.)pbdoc") .def(py::init()) @@ -1151,37 +1248,32 @@ including arg name, arg type (contains both type and shape).)pbdoc") "node shape (assuming the node holds a tensor)"); py::class_(m, "SessionObjectInitializer"); - py::class_(m, "InferenceSession", R"pbdoc(This is the main class used to run a model.)pbdoc") + py::class_(m, "InferenceSession", R"pbdoc(This is the main class used to run a model.)pbdoc") // In Python3, a Python bytes object will be passed to C++ functions that accept std::string or char* // without any conversion. So this init method can be used for model file path (string) // and model content (bytes) - .def(py::init([&env](const SessionOptions& so, const std::string& arg, bool is_arg_file_name) { - // Given arg is the file path. Invoke the corresponding ctor(). - if (is_arg_file_name) { - return onnxruntime::make_unique(so, env, arg); - } - - // Given arg is the model content as bytes. Invoke the corresponding ctor(). - std::istringstream buffer(arg); - return onnxruntime::make_unique(so, env, buffer); + .def(py::init([&env](const PySessionOptions& so, const std::string& arg, bool is_arg_file_name) { + auto sess = onnxruntime::make_unique(env, so, arg, is_arg_file_name); + RegisterCustomOpDomainsAndLibraries(sess.get(), so); + return sess; })) .def( - "load_model", [](InferenceSession* sess, std::vector& provider_types) { - OrtPybindThrowIfError(sess->Load()); - InitializeSession(sess, provider_types); + "load_model", [](PyInferenceSession* sess, std::vector& provider_types) { + OrtPybindThrowIfError(sess->GetSessionHandle()->Load()); + InitializeSession(sess->GetSessionHandle(), provider_types); }, R"pbdoc(Load a model saved in ONNX format.)pbdoc") .def( - "load_model", [](InferenceSession* sess, std::vector& provider_types, ProviderOptionsVector& provider_options) { - OrtPybindThrowIfError(sess->Load()); - InitializeSession(sess, provider_types, provider_options); + "load_model", [](PyInferenceSession* sess, std::vector& provider_types, ProviderOptionsVector& provider_options) { + OrtPybindThrowIfError(sess->GetSessionHandle()->Load()); + InitializeSession(sess->GetSessionHandle(), provider_types, provider_options); }, R"pbdoc(Load a model saved in ONNX format.)pbdoc") - .def("run", [](InferenceSession* sess, std::vector output_names, std::map pyfeeds, RunOptions* run_options = nullptr) -> std::vector { + .def("run", [](PyInferenceSession* sess, std::vector output_names, std::map pyfeeds, RunOptions* run_options = nullptr) -> std::vector { NameMLValMap feeds; for (auto _ : pyfeeds) { OrtValue ml_value; - auto px = sess->GetModelInputs(); + auto px = sess->GetSessionHandle()->GetModelInputs(); if (!px.first.IsOK() || !px.second) { throw std::runtime_error("Either failed to get model inputs from the session object or the input def list was null"); } @@ -1209,9 +1301,9 @@ including arg name, arg type (contains both type and shape).)pbdoc") // release GIL to allow multiple python threads to invoke Run() in parallel. py::gil_scoped_release release; if (run_options != nullptr) { - OrtPybindThrowIfError(sess->Run(*run_options, feeds, output_names, &fetches)); + OrtPybindThrowIfError(sess->GetSessionHandle()->Run(*run_options, feeds, output_names, &fetches)); } else { - OrtPybindThrowIfError(sess->Run(feeds, output_names, &fetches)); + OrtPybindThrowIfError(sess->GetSessionHandle()->Run(feeds, output_names, &fetches)); } } @@ -1226,44 +1318,45 @@ including arg name, arg type (contains both type and shape).)pbdoc") } return rfetch; }) - .def("end_profiling", [](InferenceSession* sess) -> std::string { - return sess->EndProfiling(); + .def("end_profiling", [](PyInferenceSession* sess) -> std::string { + return sess->GetSessionHandle()->EndProfiling(); }) - .def("get_providers", [](InferenceSession* sess) -> const std::vector& { - return sess->GetRegisteredProviderTypes(); + .def("get_providers", [](PyInferenceSession* sess) -> const std::vector& { + return sess->GetSessionHandle()->GetRegisteredProviderTypes(); }) - .def("get_provider_options", [](const InferenceSession* sess) -> const ProviderOptionsMap& { - return sess->GetAllProviderOptions(); + .def("get_provider_options", [](const PyInferenceSession* sess) -> const ProviderOptionsMap& { + return sess->GetSessionHandle()->GetAllProviderOptions(); }) - .def_property_readonly("session_options", [](InferenceSession* sess) -> const SessionOptions& { - return sess->GetSessionOptions(); + .def_property_readonly("session_options", [](PyInferenceSession* sess) -> const PySessionOptions& { + const auto& session_options = sess->GetSessionHandle()->GetSessionOptions(); + return static_cast(session_options); }) - .def_property_readonly("inputs_meta", [](const InferenceSession* sess) -> const std::vector& { - auto res = sess->GetModelInputs(); + .def_property_readonly("inputs_meta", [](const PyInferenceSession* sess) -> const std::vector& { + auto res = sess->GetSessionHandle()->GetModelInputs(); OrtPybindThrowIfError(res.first); return *(res.second); }) - .def_property_readonly("outputs_meta", [](const InferenceSession* sess) -> const std::vector& { - auto res = sess->GetModelOutputs(); + .def_property_readonly("outputs_meta", [](const PyInferenceSession* sess) -> const std::vector& { + auto res = sess->GetSessionHandle()->GetModelOutputs(); OrtPybindThrowIfError(res.first); return *(res.second); }) - .def_property_readonly("overridable_initializers", [](const InferenceSession* sess) -> const std::vector& { - auto res = sess->GetOverridableInitializers(); + .def_property_readonly("overridable_initializers", [](const PyInferenceSession* sess) -> const std::vector& { + auto res = sess->GetSessionHandle()->GetOverridableInitializers(); OrtPybindThrowIfError(res.first); return *(res.second); }) - .def_property_readonly("model_meta", [](const InferenceSession* sess) -> const onnxruntime::ModelMetadata& { - auto res = sess->GetModelMetadata(); + .def_property_readonly("model_meta", [](const PyInferenceSession* sess) -> const onnxruntime::ModelMetadata& { + auto res = sess->GetSessionHandle()->GetModelMetadata(); OrtPybindThrowIfError(res.first); return *(res.second); }) - .def("run_with_iobinding", [](InferenceSession* sess, SessionIOBinding& io_binding, RunOptions* run_options = nullptr) -> void { + .def("run_with_iobinding", [](PyInferenceSession* sess, SessionIOBinding& io_binding, RunOptions* run_options = nullptr) -> void { Status status; if (!run_options) - status = sess->Run(*io_binding.Get()); + status = sess->GetSessionHandle()->Run(*io_binding.Get()); else - status = sess->Run(*run_options, *io_binding.Get()); + status = sess->GetSessionHandle()->Run(*run_options, *io_binding.Get()); if (!status.IsOK()) throw std::runtime_error("Error in execution: " + status.ErrorMessage()); }); @@ -1340,7 +1433,7 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) { import_array1(); })(); - Environment& env = get_env(); + Environment& env = GetEnv(); addGlobalMethods(m, env); addObjectMethods(m, env); @@ -1358,7 +1451,7 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) { // static variable used to create inference session and training session. static std::unique_ptr session_env; -void initialize_env() { +void InitializeEnv() { auto initialize = [&]() { // Initialization of the module ([]() -> void { @@ -1381,9 +1474,9 @@ void initialize_env() { initialize(); } -onnxruntime::Environment& get_env() { +onnxruntime::Environment& GetEnv() { if (!session_env) { - initialize_env(); + InitializeEnv(); } return *session_env; } diff --git a/onnxruntime/python/onnxruntime_pybind_state_common.h b/onnxruntime/python/onnxruntime_pybind_state_common.h index b1b7637f2bc7b..dbc96096cb8f6 100644 --- a/onnxruntime/python/onnxruntime_pybind_state_common.h +++ b/onnxruntime/python/onnxruntime_pybind_state_common.h @@ -5,19 +5,82 @@ #include "core/common/logging/sinks/cerr_sink.h" #include "core/framework/allocator.h" #include "core/framework/session_options.h" - #include "core/session/environment.h" +#include "core/session/inference_session.h" namespace onnxruntime { -class InferenceSession; - namespace python { using namespace onnxruntime; using namespace onnxruntime::logging; -inline const SessionOptions& GetDefaultCPUSessionOptions() { - static SessionOptions so; +struct CustomOpLibrary { + CustomOpLibrary(const char* library_path, OrtSessionOptions& ort_so); + + ~CustomOpLibrary(); + + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(CustomOpLibrary); + + private: + void UnloadLibrary(); + + std::string library_path_; + void* library_handle_ = nullptr; +}; + +// Thin wrapper over internal C++ SessionOptions to accommodate custom op library management for the Python user +struct PySessionOptions : public SessionOptions { + // `PySessionOptions` has a vector of shared_ptrs to CustomOpLibrary, because so that it can be re-used for all + // `PyInferenceSession`s using the same `PySessionOptions` and that each `PyInferenceSession` need not construct + // duplicate CustomOpLibrary instances. + std::vector> custom_op_libraries_; + + // Hold raw `OrtCustomOpDomain` pointers - it is upto the shared library to release the OrtCustomOpDomains + // that was created when the library is unloaded + std::vector custom_op_domains_; +}; + +// Thin wrapper over internal C++ InferenceSession to accommodate custom op library management for the Python user +struct PyInferenceSession { + // Default ctor is present only to be invoked by the PyTrainingSession class + PyInferenceSession() {} + + PyInferenceSession(Environment& env, const PySessionOptions& so, const std::string& arg, bool is_arg_file_name) { + if (is_arg_file_name) { + // Given arg is the file path. Invoke the corresponding ctor(). + sess_ = onnxruntime::make_unique(so, env, arg); + } else { + // Given arg is the model content as bytes. Invoke the corresponding ctor(). + std::istringstream buffer(arg); + sess_ = onnxruntime::make_unique(so, env, buffer); + } + } + + void AddCustomOpLibraries(const std::vector>& custom_op_libraries) { + if (!custom_op_libraries.empty()) { + custom_op_libraries_.reserve(custom_op_libraries_.size() + custom_op_libraries.size()); + for (size_t i = 0; i < custom_op_libraries.size(); ++i) { + custom_op_libraries_.push_back(custom_op_libraries[i]); + } + } + } + + InferenceSession* GetSessionHandle() const { return sess_.get(); } + + virtual ~PyInferenceSession() {} + + protected: + // Hold CustomOpLibrary resources so as to tie it to the life cycle of the InferenceSession needing it. + // NOTE: Declare this above `sess_` so that this is destructed AFTER the InferenceSession instance - + // this is so that the custom ops held by the InferenceSession gets destroyed prior to the library getting unloaded + // (if ref count of the shared_ptr reaches 0) + std::vector> custom_op_libraries_; + + std::unique_ptr sess_; +}; + +inline const PySessionOptions& GetDefaultCPUSessionOptions() { + static PySessionOptions so; return so; } @@ -28,7 +91,7 @@ inline AllocatorPtr& GetAllocator() { class SessionObjectInitializer { public: - typedef const SessionOptions& Arg1; + typedef const PySessionOptions& Arg1; // typedef logging::LoggingManager* Arg2; static const std::string default_logger_id; operator Arg1() { @@ -47,10 +110,9 @@ class SessionObjectInitializer { } }; -Environment& get_env(); +Environment& GetEnv(); void InitializeSession(InferenceSession* sess, const std::vector& provider_types); -} -} - +} // namespace python +} // namespace onnxruntime diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index 375115e5ee311..c80b2df0fbcff 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -7,9 +7,9 @@ import numpy as np import onnxruntime as onnxrt import threading +import sys from helper import get_name - class TestInferenceSession(unittest.TestCase): def run_model(self, session_object, run_options): @@ -650,5 +650,54 @@ def testInvalidSessionOptionsConfigEntry(self): self.assertTrue( 'SessionOptions does not have configuration with key: ' + invalide_key in str(context.exception)) + def testRegisterCustomOpsLibrary(self): + if sys.platform.startswith("win"): + shared_library = 'custom_op_library.dll' + if not os.path.exists(shared_library): + raise FileNotFoundError("Unable to find '{0}'".format(shared_library)) + + elif sys.platform.startswith("darwin"): + shared_library = 'libcustom_op_library.dylib' + if not os.path.exists(shared_library): + raise FileNotFoundError("Unable to find '{0}'".format(shared_library)) + + else: + shared_library = './libcustom_op_library.so' + if not os.path.exists(shared_library): + raise FileNotFoundError("Unable to find '{0}'".format(shared_library)) + + this = os.path.dirname(__file__) + custom_op_model = os.path.join(this, "testdata", "custom_op_library", "custom_op_test.onnx") + if not os.path.exists(custom_op_model): + raise FileNotFoundError("Unable to find '{0}'".format(custom_op_model)) + + so1 = onnxrt.SessionOptions() + so1.register_custom_ops_library(shared_library) + + # Model loading successfully indicates that the custom op node could be resolved successfully + sess1 = onnxrt.InferenceSession(custom_op_model, so1) + #Run with input data + input_name_0 = sess1.get_inputs()[0].name + input_name_1 = sess1.get_inputs()[1].name + output_name = sess1.get_outputs()[0].name + input_0 = np.ones((3,5)).astype(np.float32) + input_1 = np.zeros((3,5)).astype(np.float32) + res = sess1.run([output_name], {input_name_0: input_0, input_name_1: input_1}) + output_expected = np.ones((3,5)).astype(np.float32) + np.testing.assert_allclose(output_expected, res[0], rtol=1e-05, atol=1e-08) + + # Create an alias of SessionOptions instance + # We will use this alias to construct another InferenceSession + so2 = so1 + + # Model loading successfully indicates that the custom op node could be resolved successfully + sess2 = onnxrt.InferenceSession(custom_op_model, so2) + + # Create another SessionOptions instance with the same shared library referenced + so3 = onnxrt.SessionOptions() + so3.register_custom_ops_library(shared_library) + sess3 = onnxrt.InferenceSession(custom_op_model, so3) + + if __name__ == '__main__': unittest.main() diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index 92cedaaead950..e110e6b950821 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -7,7 +7,6 @@ // pybind11/stl.h is needed to support std::unordered_set, etc. #include -#include "core/framework/session_options.h" #include "core/session/environment.h" #include "orttraining/core/session/training_session.h" #include "orttraining/core/graph/optimizer_config.h" @@ -67,7 +66,7 @@ TrainingConfigurationResult ConfigureSessionForTraining( auto data_group_size = parameters.world_size / parameters.horizontal_parallel_size; if (data_group_size != parameters.data_parallel_size) { LOGS(*(sess->GetLogger()), WARNING) << "data_parallel_size is not correct, tuned automatically to " - << data_group_size; + << data_group_size; parameters.data_parallel_size = data_group_size; } @@ -146,23 +145,23 @@ TrainingConfigurationResult ConfigureSessionForTraining( #if defined(USE_NCCL) void CopyMPIContextToTrainingParameters(TrainingParameters& parameters, const logging::Logger* logger) { - LOGS(*logger, INFO) << "MPIContext::GetInstance().GetWorldRank(): " << MPIContext::GetInstance().GetWorldRank(); - LOGS(*logger, INFO) << "MPIContext::GetInstance().GetLocalRank(): " << MPIContext::GetInstance().GetLocalRank(); - LOGS(*logger, INFO) << "MPIContext::GetInstance().GetWorldSize(): " << MPIContext::GetInstance().GetWorldSize(); - LOGS(*logger, INFO) << "MPIContext::GetInstance().GetLocalSize(): " << MPIContext::GetInstance().GetLocalSize(); - - parameters.local_rank = MPIContext::GetInstance().GetLocalRank(); - parameters.local_size = MPIContext::GetInstance().GetLocalSize(); - if (parameters.world_rank != MPIContext::GetInstance().GetWorldRank()) { - if (parameters.world_rank != 0) - LOGS(*logger, WARNING) << "TrainingParameters world_rank is not correct, tuned automatically to " << MPIContext::GetInstance().GetWorldRank(); - parameters.world_rank = MPIContext::GetInstance().GetWorldRank(); - } - if (parameters.world_size != MPIContext::GetInstance().GetWorldSize()) { - if (parameters.world_size != 1) - LOGS(*logger, WARNING) << "TrainingParameters world_size is not correct, tuned automatically to " << MPIContext::GetInstance().GetWorldSize(); - parameters.world_size = MPIContext::GetInstance().GetWorldSize(); - } + LOGS(*logger, INFO) << "MPIContext::GetInstance().GetWorldRank(): " << MPIContext::GetInstance().GetWorldRank(); + LOGS(*logger, INFO) << "MPIContext::GetInstance().GetLocalRank(): " << MPIContext::GetInstance().GetLocalRank(); + LOGS(*logger, INFO) << "MPIContext::GetInstance().GetWorldSize(): " << MPIContext::GetInstance().GetWorldSize(); + LOGS(*logger, INFO) << "MPIContext::GetInstance().GetLocalSize(): " << MPIContext::GetInstance().GetLocalSize(); + + parameters.local_rank = MPIContext::GetInstance().GetLocalRank(); + parameters.local_size = MPIContext::GetInstance().GetLocalSize(); + if (parameters.world_rank != MPIContext::GetInstance().GetWorldRank()) { + if (parameters.world_rank != 0) + LOGS(*logger, WARNING) << "TrainingParameters world_rank is not correct, tuned automatically to " << MPIContext::GetInstance().GetWorldRank(); + parameters.world_rank = MPIContext::GetInstance().GetWorldRank(); + } + if (parameters.world_size != MPIContext::GetInstance().GetWorldSize()) { + if (parameters.world_size != 1) + LOGS(*logger, WARNING) << "TrainingParameters world_size is not correct, tuned automatically to " << MPIContext::GetInstance().GetWorldSize(); + parameters.world_size = MPIContext::GetInstance().GetWorldSize(); + } } #endif @@ -205,14 +204,23 @@ void addObjectMethodsForTraining(py::module& m) { return py::none(); }); - py::class_ training_session(m, "TrainingSession"); - training_session.def(py::init([](const SessionOptions& so) { - Environment& env = get_env(); - return onnxruntime::make_unique(so, env); - })) + // Thin wrapper over internal C++ InferenceSession to accommodate custom op library management for the Python user + struct PyTrainingSession : public PyInferenceSession { + PyTrainingSession(Environment& env, const PySessionOptions& so) { + // `sess_` is inherited from PyinferenceSession + sess_ = onnxruntime::make_unique(so, env); + } + }; + + py::class_ training_session(m, "TrainingSession"); + training_session + .def(py::init([](const PySessionOptions& so) { + Environment& env = GetEnv(); + return onnxruntime::make_unique(env, so); + })) .def(py::init([]() { - Environment& env = get_env(); - return onnxruntime::make_unique(GetDefaultCPUSessionOptions(), env); + Environment& env = GetEnv(); + return onnxruntime::make_unique(env, GetDefaultCPUSessionOptions()); })) .def("finalize", [](py::object) { #if defined(USE_NCCL) @@ -224,37 +232,37 @@ void addObjectMethodsForTraining(py::module& m) { #endif #endif }) - .def("load_model", [](onnxruntime::training::TrainingSession* sess, const std::string& path, TrainingParameters& parameters) { - OrtPybindThrowIfError(sess->Load(path)); + .def("load_model", [](PyTrainingSession* sess, const std::string& path, TrainingParameters& parameters) { + OrtPybindThrowIfError(sess->GetSessionHandle()->Load(path)); #if defined(USE_NCCL) - CopyMPIContextToTrainingParameters(parameters, sess->GetLogger()); + CopyMPIContextToTrainingParameters(parameters, sess->GetSessionHandle()->GetLogger()); #endif - const auto config_result = ConfigureSessionForTraining(sess, parameters); + const auto config_result = ConfigureSessionForTraining(static_cast(sess->GetSessionHandle()), parameters); std::vector provider_types = {}; - InitializeSession(sess, provider_types); + InitializeSession(sess->GetSessionHandle(), provider_types); return config_result; }) - .def("read_bytes", [](onnxruntime::training::TrainingSession* sess, const py::bytes& serialized_model, TrainingParameters& parameters) { + .def("read_bytes", [](PyTrainingSession* sess, const py::bytes& serialized_model, TrainingParameters& parameters) { std::istringstream buffer(serialized_model); - OrtPybindThrowIfError(sess->Load(buffer)); + OrtPybindThrowIfError(sess->GetSessionHandle()->Load(buffer)); #if defined(USE_NCCL) - CopyMPIContextToTrainingParameters(parameters, sess->GetLogger()); + CopyMPIContextToTrainingParameters(parameters, sess->GetSessionHandle()->GetLogger()); #endif - const auto config_result = ConfigureSessionForTraining(sess, parameters); + const auto config_result = ConfigureSessionForTraining(static_cast(sess->GetSessionHandle()), parameters); std::vector provider_types = {}; - InitializeSession(sess, provider_types); + InitializeSession(sess->GetSessionHandle(), provider_types); return config_result; }) - .def("get_state", [](onnxruntime::training::TrainingSession* sess) { + .def("get_state", [](PyTrainingSession* sess) { NameMLValMap state_tensors; - ORT_THROW_IF_ERROR(sess->GetStateTensors(state_tensors)); - auto& data_transfer_manager = sess->GetDataTransferManager(); + ORT_THROW_IF_ERROR(static_cast(sess->GetSessionHandle())->GetStateTensors(state_tensors)); + auto& data_transfer_manager = sess->GetSessionHandle()->GetDataTransferManager(); //convert to numpy array std::map rmap; for (auto& kv : state_tensors) { @@ -269,11 +277,11 @@ void addObjectMethodsForTraining(py::module& m) { } return rmap; }) - .def("load_state", [](onnxruntime::training::TrainingSession* sess, std::unordered_map& state, bool strict) { + .def("load_state", [](PyTrainingSession* sess, std::unordered_map& state, bool strict) { NameMLValMap state_tensors; for (auto initializer : state) { OrtValue ml_value; - auto px = sess->GetModelInputs(); + auto px = sess->GetSessionHandle()->GetModelInputs(); if (!px.first.IsOK() || !px.second) { throw std::runtime_error("Either failed to get model inputs from the session object or the input def list was null"); } @@ -293,10 +301,10 @@ void addObjectMethodsForTraining(py::module& m) { } state_tensors.insert(std::make_pair(initializer.first, ml_value)); } - ORT_THROW_IF_ERROR(sess->SetStateTensors(state_tensors, strict)); + ORT_THROW_IF_ERROR(static_cast(sess->GetSessionHandle())->SetStateTensors(state_tensors, strict)); }) - .def("is_output_fp32_node", [](onnxruntime::training::TrainingSession* sess, const std::string& output_name) { - return sess->IsGraphOutputFp32Node(output_name); + .def("is_output_fp32_node", [](PyTrainingSession* sess, const std::string& output_name) { + return static_cast(sess->GetSessionHandle())->IsGraphOutputFp32Node(output_name); }); } } // namespace python