Skip to content

Commit 61e1169

Browse files
fergushendersoncopybara-github
authored andcommitted
Fix a bug where the Play services native API example code was missing the 'java/' directories.
PiperOrigin-RevId: 693019678
1 parent fe0c06d commit 61e1169

File tree

4 files changed

+289
-0
lines changed

4 files changed

+289
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/*
2+
* Copyright 2023 The TensorFlow Authors
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.google.samples.gms.tflite.c.instrumentation;
18+
19+
import static com.google.common.truth.Truth.assertThat;
20+
21+
import android.content.Context;
22+
import android.util.Log;
23+
import androidx.test.core.app.ApplicationProvider;
24+
import androidx.test.ext.junit.runners.AndroidJUnit4;
25+
import com.google.android.gms.tasks.Tasks;
26+
import com.google.android.gms.tflite.java.TfLiteNative;
27+
import com.google.samples.gms.tflite.c.TfLiteJni;
28+
import java.util.concurrent.ExecutionException;
29+
import org.junit.Test;
30+
import org.junit.runner.RunWith;
31+
32+
/** Instrumentation tests for the TFLite Native API. */
33+
@RunWith(AndroidJUnit4.class)
34+
public class BasicScenarioTest {
35+
private static final String TAG = "BasicScenarioTest";
36+
37+
@Test
38+
public void basicScenario() throws ExecutionException, InterruptedException {
39+
Context context = ApplicationProvider.getApplicationContext();
40+
Tasks.await(TfLiteNative.initialize(context));
41+
TfLiteJni jni = new TfLiteJni(message -> Log.e(TAG, message));
42+
43+
jni.loadModel(context.getAssets(), "add.tflite");
44+
float[] output = jni.runInference(new float[] {1.f, 3.f});
45+
jni.destroy();
46+
47+
assertThat(output).isEqualTo(new float[] {3.f, 9.f});
48+
}
49+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
/*
2+
* Copyright 2023 The TensorFlow Authors
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.google.samples.gms.tflite.c.instrumentation;
18+
19+
import static com.google.common.truth.Truth.assertThat;
20+
21+
import android.content.Context;
22+
import android.util.Log;
23+
import androidx.test.core.app.ApplicationProvider;
24+
import androidx.test.ext.junit.runners.AndroidJUnit4;
25+
import com.google.android.gms.tasks.Tasks;
26+
import com.google.android.gms.tflite.client.TfLiteInitializationOptions;
27+
import com.google.android.gms.tflite.gpu.support.TfLiteGpu;
28+
import com.google.android.gms.tflite.java.TfLiteNative;
29+
import com.google.samples.gms.tflite.c.TfLiteJni;
30+
import java.util.concurrent.ExecutionException;
31+
import org.junit.Assume;
32+
import org.junit.Before;
33+
import org.junit.Test;
34+
import org.junit.runner.RunWith;
35+
36+
/** Instrumentation tests for the TFLite Native Acceleration API. */
37+
@RunWith(AndroidJUnit4.class)
38+
public class TfLiteNativeGPUAccelerationTest {
39+
private static final String TAG = "TfLiteNativeGPUAccelerationTest";
40+
41+
private static final TfLiteInitializationOptions options =
42+
TfLiteInitializationOptions.builder().setEnableGpuDelegateSupport(true).build();
43+
44+
private Context context;
45+
46+
@Before
47+
public void setUp() throws ExecutionException, InterruptedException {
48+
context = ApplicationProvider.getApplicationContext();
49+
boolean gpuAvailable = Tasks.await(TfLiteGpu.isGpuDelegateAvailable(context));
50+
Assume.assumeTrue("GPU acceleration is unavailable on this device.", gpuAvailable);
51+
Tasks.await(TfLiteNative.initialize(context, options));
52+
}
53+
54+
@Test
55+
public void doInferenceWithAcceleration() {
56+
TfLiteJni jni = new TfLiteJni(message -> Log.e(TAG, message));
57+
jni.initGpuAcceleration();
58+
jni.loadModel(context.getAssets(), "add.tflite");
59+
float[] output = jni.runInference(new float[] {1.f, 3.f});
60+
jni.destroy();
61+
assertThat(output).isEqualTo(new float[] {3.f, 9.f});
62+
}
63+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
/*
2+
* Copyright 2023 The TensorFlow Authors
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.google.samples.gms.tflite.c;
18+
19+
import android.app.Activity;
20+
import android.os.Bundle;
21+
import android.util.Log;
22+
import android.widget.TextView;
23+
import androidx.annotation.Nullable;
24+
import com.google.android.gms.tasks.Task;
25+
import com.google.android.gms.tasks.Tasks;
26+
import com.google.android.gms.tflite.client.TfLiteInitializationOptions;
27+
import com.google.android.gms.tflite.gpu.support.TfLiteGpu;
28+
import com.google.android.gms.tflite.java.TfLiteNative;
29+
import java.text.DateFormat;
30+
import java.text.SimpleDateFormat;
31+
import java.util.Arrays;
32+
import java.util.Date;
33+
34+
/** Sample activity to test the TFLite C API. */
35+
public class MainActivity extends Activity {
36+
37+
private static final String TAG = "MainActivity";
38+
39+
private volatile boolean isGpuAvailable = false;
40+
41+
private TextView logView;
42+
43+
@Override
44+
protected void onCreate(@Nullable Bundle savedInstanceState) {
45+
super.onCreate(savedInstanceState);
46+
setContentView(R.layout.activity_main);
47+
48+
logView = findViewById(R.id.log_text);
49+
50+
findViewById(android.R.id.button1).setOnClickListener(v -> runScenario());
51+
}
52+
53+
@Override
54+
protected void onStart() {
55+
super.onStart();
56+
runScenario();
57+
}
58+
59+
private void runScenario() {
60+
String currentTime = SimpleDateFormat.getTimeInstance(DateFormat.SHORT).format(new Date());
61+
logView.setText(String.format("Scenario started at %s\n", currentTime));
62+
isGpuAvailable = false;
63+
64+
logEvent("Checking GPU acceleration availability...");
65+
66+
Task<Void> tfLiteHandleTask =
67+
TfLiteGpu.isGpuDelegateAvailable(this)
68+
.onSuccessTask(
69+
gpuAvailable -> {
70+
isGpuAvailable = gpuAvailable;
71+
logEvent("GPU acceleration is " + (isGpuAvailable ? "available" : "unavailable"));
72+
TfLiteInitializationOptions options =
73+
TfLiteInitializationOptions.builder()
74+
.setEnableGpuDelegateSupport(isGpuAvailable)
75+
.build();
76+
return TfLiteNative.initialize(this, options);
77+
});
78+
79+
tfLiteHandleTask
80+
.onSuccessTask(
81+
unused -> {
82+
logEvent("Running inference on " + (isGpuAvailable ? "GPU" : "CPU"));
83+
TfLiteJni jni = new TfLiteJni(this::logEvent);
84+
if (isGpuAvailable) {
85+
jni.initGpuAcceleration();
86+
}
87+
logEvent("TfLiteJni created");
88+
jni.loadModel(getAssets(), "add.tflite");
89+
logEvent("Model loaded");
90+
float[] output = jni.runInference(new float[] {1.f, 3.f});
91+
logEvent("Ran inference, expected: [3, 9], got output: " + Arrays.toString(output));
92+
jni.destroy();
93+
logEvent("TfLiteJni destroyed");
94+
return Tasks.forResult(output);
95+
})
96+
.addOnSuccessListener(unused -> logEvent("Scenario successful!"))
97+
.addOnFailureListener(e -> logEvent("Scenario failed", e));
98+
}
99+
100+
private void logEvent(String message) {
101+
logEvent(message, null);
102+
}
103+
104+
private void logEvent(String message, @Nullable Throwable throwable) {
105+
Log.e(TAG, message, throwable);
106+
logView.append("• ");
107+
logView.append(String.valueOf(message));
108+
logView.append("\n");
109+
if (throwable != null) {
110+
logView.append(throwable.getClass().getCanonicalName() + ": " + throwable.getMessage());
111+
logView.append("\n");
112+
logView.append(Arrays.toString(throwable.getStackTrace()));
113+
logView.append("\n");
114+
}
115+
}
116+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
/*
2+
* Copyright 2023 The TensorFlow Authors
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.google.samples.gms.tflite.c;
18+
19+
import android.content.res.AssetManager;
20+
21+
/** JNI bridge to forward the calls to the native code, where we can invoke the TFLite C API. */
22+
public class TfLiteJni {
23+
24+
private final LoggingCallback loggingCallback;
25+
26+
/**
27+
* This interface gets called when the JNI wants to print a message (used for debugging purposes).
28+
*/
29+
public interface LoggingCallback {
30+
void printLogMessage(String message);
31+
}
32+
33+
static {
34+
System.loadLibrary("tflite-jni");
35+
}
36+
37+
public TfLiteJni(LoggingCallback loggingCallback) {
38+
this.loggingCallback = loggingCallback;
39+
}
40+
41+
private void sendLogMessage(String message) {
42+
if (loggingCallback != null) {
43+
loggingCallback.printLogMessage(message);
44+
}
45+
}
46+
47+
/** Creates GPU delegate that will be used for the inference. */
48+
public native void initGpuAcceleration();
49+
50+
/**
51+
* Loads the model and creates the Interpreter. GPU delegate is applied if {@link
52+
* TfLiteJni#initGpuAcceleration} was previously called.
53+
*/
54+
public native void loadModel(AssetManager assetManager, String assetName);
55+
56+
/** Runs the inference using the Interpreter created by {@link TfLiteJni#loadModel}. */
57+
public native float[] runInference(float[] input);
58+
59+
/** Unloads the assets and clears all the Interpreter's resources. */
60+
public native void destroy();
61+
}

0 commit comments

Comments
 (0)