Skip to content

[Android] Add Java runtime API for registered ops and backends #11042

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion extension/android/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ set(executorch_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../lib/cmake/ExecuTorch)
find_package(executorch CONFIG REQUIRED)
target_link_options_shared_lib(executorch)

add_library(executorch_jni SHARED jni/jni_layer.cpp jni/log.cpp)
add_library(executorch_jni SHARED jni/jni_layer.cpp jni/log.cpp jni/jni_layer_runtime.cpp)

set(link_libraries)
list(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

package org.pytorch.executorch;

import static org.junit.Assert.assertNotNull;

import androidx.test.ext.junit.runners.AndroidJUnit4;
import org.junit.runner.RunWith;
import org.junit.Test;

/** Unit tests for {@link ExecuTorchRuntime}. */
@RunWith(AndroidJUnit4.class)
public class RuntimeInstrumentationTest {

@Test
public void testRuntimeApi() {
String[] ops = ExecuTorchRuntime.getRegisteredOps();
String[] backends = ExecuTorchRuntime.getRegisteredBackends();

assertNotNull(ops);
assertNotNull(backends);

for (String op : ops) {
assertNotNull(op);
}

for (String backend : backends) {
assertNotNull(backend);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

package org.pytorch.executorch;

import com.facebook.jni.annotations.DoNotStrip;
import com.facebook.soloader.nativeloader.NativeLoader;
import com.facebook.soloader.nativeloader.SystemDelegate;

Expand All @@ -30,4 +31,12 @@ private ExecuTorchRuntime() {}
public static ExecuTorchRuntime getRuntime() {
return sInstance;
}

/** Get all registered ops. */
@DoNotStrip
public static native String[] getRegisteredOps();

/** Get all registered backends. */
@DoNotStrip
public static native String[] getRegisteredBackends();
}
9 changes: 7 additions & 2 deletions extension/android/jni/BUCK
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ non_fbcode_target(_kind = executorch_generated_lib,

non_fbcode_target(_kind = fb_android_cxx_library,
name = "executorch_jni",
srcs = ["jni_layer.cpp", "log.cpp"],
srcs = ["jni_layer.cpp", "log.cpp", "jni_layer_runtime.cpp"],
allow_jni_merging = False,
compiler_flags = ET_JNI_COMPILER_FLAGS,
soname = "libexecutorch.$(ext)",
Expand All @@ -49,7 +49,7 @@ non_fbcode_target(_kind = fb_android_cxx_library,

non_fbcode_target(_kind = fb_android_cxx_library,
name = "executorch_jni_full",
srcs = ["jni_layer.cpp", "log.cpp"],
srcs = ["jni_layer.cpp", "log.cpp", "jni_layer_runtime.cpp"],
allow_jni_merging = False,
compiler_flags = ET_JNI_COMPILER_FLAGS,
soname = "libexecutorch.$(ext)",
Expand All @@ -74,6 +74,7 @@ non_fbcode_target(_kind = fb_android_cxx_library,
srcs = [
"jni_layer.cpp",
"jni_layer_llama.cpp",
"jni_layer_runtime.cpp",
],
allow_jni_merging = False,
compiler_flags = ET_JNI_COMPILER_FLAGS + [
Expand Down Expand Up @@ -113,6 +114,10 @@ runtime.export_file(
name = "jni_layer.cpp",
)

runtime.export_file(
name = "jni_layer_runtime.cpp",
)

runtime.cxx_library(
name = "jni_headers",
exported_headers = [
Expand Down
2 changes: 2 additions & 0 deletions extension/android/jni/jni_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -491,9 +491,11 @@ extern void register_natives_for_llm();
// No op if we don't build LLM
void register_natives_for_llm() {}
#endif
extern void register_natives_for_runtime();
JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) {
return facebook::jni::initialize(vm, [] {
executorch::extension::ExecuTorchJni::registerNatives();
register_natives_for_llm();
register_natives_for_runtime();
});
}
72 changes: 72 additions & 0 deletions extension/android/jni/jni_layer_runtime.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <fbjni/fbjni.h>
#include <jni.h>

#include <executorch/runtime/backend/interface.h>
#include <executorch/runtime/kernel/operator_registry.h>

namespace executorch_jni {
namespace runtime = ::executorch::ET_RUNTIME_NAMESPACE;

class AndroidRuntimeJni : public facebook::jni::JavaClass<AndroidRuntimeJni> {
public:
constexpr static const char* kJavaDescriptor =
"Lorg/pytorch/executorch/ExecuTorchRuntime;";

static void registerNatives() {
javaClassStatic()->registerNatives({
makeNativeMethod(
"getRegisteredOps", AndroidRuntimeJni::getRegisteredOps),
makeNativeMethod(
"getRegisteredBackends", AndroidRuntimeJni::getRegisteredBackends),
});
}

// Returns a string array of all registered ops
static facebook::jni::local_ref<facebook::jni::JArrayClass<jstring>>
getRegisteredOps(facebook::jni::alias_ref<jclass>) {
auto kernels = runtime::get_registered_kernels();
auto result = facebook::jni::JArrayClass<jstring>::newArray(kernels.size());

for (size_t i = 0; i < kernels.size(); ++i) {
auto op = facebook::jni::make_jstring(kernels[i].name_);
result->setElement(i, op.get());
}

return result;
}

// Returns a string array of all registered backends
static facebook::jni::local_ref<facebook::jni::JArrayClass<jstring>>
getRegisteredBackends(facebook::jni::alias_ref<jclass>) {
int num_backends = runtime::get_num_registered_backends();
auto result = facebook::jni::JArrayClass<jstring>::newArray(num_backends);

for (int i = 0; i < num_backends; ++i) {
auto name_result = runtime::get_backend_name(i);
const char* name = "";

if (name_result.ok()) {
name = *name_result;
}

auto backend_str = facebook::jni::make_jstring(name);
result->setElement(i, backend_str.get());
}

return result;
}
};

} // namespace executorch_jni

void register_natives_for_runtime() {
executorch_jni::AndroidRuntimeJni::registerNatives();
}
4 changes: 3 additions & 1 deletion extension/android/jni/selective_jni.buck.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@ load("@fbsource//xplat/executorch/backends/xnnpack/third-party:third_party_libs.
load("@fbsource//xplat/executorch/extension/android/jni:build_defs.bzl", "ET_JNI_COMPILER_FLAGS")

def selective_jni_target(name, deps, srcs = [], soname = "libexecutorch.$(ext)"):
non_fbcode_target(_kind = fb_android_cxx_library,
non_fbcode_target(
_kind = fb_android_cxx_library,
name = name,
srcs = [
"//xplat/executorch/extension/android/jni:jni_layer.cpp",
"//xplat/executorch/extension/android/jni:jni_layer_runtime.cpp",
] + srcs,
allow_jni_merging = False,
compiler_flags = ET_JNI_COMPILER_FLAGS,
Expand Down
Loading