diff --git a/.github/workflows/android-perf.yml b/.github/workflows/android-perf.yml index a8223eef2c..4f8b216a54 100644 --- a/.github/workflows/android-perf.yml +++ b/.github/workflows/android-perf.yml @@ -218,7 +218,6 @@ jobs: # TODO: Hard code llm_demo_bpe for now in this job. android-app-archive: https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifact/llm_demo_bpe/app-debug.apk android-test-archive: https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifact/llm_demo_bpe/app-debug-androidTest.apk - # The test spec can be downloaded from https://ossci-assets.s3.amazonaws.com/android-llama2-device-farm-test-spec.yml - test-spec: arn:aws:devicefarm:us-west-2:308535385114:upload:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/abd86868-fa63-467e-a5c7-218194665a77 + test-spec: https://ossci-assets.s3.amazonaws.com/android-llm-device-farm-test-spec.yml # Uploaded to S3 from the previous job extra-data: https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifact/${{ matrix.model }}_${{ matrix.delegate }}/model.zip diff --git a/.github/workflows/android.yml b/.github/workflows/android.yml index 7b3d8ab9a8..5af09dc490 100644 --- a/.github/workflows/android.yml +++ b/.github/workflows/android.yml @@ -170,8 +170,7 @@ jobs: # Uploaded to S3 from the previous job, the name of the app comes from the project itself android-app-archive: https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifact/llm_demo_${{ matrix.tokenizer }}/app-debug.apk android-test-archive: https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifact/llm_demo_${{ matrix.tokenizer }}/app-debug-androidTest.apk - # The test spec can be downloaded from https://ossci-assets.s3.amazonaws.com/android-llama2-device-farm-test-spec.yml - test-spec: arn:aws:devicefarm:us-west-2:308535385114:upload:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/abd86868-fa63-467e-a5c7-218194665a77 + test-spec: https://ossci-assets.s3.amazonaws.com/android-llm-device-farm-test-spec.yml # Among the input, this is the biggest file, so it is cached on AWS to make the test faster. Note that the file is deleted by AWS after 30 # days and the job will automatically re-upload the file when that happens. extra-data: https://ossci-assets.s3.amazonaws.com/executorch-android-llama2-7b-0717.zip diff --git a/examples/demo-apps/android/LlamaDemo/android-llama2-device-farm-test-spec.yml b/examples/demo-apps/android/LlamaDemo/android-llm-device-farm-test-spec.yml similarity index 81% rename from examples/demo-apps/android/LlamaDemo/android-llama2-device-farm-test-spec.yml rename to examples/demo-apps/android/LlamaDemo/android-llm-device-farm-test-spec.yml index 4df9f18cc5..cac83b8e6f 100644 --- a/examples/demo-apps/android/LlamaDemo/android-llama2-device-farm-test-spec.yml +++ b/examples/demo-apps/android/LlamaDemo/android-llm-device-farm-test-spec.yml @@ -11,10 +11,10 @@ phases: # Prepare the model and the tokenizer - adb -s $DEVICEFARM_DEVICE_UDID shell "ls -la /sdcard/" - adb -s $DEVICEFARM_DEVICE_UDID shell "mkdir -p /data/local/tmp/llama/" - - adb -s $DEVICEFARM_DEVICE_UDID shell "mv /sdcard/tokenizer.bin /data/local/tmp/llama/tokenizer.bin" - - adb -s $DEVICEFARM_DEVICE_UDID shell "mv /sdcard/xnnpack_llama2.pte /data/local/tmp/llama/xnnpack_llama2.pte" - - adb -s $DEVICEFARM_DEVICE_UDID shell "chmod 664 /data/local/tmp/llama/tokenizer.bin" - - adb -s $DEVICEFARM_DEVICE_UDID shell "chmod 664 /data/local/tmp/llama/xnnpack_llama2.pte" + - adb -s $DEVICEFARM_DEVICE_UDID shell "mv /sdcard/*.bin /data/local/tmp/llama/" + - adb -s $DEVICEFARM_DEVICE_UDID shell "mv /sdcard/*.pte /data/local/tmp/llama/" + - adb -s $DEVICEFARM_DEVICE_UDID shell "chmod 664 /data/local/tmp/llama/*.bin" + - adb -s $DEVICEFARM_DEVICE_UDID shell "chmod 664 /data/local/tmp/llama/*.pte" - adb -s $DEVICEFARM_DEVICE_UDID shell "ls -la /data/local/tmp/llama/" test: @@ -50,14 +50,8 @@ phases: false; elif [ $TESTS_FAILED -ne 0 ]; then - OBSERVED_TPS=$(grep "The observed TPS " $INSTRUMENT_LOG | tail -n 1) - - if [ -n "${OBSERVED_TPS}" ]; - then - echo "[PyTorch] ${OBSERVED_TPS}"; - else - echo "[PyTorch] Marking the test suite as failed because it failed to load the model"; - fi + echo "[PyTorch] Marking the test suite as failed because it failed to load the model"; + false; elif [ $TESTS_ERRORED -ne 0 ]; then echo "[PyTorch] Marking the test suite as failed because $TESTS_ERRORED tests errored!"; @@ -66,6 +60,17 @@ phases: then echo "[PyTorch] Marking the test suite as failed because the app crashed due to OOM!"; false; + # Check for this last to make sure that there is no failure + elif [ $TESTS_PASSED -ne 0 ]; + then + OBSERVED_TPS=$(grep "INSTRUMENTATION_STATUS: TPS=" $INSTRUMENT_LOG | tail -n 1) + + if [ -n "${OBSERVED_TPS}" ]; + then + echo "[PyTorch] ${OBSERVED_TPS}"; + else + echo "[PyTorch] Test passes but couldn't find the observed TPS from instrument log"; + fi fi; post_test: diff --git a/examples/demo-apps/android/LlamaDemo/app/src/androidTest/java/com/example/executorchllamademo/PerfTest.java b/examples/demo-apps/android/LlamaDemo/app/src/androidTest/java/com/example/executorchllamademo/PerfTest.java index b8988d1f4b..221a9bd741 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/androidTest/java/com/example/executorchllamademo/PerfTest.java +++ b/examples/demo-apps/android/LlamaDemo/app/src/androidTest/java/com/example/executorchllamademo/PerfTest.java @@ -8,12 +8,15 @@ package com.example.executorchllamademo; -import static junit.framework.TestCase.assertTrue; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import android.os.Bundle; import androidx.test.ext.junit.runners.AndroidJUnit4; +import androidx.test.platform.app.InstrumentationRegistry; +import java.io.File; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import org.junit.Test; import org.junit.runner.RunWith; @@ -24,33 +27,35 @@ public class PerfTest implements LlamaCallback { private static final String RESOURCE_PATH = "/data/local/tmp/llama/"; - private static final String MODEL_NAME = "xnnpack_llama2.pte"; private static final String TOKENIZER_BIN = "tokenizer.bin"; - // From https://github.com/pytorch/executorch/blob/main/examples/models/llama2/README.md - private static final Float EXPECTED_TPS = 10.0F; - private final List results = new ArrayList<>(); private final List tokensPerSecond = new ArrayList<>(); @Test public void testTokensPerSecond() { - String modelPath = RESOURCE_PATH + MODEL_NAME; String tokenizerPath = RESOURCE_PATH + TOKENIZER_BIN; - LlamaModule mModule = new LlamaModule(modelPath, tokenizerPath, 0.8f); + // Find out the model name + File directory = new File(RESOURCE_PATH); + Arrays.stream(directory.listFiles()) + .filter(file -> file.getName().endsWith(".pte")) + .forEach( + model -> { + LlamaModule mModule = new LlamaModule(model.getPath(), tokenizerPath, 0.8f); + // Print the model name because there might be more than one of them + report("ModelName", model.getName()); - int loadResult = mModule.load(); - // Check that the model can be load successfully - assertEquals(0, loadResult); + int loadResult = mModule.load(); + // Check that the model can be load successfully + assertEquals(0, loadResult); - // Run a testing prompt - mModule.generate("How do you do! I'm testing llama2 on mobile device", PerfTest.this); - assertFalse(tokensPerSecond.isEmpty()); + // Run a testing prompt + mModule.generate("How do you do! I'm testing llama2 on mobile device", PerfTest.this); + assertFalse(tokensPerSecond.isEmpty()); - final Float tps = tokensPerSecond.get(tokensPerSecond.size() - 1); - assertTrue( - "The observed TPS " + tps + " is less than the expected TPS " + EXPECTED_TPS, - tps >= EXPECTED_TPS); + final Float tps = tokensPerSecond.get(tokensPerSecond.size() - 1); + report("TPS", tps); + }); } @Override @@ -62,4 +67,16 @@ public void onResult(String result) { public void onStats(float tps) { tokensPerSecond.add(tps); } + + private void report(final String metric, final Float value) { + Bundle bundle = new Bundle(); + bundle.putFloat(metric, value); + InstrumentationRegistry.getInstrumentation().sendStatus(0, bundle); + } + + private void report(final String key, final String value) { + Bundle bundle = new Bundle(); + bundle.putString(key, value); + InstrumentationRegistry.getInstrumentation().sendStatus(0, bundle); + } }