Skip to content

Commit

Permalink
Support RegisterCustomOpsLibrary via the Python API (microsoft#4764)
Browse files Browse the repository at this point in the history
  • Loading branch information
hariharans29 authored Aug 28, 2020
1 parent 040c5fa commit 7045910
Show file tree
Hide file tree
Showing 7 changed files with 349 additions and 125 deletions.
7 changes: 5 additions & 2 deletions docs/AddingCustomOp.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@ Adding a new op
===============

## A new op can be written and registered with ONNXRuntime in the following 3 ways
### 1. Using the experimental custom op API in the C API (onnxruntime_c_api.h)
Note: These APIs are experimental and will change in the next release. They're released now for feedback and experimentation.
### 1. Using the custom op API in the C/C++ APIs (onnxruntime_c_api.h)
* Create an OrtCustomOpDomain with the domain name used by the custom ops
* Create an OrtCustomOp structure for each op and add them to the OrtCustomOpDomain with OrtCustomOpDomain_Add
* Call OrtAddCustomOpDomain to add the custom domain of ops to the session options
See [this](../onnxruntime/test/shared_lib/test_inference.cc) for an example called MyCustomOp that uses the C++ helper API (onnxruntime_cxx_api.h).
You can also compile the custom ops into a shared library and use that to run a model via the C++ API. The same test file contains an example.
The source code for a sample custom op shared library containing two custom kernels is [here](../onnxruntime/test/testdata/custom_op_library/custom_op_library.cc).
See [this](../onnxruntime/test/python/onnxruntime_test_python.py) for an example called testRegisterCustomOpsLibrary that uses the Python API
to register a shared library that contains custom op kernels.
Currently, the only supported Execution Providers (EPs) for custom ops registered via this approach are the `CUDA` and the `CPU` EPs.

### 2. Using RegisterCustomRegistry API
Expand Down
14 changes: 12 additions & 2 deletions onnxruntime/core/session/custom_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,18 @@ common::Status CreateCustomRegistry(const std::vector<OrtCustomOpDomain*>& op_do
output = std::make_shared<CustomRegistry>();

for (auto& domain : op_domains) {
if (domain->domain_[0])
ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange::Instance().AddDomainToVersion(domain->domain_, 1, 1000);
// Domain is not empty - add it to the DomainToVersion ONNX map
// If domain is empty, it is assumed to be part of the ONNX domain
if (domain->domain_[0]) {
// Add it to the DomainToVersion ONNX map if it doesn't already exist
// For example, two sessions using the same session_options should not add the same custom op domain to the version map twice
auto& domain_to_version_range_instance = ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange::Instance();
const auto& domain_to_version_map = domain_to_version_range_instance.Map();

if (domain_to_version_map.find(domain->domain_) == domain_to_version_map.end()) {
domain_to_version_range_instance.AddDomainToVersion(domain->domain_, 1, 1000);
}
}

std::vector<ONNX_NAMESPACE::OpSchema> schemas_list;

Expand Down
3 changes: 1 addition & 2 deletions onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -408,8 +408,7 @@ common::Status InferenceSession::RegisterCustomRegistry(std::shared_ptr<CustomRe

// Insert session-level customized kernel registry.
kernel_registry_manager_.RegisterKernelRegistry(custom_registry->GetKernelRegistry());
// if (custom_schema_registries_.empty())
// custom_schema_registries_.push_back();

custom_schema_registries_.push_back(custom_registry->GetOpschemaRegistry());
return Status::OK();
}
Expand Down
221 changes: 157 additions & 64 deletions onnxruntime/python/onnxruntime_pybind_state.cc

Large diffs are not rendered by default.

82 changes: 72 additions & 10 deletions onnxruntime/python/onnxruntime_pybind_state_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,82 @@
#include "core/common/logging/sinks/cerr_sink.h"
#include "core/framework/allocator.h"
#include "core/framework/session_options.h"

#include "core/session/environment.h"
#include "core/session/inference_session.h"

namespace onnxruntime {
class InferenceSession;

namespace python {

using namespace onnxruntime;
using namespace onnxruntime::logging;

inline const SessionOptions& GetDefaultCPUSessionOptions() {
static SessionOptions so;
struct CustomOpLibrary {
CustomOpLibrary(const char* library_path, OrtSessionOptions& ort_so);

~CustomOpLibrary();

ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(CustomOpLibrary);

private:
void UnloadLibrary();

std::string library_path_;
void* library_handle_ = nullptr;
};

// Thin wrapper over internal C++ SessionOptions to accommodate custom op library management for the Python user
struct PySessionOptions : public SessionOptions {
// `PySessionOptions` has a vector of shared_ptrs to CustomOpLibrary, because so that it can be re-used for all
// `PyInferenceSession`s using the same `PySessionOptions` and that each `PyInferenceSession` need not construct
// duplicate CustomOpLibrary instances.
std::vector<std::shared_ptr<CustomOpLibrary>> custom_op_libraries_;

// Hold raw `OrtCustomOpDomain` pointers - it is upto the shared library to release the OrtCustomOpDomains
// that was created when the library is unloaded
std::vector<OrtCustomOpDomain*> custom_op_domains_;
};

// Thin wrapper over internal C++ InferenceSession to accommodate custom op library management for the Python user
struct PyInferenceSession {
// Default ctor is present only to be invoked by the PyTrainingSession class
PyInferenceSession() {}

PyInferenceSession(Environment& env, const PySessionOptions& so, const std::string& arg, bool is_arg_file_name) {
if (is_arg_file_name) {
// Given arg is the file path. Invoke the corresponding ctor().
sess_ = onnxruntime::make_unique<InferenceSession>(so, env, arg);
} else {
// Given arg is the model content as bytes. Invoke the corresponding ctor().
std::istringstream buffer(arg);
sess_ = onnxruntime::make_unique<InferenceSession>(so, env, buffer);
}
}

void AddCustomOpLibraries(const std::vector<std::shared_ptr<CustomOpLibrary>>& custom_op_libraries) {
if (!custom_op_libraries.empty()) {
custom_op_libraries_.reserve(custom_op_libraries_.size() + custom_op_libraries.size());
for (size_t i = 0; i < custom_op_libraries.size(); ++i) {
custom_op_libraries_.push_back(custom_op_libraries[i]);
}
}
}

InferenceSession* GetSessionHandle() const { return sess_.get(); }

virtual ~PyInferenceSession() {}

protected:
// Hold CustomOpLibrary resources so as to tie it to the life cycle of the InferenceSession needing it.
// NOTE: Declare this above `sess_` so that this is destructed AFTER the InferenceSession instance -
// this is so that the custom ops held by the InferenceSession gets destroyed prior to the library getting unloaded
// (if ref count of the shared_ptr reaches 0)
std::vector<std::shared_ptr<CustomOpLibrary>> custom_op_libraries_;

std::unique_ptr<InferenceSession> sess_;
};

inline const PySessionOptions& GetDefaultCPUSessionOptions() {
static PySessionOptions so;
return so;
}

Expand All @@ -28,7 +91,7 @@ inline AllocatorPtr& GetAllocator() {

class SessionObjectInitializer {
public:
typedef const SessionOptions& Arg1;
typedef const PySessionOptions& Arg1;
// typedef logging::LoggingManager* Arg2;
static const std::string default_logger_id;
operator Arg1() {
Expand All @@ -47,10 +110,9 @@ class SessionObjectInitializer {
}
};

Environment& get_env();
Environment& GetEnv();

void InitializeSession(InferenceSession* sess, const std::vector<std::string>& provider_types);

}
}

} // namespace python
} // namespace onnxruntime
51 changes: 50 additions & 1 deletion onnxruntime/test/python/onnxruntime_test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
import numpy as np
import onnxruntime as onnxrt
import threading
import sys
from helper import get_name


class TestInferenceSession(unittest.TestCase):

def run_model(self, session_object, run_options):
Expand Down Expand Up @@ -650,5 +650,54 @@ def testInvalidSessionOptionsConfigEntry(self):
self.assertTrue(
'SessionOptions does not have configuration with key: ' + invalide_key in str(context.exception))

def testRegisterCustomOpsLibrary(self):
if sys.platform.startswith("win"):
shared_library = 'custom_op_library.dll'
if not os.path.exists(shared_library):
raise FileNotFoundError("Unable to find '{0}'".format(shared_library))

elif sys.platform.startswith("darwin"):
shared_library = 'libcustom_op_library.dylib'
if not os.path.exists(shared_library):
raise FileNotFoundError("Unable to find '{0}'".format(shared_library))

else:
shared_library = './libcustom_op_library.so'
if not os.path.exists(shared_library):
raise FileNotFoundError("Unable to find '{0}'".format(shared_library))

this = os.path.dirname(__file__)
custom_op_model = os.path.join(this, "testdata", "custom_op_library", "custom_op_test.onnx")
if not os.path.exists(custom_op_model):
raise FileNotFoundError("Unable to find '{0}'".format(custom_op_model))

so1 = onnxrt.SessionOptions()
so1.register_custom_ops_library(shared_library)

# Model loading successfully indicates that the custom op node could be resolved successfully
sess1 = onnxrt.InferenceSession(custom_op_model, so1)
#Run with input data
input_name_0 = sess1.get_inputs()[0].name
input_name_1 = sess1.get_inputs()[1].name
output_name = sess1.get_outputs()[0].name
input_0 = np.ones((3,5)).astype(np.float32)
input_1 = np.zeros((3,5)).astype(np.float32)
res = sess1.run([output_name], {input_name_0: input_0, input_name_1: input_1})
output_expected = np.ones((3,5)).astype(np.float32)
np.testing.assert_allclose(output_expected, res[0], rtol=1e-05, atol=1e-08)

# Create an alias of SessionOptions instance
# We will use this alias to construct another InferenceSession
so2 = so1

# Model loading successfully indicates that the custom op node could be resolved successfully
sess2 = onnxrt.InferenceSession(custom_op_model, so2)

# Create another SessionOptions instance with the same shared library referenced
so3 = onnxrt.SessionOptions()
so3.register_custom_ops_library(shared_library)
sess3 = onnxrt.InferenceSession(custom_op_model, so3)


if __name__ == '__main__':
unittest.main()
96 changes: 52 additions & 44 deletions orttraining/orttraining/python/orttraining_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
// pybind11/stl.h is needed to support std::unordered_set, etc.
#include <pybind11/stl.h>

#include "core/framework/session_options.h"
#include "core/session/environment.h"
#include "orttraining/core/session/training_session.h"
#include "orttraining/core/graph/optimizer_config.h"
Expand Down Expand Up @@ -67,7 +66,7 @@ TrainingConfigurationResult ConfigureSessionForTraining(
auto data_group_size = parameters.world_size / parameters.horizontal_parallel_size;
if (data_group_size != parameters.data_parallel_size) {
LOGS(*(sess->GetLogger()), WARNING) << "data_parallel_size is not correct, tuned automatically to "
<< data_group_size;
<< data_group_size;
parameters.data_parallel_size = data_group_size;
}

Expand Down Expand Up @@ -146,23 +145,23 @@ TrainingConfigurationResult ConfigureSessionForTraining(

#if defined(USE_NCCL)
void CopyMPIContextToTrainingParameters(TrainingParameters& parameters, const logging::Logger* logger) {
LOGS(*logger, INFO) << "MPIContext::GetInstance().GetWorldRank(): " << MPIContext::GetInstance().GetWorldRank();
LOGS(*logger, INFO) << "MPIContext::GetInstance().GetLocalRank(): " << MPIContext::GetInstance().GetLocalRank();
LOGS(*logger, INFO) << "MPIContext::GetInstance().GetWorldSize(): " << MPIContext::GetInstance().GetWorldSize();
LOGS(*logger, INFO) << "MPIContext::GetInstance().GetLocalSize(): " << MPIContext::GetInstance().GetLocalSize();

parameters.local_rank = MPIContext::GetInstance().GetLocalRank();
parameters.local_size = MPIContext::GetInstance().GetLocalSize();
if (parameters.world_rank != MPIContext::GetInstance().GetWorldRank()) {
if (parameters.world_rank != 0)
LOGS(*logger, WARNING) << "TrainingParameters world_rank is not correct, tuned automatically to " << MPIContext::GetInstance().GetWorldRank();
parameters.world_rank = MPIContext::GetInstance().GetWorldRank();
}
if (parameters.world_size != MPIContext::GetInstance().GetWorldSize()) {
if (parameters.world_size != 1)
LOGS(*logger, WARNING) << "TrainingParameters world_size is not correct, tuned automatically to " << MPIContext::GetInstance().GetWorldSize();
parameters.world_size = MPIContext::GetInstance().GetWorldSize();
}
LOGS(*logger, INFO) << "MPIContext::GetInstance().GetWorldRank(): " << MPIContext::GetInstance().GetWorldRank();
LOGS(*logger, INFO) << "MPIContext::GetInstance().GetLocalRank(): " << MPIContext::GetInstance().GetLocalRank();
LOGS(*logger, INFO) << "MPIContext::GetInstance().GetWorldSize(): " << MPIContext::GetInstance().GetWorldSize();
LOGS(*logger, INFO) << "MPIContext::GetInstance().GetLocalSize(): " << MPIContext::GetInstance().GetLocalSize();

parameters.local_rank = MPIContext::GetInstance().GetLocalRank();
parameters.local_size = MPIContext::GetInstance().GetLocalSize();
if (parameters.world_rank != MPIContext::GetInstance().GetWorldRank()) {
if (parameters.world_rank != 0)
LOGS(*logger, WARNING) << "TrainingParameters world_rank is not correct, tuned automatically to " << MPIContext::GetInstance().GetWorldRank();
parameters.world_rank = MPIContext::GetInstance().GetWorldRank();
}
if (parameters.world_size != MPIContext::GetInstance().GetWorldSize()) {
if (parameters.world_size != 1)
LOGS(*logger, WARNING) << "TrainingParameters world_size is not correct, tuned automatically to " << MPIContext::GetInstance().GetWorldSize();
parameters.world_size = MPIContext::GetInstance().GetWorldSize();
}
}
#endif

Expand Down Expand Up @@ -205,14 +204,23 @@ void addObjectMethodsForTraining(py::module& m) {
return py::none();
});

py::class_<onnxruntime::training::TrainingSession, InferenceSession> training_session(m, "TrainingSession");
training_session.def(py::init([](const SessionOptions& so) {
Environment& env = get_env();
return onnxruntime::make_unique<onnxruntime::training::TrainingSession>(so, env);
}))
// Thin wrapper over internal C++ InferenceSession to accommodate custom op library management for the Python user
struct PyTrainingSession : public PyInferenceSession {
PyTrainingSession(Environment& env, const PySessionOptions& so) {
// `sess_` is inherited from PyinferenceSession
sess_ = onnxruntime::make_unique<onnxruntime::training::TrainingSession>(so, env);
}
};

py::class_<PyTrainingSession, PyInferenceSession> training_session(m, "TrainingSession");
training_session
.def(py::init([](const PySessionOptions& so) {
Environment& env = GetEnv();
return onnxruntime::make_unique<PyTrainingSession>(env, so);
}))
.def(py::init([]() {
Environment& env = get_env();
return onnxruntime::make_unique<onnxruntime::training::TrainingSession>(GetDefaultCPUSessionOptions(), env);
Environment& env = GetEnv();
return onnxruntime::make_unique<PyTrainingSession>(env, GetDefaultCPUSessionOptions());
}))
.def("finalize", [](py::object) {
#if defined(USE_NCCL)
Expand All @@ -224,37 +232,37 @@ void addObjectMethodsForTraining(py::module& m) {
#endif
#endif
})
.def("load_model", [](onnxruntime::training::TrainingSession* sess, const std::string& path, TrainingParameters& parameters) {
OrtPybindThrowIfError(sess->Load(path));
.def("load_model", [](PyTrainingSession* sess, const std::string& path, TrainingParameters& parameters) {
OrtPybindThrowIfError(sess->GetSessionHandle()->Load(path));

#if defined(USE_NCCL)
CopyMPIContextToTrainingParameters(parameters, sess->GetLogger());
CopyMPIContextToTrainingParameters(parameters, sess->GetSessionHandle()->GetLogger());
#endif
const auto config_result = ConfigureSessionForTraining(sess, parameters);
const auto config_result = ConfigureSessionForTraining(static_cast<TrainingSession*>(sess->GetSessionHandle()), parameters);

std::vector<std::string> provider_types = {};
InitializeSession(sess, provider_types);
InitializeSession(sess->GetSessionHandle(), provider_types);

return config_result;
})
.def("read_bytes", [](onnxruntime::training::TrainingSession* sess, const py::bytes& serialized_model, TrainingParameters& parameters) {
.def("read_bytes", [](PyTrainingSession* sess, const py::bytes& serialized_model, TrainingParameters& parameters) {
std::istringstream buffer(serialized_model);
OrtPybindThrowIfError(sess->Load(buffer));
OrtPybindThrowIfError(sess->GetSessionHandle()->Load(buffer));

#if defined(USE_NCCL)
CopyMPIContextToTrainingParameters(parameters, sess->GetLogger());
CopyMPIContextToTrainingParameters(parameters, sess->GetSessionHandle()->GetLogger());
#endif
const auto config_result = ConfigureSessionForTraining(sess, parameters);
const auto config_result = ConfigureSessionForTraining(static_cast<TrainingSession*>(sess->GetSessionHandle()), parameters);

std::vector<std::string> provider_types = {};
InitializeSession(sess, provider_types);
InitializeSession(sess->GetSessionHandle(), provider_types);

return config_result;
})
.def("get_state", [](onnxruntime::training::TrainingSession* sess) {
.def("get_state", [](PyTrainingSession* sess) {
NameMLValMap state_tensors;
ORT_THROW_IF_ERROR(sess->GetStateTensors(state_tensors));
auto& data_transfer_manager = sess->GetDataTransferManager();
ORT_THROW_IF_ERROR(static_cast<TrainingSession*>(sess->GetSessionHandle())->GetStateTensors(state_tensors));
auto& data_transfer_manager = sess->GetSessionHandle()->GetDataTransferManager();
//convert to numpy array
std::map<std::string, py::object> rmap;
for (auto& kv : state_tensors) {
Expand All @@ -269,11 +277,11 @@ void addObjectMethodsForTraining(py::module& m) {
}
return rmap;
})
.def("load_state", [](onnxruntime::training::TrainingSession* sess, std::unordered_map<std::string, py::object>& state, bool strict) {
.def("load_state", [](PyTrainingSession* sess, std::unordered_map<std::string, py::object>& state, bool strict) {
NameMLValMap state_tensors;
for (auto initializer : state) {
OrtValue ml_value;
auto px = sess->GetModelInputs();
auto px = sess->GetSessionHandle()->GetModelInputs();
if (!px.first.IsOK() || !px.second) {
throw std::runtime_error("Either failed to get model inputs from the session object or the input def list was null");
}
Expand All @@ -293,10 +301,10 @@ void addObjectMethodsForTraining(py::module& m) {
}
state_tensors.insert(std::make_pair(initializer.first, ml_value));
}
ORT_THROW_IF_ERROR(sess->SetStateTensors(state_tensors, strict));
ORT_THROW_IF_ERROR(static_cast<TrainingSession*>(sess->GetSessionHandle())->SetStateTensors(state_tensors, strict));
})
.def("is_output_fp32_node", [](onnxruntime::training::TrainingSession* sess, const std::string& output_name) {
return sess->IsGraphOutputFp32Node(output_name);
.def("is_output_fp32_node", [](PyTrainingSession* sess, const std::string& output_name) {
return static_cast<TrainingSession*>(sess->GetSessionHandle())->IsGraphOutputFp32Node(output_name);
});
}
} // namespace python
Expand Down

0 comments on commit 7045910

Please sign in to comment.