Skip to content

Commit

Permalink
Allow sharing of initializers between sessions. (microsoft#5092)
Browse files Browse the repository at this point in the history
* Allow sharing of initializers between sessions.

* Allow sharing of initializers between sessions (2).

* Add test for C#

* Add test for C#; address PR comments

* Address PR comments
Moved AddInitializer logic to internal session options
Added tests for owned buffer
Clarified documentation
Fix bug where memory info and not device was getting compared

* Fix test

* Fix training build

* Add ver 5 end marker and ver 6 starter, add scenario and usage examples.
  • Loading branch information
pranavsharma authored Sep 21, 2020
1 parent e0719a1 commit 974b9bf
Show file tree
Hide file tree
Showing 23 changed files with 403 additions and 47 deletions.
7 changes: 6 additions & 1 deletion csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ public struct OrtApi
public IntPtr TensorAt;
public IntPtr CreateAndRegisterAllocator;
public IntPtr SetLanguageProjection;
public IntPtr AddInitializer;
}

internal static class NativeMethods
Expand Down Expand Up @@ -238,7 +239,8 @@ static NativeMethods()
OrtSetSessionGraphOptimizationLevel = (DOrtSetSessionGraphOptimizationLevel)Marshal.GetDelegateForFunctionPointer(api_.SetSessionGraphOptimizationLevel, typeof(DOrtSetSessionGraphOptimizationLevel));
OrtRegisterCustomOpsLibrary = (DOrtRegisterCustomOpsLibrary)Marshal.GetDelegateForFunctionPointer(api_.RegisterCustomOpsLibrary, typeof(DOrtRegisterCustomOpsLibrary));
OrtAddSessionConfigEntry = (DOrtAddSessionConfigEntry)Marshal.GetDelegateForFunctionPointer(api_.AddSessionConfigEntry, typeof(DOrtAddSessionConfigEntry));

OrtAddInitializer = (DOrtAddInitializer)Marshal.GetDelegateForFunctionPointer(api_.AddInitializer, typeof(DOrtAddInitializer));

OrtCreateRunOptions = (DOrtCreateRunOptions)Marshal.GetDelegateForFunctionPointer(api_.CreateRunOptions, typeof(DOrtCreateRunOptions));
OrtReleaseRunOptions = (DOrtReleaseRunOptions)Marshal.GetDelegateForFunctionPointer(api_.ReleaseRunOptions, typeof(DOrtReleaseRunOptions));
OrtRunOptionsSetRunLogVerbosityLevel = (DOrtRunOptionsSetRunLogVerbosityLevel)Marshal.GetDelegateForFunctionPointer(api_.RunOptionsSetRunLogVerbosityLevel, typeof(DOrtRunOptionsSetRunLogVerbosityLevel));
Expand Down Expand Up @@ -549,6 +551,9 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca
public delegate IntPtr /*(OrtStatus*)*/DOrtRegisterCustomOpsLibrary(IntPtr /*(OrtSessionOptions*) */ options, string /*(const char*)*/ library_path, out IntPtr /* (void**) */ library_handle);
public static DOrtRegisterCustomOpsLibrary OrtRegisterCustomOpsLibrary;

public delegate IntPtr /*(OrtStatus*)*/DOrtAddInitializer(IntPtr /*(OrtSessionOptions*) */ options, string /*(const char*)*/ name, IntPtr /* OrtValue* */ ort_value);
public static DOrtAddInitializer OrtAddInitializer;

#endregion

#region RunOptions API
Expand Down
17 changes: 16 additions & 1 deletion csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public class SessionOptions : SafeHandle
/// Constructs an empty SessionOptions
/// </summary>
public SessionOptions()
:base(IntPtr.Zero, true)
: base(IntPtr.Zero, true)
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSessionOptions(out handle));
}
Expand Down Expand Up @@ -175,6 +175,21 @@ public void RegisterCustomOpLibrary(string libraryPath)
NativeApiStatus.VerifySuccess(NativeMethods.OrtRegisterCustomOpsLibrary(handle, libraryPath, out libraryHandle));
}

/// <summary>
/// Add a pre-allocated initializer to a session. If a model contains an initializer with a name
/// that is same as the name passed to this API call, ORT will use this initializer instance
/// instead of deserializing one from the model file. This is useful when you want to share
/// the same initializer across sessions.
/// \param name name of the initializer
/// \param val OrtValue containing the initializer. Lifetime of 'val' and the underlying initializer buffer must be
/// managed by the user (created using the CreateTensorWithDataAsOrtValue API) and it must outlive the session object
/// to which it is added.
/// </summary>
public void AddInitializer(string name, OrtValue ort_value)
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtAddInitializer(handle, name, ort_value.Handle));
}

public void AddSessionConfigEntry(string configKey, string configValue)
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtAddSessionConfigEntry(handle, configKey, configValue));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
<OutputPath>bin\$(Configuration)\</OutputPath>
<PackageName Condition="'$(PackageName)' == ''">Microsoft.ML.OnnxRuntime</PackageName>
<IsLinuxBuild Condition="'$(IsLinuxBuild)' == ''">false</IsLinuxBuild>
<AllowUnsafeBlocks>True</AllowUnsafeBlocks>
</PropertyGroup>

<PropertyGroup Condition="'$(IsLinuxBuild)'=='true'">
Expand Down
86 changes: 77 additions & 9 deletions csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,9 @@ private void CanRunInferenceOnAModel(GraphOptimizationLevel graphOptimizationLev
}
}


float[] expectedOutput = LoadTensorFromFile(@"bench.expected_out");
int[] expectedDimensions = { 1, 1000, 1, 1 }; // hardcoded for now for the test data
// Run inference with named inputs and named outputs
{
// correct pre-allocated outputs
Expand All @@ -276,7 +279,7 @@ private void CanRunInferenceOnAModel(GraphOptimizationLevel graphOptimizationLev
NamedOnnxValue.CreateFromTensor("softmaxout_1", new DenseTensor<float>(expectedOutputDimensions))
};
session.Run(container, expectedOutputValues);
validateRunResultData(expectedOutputValues[0].AsTensor<float>());
validateRunResultData(expectedOutputValues[0].AsTensor<float>(), expectedOutput, expectedDimensions);
}

// Run inference with pinned inputs and named outputs
Expand All @@ -291,7 +294,7 @@ private void CanRunInferenceOnAModel(GraphOptimizationLevel graphOptimizationLev
NamedOnnxValue.CreateFromTensor("softmaxout_1", new DenseTensor<float>(expectedOutputDimensions))
};
session.Run(inputNames, pinnedInputs, expectedOutputValues);
validateRunResultData(expectedOutputValues[0].AsTensor<float>());
validateRunResultData(expectedOutputValues[0].AsTensor<float>(), expectedOutput, expectedDimensions);
}

// Run inference with named inputs and pinned outputs
Expand All @@ -302,7 +305,7 @@ private void CanRunInferenceOnAModel(GraphOptimizationLevel graphOptimizationLev
var outputTensor = new DenseTensor<float>(expectedOutputDimensions);
pinnedOutputs.Add(FixedBufferOnnxValue.CreateFromTensor(outputTensor));
session.Run(container, expectedOutputNames, pinnedOutputs);
validateRunResultData(outputTensor);
validateRunResultData(outputTensor, expectedOutput, expectedDimensions);
}
}

Expand All @@ -317,7 +320,7 @@ private void CanRunInferenceOnAModel(GraphOptimizationLevel graphOptimizationLev
pinnedOutputs.Add(FixedBufferOnnxValue.CreateFromTensor(outputTensor));

session.Run(inputNames, pinnedInputs, expectedOutputNames, pinnedOutputs);
validateRunResultData(outputTensor);
validateRunResultData(outputTensor, expectedOutput, expectedDimensions);
}
}
}
Expand Down Expand Up @@ -371,15 +374,14 @@ private void validateRunResults(IReadOnlyCollection<NamedOnnxValue> results)
Assert.Equal(1, results.Count);
Assert.Equal("softmaxout_1", r.Name);

validateRunResultData(r.AsTensor<float>());
float[] expectedOutput = LoadTensorFromFile(@"bench.expected_out");
int[] expectedDimensions = { 1, 1000, 1, 1 }; // hardcoded for now for the test data
validateRunResultData(r.AsTensor<float>(), expectedOutput, expectedDimensions);
}
}

private void validateRunResultData(Tensor<float> resultTensor)
private void validateRunResultData(Tensor<float> resultTensor, float[] expectedOutput, int[] expectedDimensions)
{
float[] expectedOutput = LoadTensorFromFile(@"bench.expected_out");

int[] expectedDimensions = { 1, 1000, 1, 1 }; // hardcoded for now for the test data
Assert.Equal(expectedDimensions.Length, resultTensor.Rank);

var resultDimensions = resultTensor.Dimensions;
Expand Down Expand Up @@ -1837,6 +1839,72 @@ private void TestIOBinding()
}
}

[Fact]
private void TestWeightSharingBetweenSessions()
{
string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "mul_1.onnx");

// create initializer to share
var ortCpuMemInfo = OrtMemoryInfo.DefaultInstance;
var dims = new long[] { 3, 2 };
var dataBuffer = new float[] { 1.0F, 2.0F, 3.0F, 4.0F, 5.0F, 6.0F };
var allocator = OrtAllocator.DefaultInstance;
var ortAllocationInput = allocator.Allocate((uint)dataBuffer.Length * sizeof(float));
unsafe
{
float* p = (float*)ortAllocationInput.DangerousGetHandle();
for (int i = 0; i < dataBuffer.Length; ++i)
{
*p++ = dataBuffer[i];
}
}
var dataBufferNumBytes = (uint)dataBuffer.Length * sizeof(float);
var sharedInitializer = OrtValue.CreateTensorValueWithData(ortCpuMemInfo, Tensors.TensorElementType.Float,
dims, ortAllocationInput.DangerousGetHandle(), dataBufferNumBytes);

SessionOptions options = new SessionOptions();
options.AddInitializer("W", sharedInitializer);

float[] expectedOutput = { 1.0F, 4.0F, 9.0F, 16.0F, 25.0F, 36.0F };
int[] expectedDimensions = { 3, 2 };

using (var session = new InferenceSession(modelPath, options))
using (var session2 = new InferenceSession(modelPath, options))
{
var inputMeta = session.InputMetadata;
var container = new List<NamedOnnxValue>();

foreach (var name in inputMeta.Keys)
{
Assert.Equal(typeof(float), inputMeta[name].ElementType);
Assert.True(inputMeta[name].IsTensor);
var tensor = new DenseTensor<float>(dataBuffer, inputMeta[name].Dimensions);
container.Add(NamedOnnxValue.CreateFromTensor<float>(name, tensor));
}

ReadOnlySpan<int> expectedOutputDimensions = new int[] { 1, 1000, 1, 1 };
string[] expectedOutputNames = new string[] { "Y" };

// Run inference with named inputs and outputs created with in Run()
using (var results = session.Run(container)) // results is an IReadOnlyList<NamedOnnxValue> container
{
foreach (var r in results)
{
validateRunResultData(r.AsTensor<float>(), expectedOutput, expectedDimensions);
}
}

// Run inference with named inputs and outputs created with in Run()
using (var results2 = session2.Run(container)) // results is an IReadOnlyList<NamedOnnxValue> container
{
foreach (var r in results2)
{
validateRunResultData(r.AsTensor<float>(), expectedOutput, expectedDimensions);
}
}
}
}

[DllImport("kernel32", SetLastError = true)]
static extern IntPtr LoadLibrary(string lpFileName);

Expand Down
Binary file added csharp/testdata/mul_1.onnx
Binary file not shown.
5 changes: 5 additions & 0 deletions docs/C_API.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ chooses to override this by setting ```session_state.use_env_allocators``` to "0
* Set ```session.use_env_allocators``` to "1" for each session that wants to use the env registered allocators.
* See test ```TestSharedAllocatorUsingCreateAndRegisterAllocator``` in
onnxruntime/test/shared_lib/test_inference.cc for an example.
* **Share initializer(s) between sessions:**
* *Description*: This feature allows a user to share the same instance of an initializer across
multiple sessions.
* *Scenario*: You've several models that use the same set of initializers except the last few layers of the model and you load these models in the same process. When every model (session) creates a separate instance of the same initializer, it leads to excessive and wasteful memory usage since in this case it's the same initializer. You want to optimize memory usage while having the flexibility to allocate the initializers (possibly even store them in shared memory).
* *Example Usage*: Use the ```AddInitializer``` API to add a pre-allocated initializer to session options before calling ```CreateSession```. Use the same instance of session options to create several sessions allowing the initializer(s) to be shared between the sessions. See [C API sample usage (TestSharingOfInitializer)](../onnxruntime/test/shared_lib/test_inference.cc) and [C# API sample usage (TestWeightSharingBetweenSessions)](../csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs).

## Usage Overview

Expand Down
4 changes: 4 additions & 0 deletions include/onnxruntime/core/framework/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,10 @@ class Tensor final {
return static_cast<char*>(p_data_) + byte_offset_;
}

bool OwnsBuffer() const noexcept {
return buffer_deleter_ != nullptr;
}

/**
* Resizes the tensor without touching underlying storage.
* This requires the total size of the tensor to remains constant.
Expand Down
13 changes: 13 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1051,6 +1051,19 @@ struct OrtApi {
* Prefer a value of 0 if your CPU usage is very high.
*/
ORT_API2_STATUS(SetGlobalSpinControl, _Inout_ OrtThreadingOptions* tp_options, int allow_spinning);

/**
* Add a pre-allocated initializer to a session. If a model contains an initializer with a name
* that is same as the name passed to this API call, ORT will use this initializer instance
* instead of deserializing one from the model file. This is useful when you want to share
* the same initializer across sessions.
* \param name name of the initializer
* \param val OrtValue containing the initializer. Lifetime of 'val' and the underlying initializer buffer must be
* managed by the user (created using the CreateTensorWithDataAsOrtValue API) and it must outlive the session object
* to which it is added.
*/
ORT_API2_STATUS(AddInitializer, _Inout_ OrtSessionOptions* options, _In_z_ const char* name,
_In_ const OrtValue* val);
};

/*
Expand Down
3 changes: 2 additions & 1 deletion include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,8 @@ struct SessionOptions : Base<OrtSessionOptions> {
SessionOptions& DisablePerSessionThreads();

SessionOptions& AddConfigEntry(const char* config_key, const char* config_value);

SessionOptions& AddInitializer(const char* name, const OrtValue* ort_val);
};

struct ModelMetadata : Base<OrtModelMetadata> {
Expand Down Expand Up @@ -330,7 +332,6 @@ struct TypeInfo : Base<OrtTypeInfo> {
Unowned<SequenceTypeInfo> GetSequenceTypeInfo() const;
Unowned<MapTypeInfo> GetMapTypeInfo() const;


ONNXType GetONNXType() const;
};

Expand Down
8 changes: 8 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,11 @@ inline SessionOptions& SessionOptions::AddConfigEntry(const char* config_key, co
return *this;
}

inline SessionOptions& SessionOptions::AddInitializer(const char* name, const OrtValue* ort_val) {
ThrowOnError(GetApi().AddInitializer(p_, name, ort_val));
return *this;
}

inline Session::Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options) {
ThrowOnError(GetApi().CreateSession(env, model_path, options, &p_));
}
Expand Down Expand Up @@ -927,4 +932,7 @@ inline std::vector<std::string> GetAvailableProviders() {
ThrowOnError(api.ReleaseAvailableProviders(providers, len));
return available_providers;
}

SessionOptions& AddInitializer(const char* name, const OrtValue* ort_val);

} // namespace Ort
28 changes: 28 additions & 0 deletions onnxruntime/core/framework/session_options.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "core/framework/session_options.h"
#include "core/common/logging/logging.h"
#include "core/framework/ml_value.h"

namespace onnxruntime {

Expand Down Expand Up @@ -45,4 +46,31 @@ Status SessionOptions::AddConfigEntry(const char* config_key, const char* config

return Status::OK();
}

Status SessionOptions::AddInitializer(const char* name, const OrtValue* val) noexcept {
// input validation
if (name == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Received nullptr for name.");
}

if (val == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Received nullptr for OrtValue.");
}

if (!val->IsTensor()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Received OrtValue is not a tensor. Only tensors are supported.");
}

if (val->Get<Tensor>().OwnsBuffer()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Buffer containing the initializer must be owned by the user.");
}

// now do the actual work
auto rc = initializers_to_share_map.insert({name, val});
if (!rc.second) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "An OrtValue for this name has already been added.");
}

return Status::OK();
}
} // namespace onnxruntime
4 changes: 4 additions & 0 deletions onnxruntime/core/framework/session_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ struct SessionOptions {
// The configuration keys and value formats are defined in
// /include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h
std::unordered_map<std::string, std::string> session_configurations;
std::unordered_map<std::string, const OrtValue*> initializers_to_share_map;

// See onnxruntime_c_api.h for detailed documentation.
Status AddInitializer(const char* name, const OrtValue* val) noexcept;

// Check if the given SessionOptions has a config using the given config_key.
// Returns true if found and copies the value into config_value.
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/framework/session_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -852,7 +852,7 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string<PATH_CHAR_
[this](int idx, const OrtValue& value, const OrtCallback& d, bool constant) -> Status {
return AddInitializedTensor(idx, value, &d, constant);
},
logger_, data_transfer_mgr_));
logger_, data_transfer_mgr_, *p_seq_exec_plan_.get(), session_options));

// remove weights from the graph now to save memory but in many cases it won't save memory, if the tensor was
// preallocated with the some other tensors in a single 'allocate' call, which is very common.
Expand Down
Loading

0 comments on commit 974b9bf

Please sign in to comment.