Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Java and Objective-C bindings for RegisterCustomOpsUsingFunction. #14256

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions java/src/main/java/ai/onnxruntime/OrtSession.java
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,28 @@ public void registerCustomOpLibrary(String path) throws OrtException {
customLibraryHandles.add(customHandle);
}

/**
* Registers custom ops for use with {@link OrtSession}s using this SessionOptions by calling
* the specified native function name. The custom ops library must either be linked against, or
* have previously been loaded by the user.
*
* <p>The registration function must have the signature:
*
* <p>&emsp;OrtStatus* (*fn)(OrtSessionOptions* options, const OrtApiBase* api);
*
* <p>See https://onnxruntime.ai/docs/reference/operators/add-custom-op.html for more
* information on custom ops. See
* https://github.com/microsoft/onnxruntime/blob/342a5bf2b756d1a1fc6fdc582cfeac15182632fe/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc#L115
* for an example of a custom op library registration function.
*
* @param registrationFuncName The name of the registration function to call.
* @throws OrtException If there was an error finding or calling the registration function.
*/
public void registerCustomOpsUsingFunction(String registrationFuncName) throws OrtException {
checkClosed();
registerCustomOpsUsingFunction(OnnxRuntime.ortApiHandle, nativeHandle, registrationFuncName);
}

/**
* Sets the value of a symbolic dimension. Fixed dimension computations may have more
* optimizations applied to them.
Expand Down Expand Up @@ -1039,6 +1061,9 @@ private native void setSessionLogVerbosityLevel(long apiHandle, long nativeHandl
private native long registerCustomOpLibrary(long apiHandle, long nativeHandle, String path)
throws OrtException;

private native void registerCustomOpsUsingFunction(
long apiHandle, long nativeHandle, String registrationFuncName) throws OrtException;

private native void closeCustomLibraries(long[] nativeHandle);

private native void closeOptions(long apiHandle, long nativeHandle);
Expand Down
20 changes: 20 additions & 0 deletions java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,26 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_regis
return (jlong) libraryHandle;
}

/*
* Class: ai_onnxruntime_OrtSession_SessionOptions
* Method: registerCustomOpsUsingFunction
* Signature: (JJLjava/lang/String;)V
*/
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_registerCustomOpsUsingFunction
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jstring functionName) {
(void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;

// Extract the string chars
const char* cFuncName = (*jniEnv)->GetStringUTFChars(jniEnv, functionName, NULL);

// Register the custom ops by calling the function
checkOrtStatus(jniEnv,api,api->RegisterCustomOpsUsingFunction((OrtSessionOptions*)optionsHandle,cFuncName));

// Release the string chars
(*jniEnv)->ReleaseStringUTFChars(jniEnv,functionName,cFuncName);
}

/*
* Class: ai_onnxruntime_OrtSession_SessionOptions
* Method: closeCustomLibraries
Expand Down
81 changes: 68 additions & 13 deletions java/src/test/java/ai/onnxruntime/InferenceTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,24 @@ private static Map<String, String> getSkippedModels() {
return skipModels;
}

private static String getCustomOpLibraryName() {
String customLibraryName = "";
String osName = System.getProperty("os.name").toLowerCase();
if (osName.contains("windows")) {
// In windows we start in the wrong working directory relative to the custom_op_library.dll
// So we look it up as a classpath resource and resolve it to a real path
customLibraryName = TestHelpers.getResourcePath("/custom_op_library.dll").toString();
} else if (osName.contains("mac")) {
customLibraryName = TestHelpers.getResourcePath("/libcustom_op_library.dylib").toString();
} else if (osName.contains("linux")) {
customLibraryName = TestHelpers.getResourcePath("/libcustom_op_library.so").toString();
} else {
fail("Unknown os/platform '" + osName + "'");
}

return customLibraryName;
}

public static List<String[]> getModelsForTest() throws IOException {
File modelsDir = getTestModelsDir();
Map<String, String> skipModels = getSkippedModels();
Expand Down Expand Up @@ -1016,19 +1034,7 @@ public void testExtraSessionOptions() throws OrtException, IOException {
public void testLoadCustomLibrary() throws OrtException {
// This test is disabled on Android.
if (!OnnxRuntime.isAndroid()) {
String customLibraryName = "";
String osName = System.getProperty("os.name").toLowerCase();
if (osName.contains("windows")) {
// In windows we start in the wrong working directory relative to the custom_op_library.dll
// So we look it up as a classpath resource and resolve it to a real path
customLibraryName = TestHelpers.getResourcePath("/custom_op_library.dll").toString();
} else if (osName.contains("mac")) {
customLibraryName = TestHelpers.getResourcePath("/libcustom_op_library.dylib").toString();
} else if (osName.contains("linux")) {
customLibraryName = TestHelpers.getResourcePath("/libcustom_op_library.so").toString();
} else {
fail("Unknown os/platform '" + osName + "'");
}
String customLibraryName = getCustomOpLibraryName();
String customOpLibraryTestModel =
TestHelpers.getResourcePath("/custom_op_library/custom_op_test.onnx").toString();

Expand Down Expand Up @@ -1073,6 +1079,55 @@ public void testLoadCustomLibrary() throws OrtException {
}
}

@Test
public void testLoadCustomOpsUsingFunction() throws OrtException {
// This test is disabled on Android.
if (!OnnxRuntime.isAndroid()) {
String customLibraryName = getCustomOpLibraryName();
String customOpLibraryTestModel =
TestHelpers.getResourcePath("/custom_op_library/custom_op_test.onnx").toString();

try (SessionOptions options = new SessionOptions()) {
String osName = System.getProperty("os.name").toLowerCase();
boolean isWindows = osName.contains("windows");
boolean isMac = osName.contains("mac");

// on Windows and mac, Java.System.load will make the symbols from the loaded library
// available.
// on other platforms the dlsym uses RTLD_LOCAL so they're not. Would need to use something
// like
// https://github.com/java-native-access/jna to achieve that.
// As we have unit tests that validate the custom op registration across all platforms, we
// settle for just
// making sure the ORT API function can be called and behaves as expected.
try {
// manually load the library. typically we'd expect the user to link against the library,
// but doing that here would conflict with testLoadCustomLibrary needing to test ORT
// loading
// the library.
System.load(customLibraryName);
options.registerCustomOpsUsingFunction("RegisterCustomOps");

if (isWindows || isMac) {
if (OnnxRuntime.extractCUDA()) {
options.addCUDA();
}
try (OrtSession session = env.createSession(customOpLibraryTestModel, options)) {
// if model was loaded the op registration was successful
}
} else {
fail("Expected to throw OrtException due System.load not using RTLD_GLOBAL");
}
} catch (OrtException e) {
System.out.println(e.getMessage());
assertTrue(
!(isWindows || isMac), "Expected to not throw OrtException on Windows or macOS");
assertTrue(e.getMessage().contains("Failed to get symbol RegisterCustomOps"));
}
}
}
}

@Test
public void testModelMetadata() throws OrtException {
String modelPath =
Expand Down
30 changes: 28 additions & 2 deletions objectivec/include/ort_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,14 @@ NS_ASSUME_NONNULL_BEGIN
- (nullable instancetype)initWithError:(NSError**)error NS_SWIFT_NAME(init());

/**
* Appends an execution provider to the session options to enable the execution provider to be used when running
* the model.
*
* Available since 1.14.
* Appends an execution provider to the session configuration options.
* The execution provider list is ordered by decreasing priority
*
* The execution provider list is ordered by decreasing priority.
* i.e. the first provider registered has the highest priority.
*
* @param providerName Provider name. For example, "xnnpack".
* @param providerOptions Provider-specific options. For example, for provider "xnnpack", {"intra_op_num_threads": "2"}.
* @param error Optional error information set if an error occurs.
Expand Down Expand Up @@ -183,6 +188,27 @@ NS_ASSUME_NONNULL_BEGIN
value:(NSString*)value
error:(NSError**)error;

/**
* Registers custom ops for use with `ORTSession`s using this SessionOptions by calling the specified
* native function name. The custom ops library must either be linked against, or have previously been loaded
* by the user.
*
* Available since 1.14.
*
* The registration function must have the signature:
* OrtStatus* (*fn)(OrtSessionOptions* options, const OrtApiBase* api);
*
* See https://onnxruntime.ai/docs/reference/operators/add-custom-op.html for more information on custom ops.
* See https://github.com/microsoft/onnxruntime/blob/342a5bf2b756d1a1fc6fdc582cfeac15182632fe/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc#L115
* for an example of a custom op library registration function.
skottmckay marked this conversation as resolved.
Show resolved Hide resolved
*
* @param registration_func_name The name of the registration function to call.
* @param error Optional error information set if an error occurs.
* @return Whether the registration function was successfully called.
*/
- (BOOL)registerCustomOpsUsingFunction:(NSString*)registrationFuncName
error:(NSError**)error;

@end

/**
Expand Down
13 changes: 11 additions & 2 deletions objectivec/src/ort_session.mm
Original file line number Diff line number Diff line change
Expand Up @@ -219,12 +219,12 @@ - (BOOL)appendExecutionProvider:(NSString*)providerName
try {
std::unordered_map<std::string, std::string> options;
NSArray* keys = [providerOptions allKeys];

for (NSString* key in keys) {
NSString* value = [providerOptions objectForKey:key];
options.emplace(key.UTF8String, value.UTF8String);
}

_sessionOptions->AppendExecutionProvider(providerName.UTF8String, options);
return YES;
}
Expand Down Expand Up @@ -287,6 +287,15 @@ - (BOOL)addConfigEntryWithKey:(NSString*)key
ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error)
}

- (BOOL)registerCustomOpsUsingFunction:(NSString*)registrationFuncName
error:(NSError**)error {
try {
_sessionOptions->RegisterCustomOpsUsingFunction(registrationFuncName.UTF8String);
return YES;
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error)
}

#pragma mark - Internal

- (Ort::SessionOptions&)CXXAPIOrtSessionOptions {
Expand Down