Skip to content

Commit 8a2704f

Browse files
committed
input and output tag API
1 parent 5268b24 commit 8a2704f

File tree

5 files changed

+79
-7
lines changed

5 files changed

+79
-7
lines changed

extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ class ModuleE2ETest {
6868
inputStream.close()
6969

7070
val module = Module.load(getTestFilePath("/mv3_xnnpack_fp32.pte"))
71-
val expectedBackends = arrayOf("XnnpackBackend")
72-
Assert.assertArrayEquals(expectedBackends, module.getMethodMetadata("forward").getBackends())
71+
72+
Assert.assertArrayEquals(arrayOf("XnnpackBackend"), module.getMethodMetadata("forward").getBackends())
7373
}
7474

7575
@Test

extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ class ModuleInstrumentationTest {
6262

6363
Assert.assertArrayEquals(arrayOf("forward"), module.getMethods())
6464
Assert.assertTrue(module.getMethodMetadata("forward").backends.isEmpty())
65+
Assert.assertArrayEquals(intArrayOf(1, 1, 3), module.getMethodMetadata("forward").inputTags);
66+
Assert.assertArrayEquals(intArrayOf(1), module.getMethodMetadata("forward").outputTags);
6567
}
6668

6769
@Test

extension/android/executorch_android/src/main/java/org/pytorch/executorch/MethodMetadata.java

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@
1111
/** Helper class to access the metadata for a method from a Module */
1212
public class MethodMetadata {
1313
private String mName;
14-
1514
private String[] mBackends;
15+
private int[] mInputTags;
16+
private int[] mOutputTags;
1617

1718
MethodMetadata setName(String name) {
1819
mName = name;
@@ -37,4 +38,28 @@ MethodMetadata setBackends(String[] backends) {
3738
public String[] getBackends() {
3839
return mBackends;
3940
}
41+
42+
/**
43+
* @return Output tags
44+
*/
45+
public int[] getOutputTags() {
46+
return mOutputTags;
47+
}
48+
49+
MethodMetadata setOutputTags(int[] outputTags) {
50+
mOutputTags = outputTags;
51+
return this;
52+
}
53+
54+
/**
55+
* @return Input tags
56+
*/
57+
public int[] getInputTags() {
58+
return mInputTags;
59+
}
60+
61+
MethodMetadata setInputTags(int[] inputTags) {
62+
mInputTags = inputTags;
63+
return this;
64+
}
4065
}

extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,13 @@ Map<String, MethodMetadata> populateMethodMeta() {
6767
Map<String, MethodMetadata> metadata = new HashMap<String, MethodMetadata>();
6868
for (int i = 0; i < methods.length; i++) {
6969
String name = methods[i];
70-
metadata.put(name, new MethodMetadata().setName(name).setBackends(getUsedBackends(name)));
70+
metadata.put(
71+
name,
72+
new MethodMetadata()
73+
.setName(name)
74+
.setBackends(getUsedBackends(name))
75+
.setInputTags(getInputTags(name))
76+
.setOutputTags(getOutputTags(name)));
7177
}
7278

7379
return metadata;
@@ -212,6 +218,12 @@ public String[] readLogBuffer() {
212218
@DoNotStrip
213219
private native String[] readLogBufferNative();
214220

221+
@DoNotStrip
222+
private native int[] getInputTags(String method);
223+
224+
@DoNotStrip
225+
private native int[] getOutputTags(String method);
226+
215227
/**
216228
* Dump the ExecuTorch ETRecord file to /data/local/tmp/result.etdump.
217229
*

extension/android/jni/jni_layer.cpp

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -453,10 +453,11 @@ class ExecuTorchJni : public facebook::jni::HybridClass<ExecuTorchJni> {
453453

454454
facebook::jni::local_ref<facebook::jni::JArrayClass<jstring>> getUsedBackends(
455455
facebook::jni::alias_ref<jstring> methodName) {
456-
auto methodMeta = module_->method_meta(methodName->toStdString()).get();
456+
auto method_meta =
457+
module_->method_meta(methodName->toStdString()).get();
457458
std::unordered_set<std::string> backends;
458-
for (auto i = 0; i < methodMeta.num_backends(); i++) {
459-
backends.insert(methodMeta.get_backend_name(i).get());
459+
for (auto i = 0; i < method_meta.num_backends(); i++) {
460+
backends.insert(method_meta.get_backend_name(i).get());
460461
}
461462

462463
facebook::jni::local_ref<facebook::jni::JArrayClass<jstring>> ret =
@@ -471,6 +472,36 @@ class ExecuTorchJni : public facebook::jni::HybridClass<ExecuTorchJni> {
471472
return ret;
472473
}
473474

475+
facebook::jni::local_ref<facebook::jni::JArrayInt> getInputTags(
476+
facebook::jni::alias_ref<jstring> methodName) {
477+
auto method_meta =
478+
module_->method_meta(methodName->toStdString()).get();
479+
auto num_inputs = method_meta.num_inputs();
480+
facebook::jni::local_ref<facebook::jni::JArrayInt> ret =
481+
facebook::jni::JArrayInt::newArray(num_inputs);
482+
483+
int i = 0;
484+
for (int i = 0; i < num_inputs; i++) {
485+
ret->pin()[i] = static_cast<uint32_t>(method_meta.input_tag(i).get());
486+
}
487+
return ret;
488+
}
489+
490+
facebook::jni::local_ref<facebook::jni::JArrayInt> getOutputTags(
491+
facebook::jni::alias_ref<jstring> methodName) {
492+
auto method_meta =
493+
module_->method_meta(methodName->toStdString()).get();
494+
auto num_outputs = method_meta.num_outputs();
495+
facebook::jni::local_ref<facebook::jni::JArrayInt> ret =
496+
facebook::jni::JArrayInt::newArray(num_outputs);
497+
498+
int i = 0;
499+
for (int i = 0; i < num_outputs; i++) {
500+
ret->pin()[i] = static_cast<uint32_t>(method_meta.output_tag(i).get());
501+
}
502+
return ret;
503+
}
504+
474505
static void registerNatives() {
475506
registerHybrid({
476507
makeNativeMethod("initHybrid", ExecuTorchJni::initHybrid),
@@ -480,6 +511,8 @@ class ExecuTorchJni : public facebook::jni::HybridClass<ExecuTorchJni> {
480511
makeNativeMethod("etdump", ExecuTorchJni::etdump),
481512
makeNativeMethod("getMethods", ExecuTorchJni::getMethods),
482513
makeNativeMethod("getUsedBackends", ExecuTorchJni::getUsedBackends),
514+
makeNativeMethod("getInputTags", ExecuTorchJni::getInputTags),
515+
makeNativeMethod("getOutputTags", ExecuTorchJni::getOutputTags),
483516
});
484517
}
485518
};

0 commit comments

Comments
 (0)