Skip to content

Commit d53f8fa

Browse files
authored
Not hardcode llama2 model in perf test
Differential Revision: D61057535 Pull Request resolved: #4657
1 parent e800626 commit d53f8fa

File tree

4 files changed

+53
-33
lines changed

4 files changed

+53
-33
lines changed

.github/workflows/android-perf.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,6 @@ jobs:
218218
# TODO: Hard code llm_demo_bpe for now in this job.
219219
android-app-archive: https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifact/llm_demo_bpe/app-debug.apk
220220
android-test-archive: https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifact/llm_demo_bpe/app-debug-androidTest.apk
221-
# The test spec can be downloaded from https://ossci-assets.s3.amazonaws.com/android-llama2-device-farm-test-spec.yml
222-
test-spec: arn:aws:devicefarm:us-west-2:308535385114:upload:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/abd86868-fa63-467e-a5c7-218194665a77
221+
test-spec: https://ossci-assets.s3.amazonaws.com/android-llm-device-farm-test-spec.yml
223222
# Uploaded to S3 from the previous job
224223
extra-data: https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifact/${{ matrix.model }}_${{ matrix.delegate }}/model.zip

.github/workflows/android.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,7 @@ jobs:
170170
# Uploaded to S3 from the previous job, the name of the app comes from the project itself
171171
android-app-archive: https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifact/llm_demo_${{ matrix.tokenizer }}/app-debug.apk
172172
android-test-archive: https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifact/llm_demo_${{ matrix.tokenizer }}/app-debug-androidTest.apk
173-
# The test spec can be downloaded from https://ossci-assets.s3.amazonaws.com/android-llama2-device-farm-test-spec.yml
174-
test-spec: arn:aws:devicefarm:us-west-2:308535385114:upload:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/abd86868-fa63-467e-a5c7-218194665a77
173+
test-spec: https://ossci-assets.s3.amazonaws.com/android-llm-device-farm-test-spec.yml
175174
# Among the input, this is the biggest file, so it is cached on AWS to make the test faster. Note that the file is deleted by AWS after 30
176175
# days and the job will automatically re-upload the file when that happens.
177176
extra-data: https://ossci-assets.s3.amazonaws.com/executorch-android-llama2-7b-0717.zip

examples/demo-apps/android/LlamaDemo/android-llama2-device-farm-test-spec.yml renamed to examples/demo-apps/android/LlamaDemo/android-llm-device-farm-test-spec.yml

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@ phases:
1111
# Prepare the model and the tokenizer
1212
- adb -s $DEVICEFARM_DEVICE_UDID shell "ls -la /sdcard/"
1313
- adb -s $DEVICEFARM_DEVICE_UDID shell "mkdir -p /data/local/tmp/llama/"
14-
- adb -s $DEVICEFARM_DEVICE_UDID shell "mv /sdcard/tokenizer.bin /data/local/tmp/llama/tokenizer.bin"
15-
- adb -s $DEVICEFARM_DEVICE_UDID shell "mv /sdcard/xnnpack_llama2.pte /data/local/tmp/llama/xnnpack_llama2.pte"
16-
- adb -s $DEVICEFARM_DEVICE_UDID shell "chmod 664 /data/local/tmp/llama/tokenizer.bin"
17-
- adb -s $DEVICEFARM_DEVICE_UDID shell "chmod 664 /data/local/tmp/llama/xnnpack_llama2.pte"
14+
- adb -s $DEVICEFARM_DEVICE_UDID shell "mv /sdcard/*.bin /data/local/tmp/llama/"
15+
- adb -s $DEVICEFARM_DEVICE_UDID shell "mv /sdcard/*.pte /data/local/tmp/llama/"
16+
- adb -s $DEVICEFARM_DEVICE_UDID shell "chmod 664 /data/local/tmp/llama/*.bin"
17+
- adb -s $DEVICEFARM_DEVICE_UDID shell "chmod 664 /data/local/tmp/llama/*.pte"
1818
- adb -s $DEVICEFARM_DEVICE_UDID shell "ls -la /data/local/tmp/llama/"
1919

2020
test:
@@ -50,14 +50,8 @@ phases:
5050
false;
5151
elif [ $TESTS_FAILED -ne 0 ];
5252
then
53-
OBSERVED_TPS=$(grep "The observed TPS " $INSTRUMENT_LOG | tail -n 1)
54-
55-
if [ -n "${OBSERVED_TPS}" ];
56-
then
57-
echo "[PyTorch] ${OBSERVED_TPS}";
58-
else
59-
echo "[PyTorch] Marking the test suite as failed because it failed to load the model";
60-
fi
53+
echo "[PyTorch] Marking the test suite as failed because it failed to load the model";
54+
false;
6155
elif [ $TESTS_ERRORED -ne 0 ];
6256
then
6357
echo "[PyTorch] Marking the test suite as failed because $TESTS_ERRORED tests errored!";
@@ -66,6 +60,17 @@ phases:
6660
then
6761
echo "[PyTorch] Marking the test suite as failed because the app crashed due to OOM!";
6862
false;
63+
# Check for this last to make sure that there is no failure
64+
elif [ $TESTS_PASSED -ne 0 ];
65+
then
66+
OBSERVED_TPS=$(grep "INSTRUMENTATION_STATUS: TPS=" $INSTRUMENT_LOG | tail -n 1)
67+
68+
if [ -n "${OBSERVED_TPS}" ];
69+
then
70+
echo "[PyTorch] ${OBSERVED_TPS}";
71+
else
72+
echo "[PyTorch] Test passes but couldn't find the observed TPS from instrument log";
73+
fi
6974
fi;
7075
7176
post_test:

examples/demo-apps/android/LlamaDemo/app/src/androidTest/java/com/example/executorchllamademo/PerfTest.java

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,15 @@
88

99
package com.example.executorchllamademo;
1010

11-
import static junit.framework.TestCase.assertTrue;
1211
import static org.junit.Assert.assertEquals;
1312
import static org.junit.Assert.assertFalse;
1413

14+
import android.os.Bundle;
1515
import androidx.test.ext.junit.runners.AndroidJUnit4;
16+
import androidx.test.platform.app.InstrumentationRegistry;
17+
import java.io.File;
1618
import java.util.ArrayList;
19+
import java.util.Arrays;
1720
import java.util.List;
1821
import org.junit.Test;
1922
import org.junit.runner.RunWith;
@@ -24,33 +27,35 @@
2427
public class PerfTest implements LlamaCallback {
2528

2629
private static final String RESOURCE_PATH = "/data/local/tmp/llama/";
27-
private static final String MODEL_NAME = "xnnpack_llama2.pte";
2830
private static final String TOKENIZER_BIN = "tokenizer.bin";
2931

30-
// From https://github.com/pytorch/executorch/blob/main/examples/models/llama2/README.md
31-
private static final Float EXPECTED_TPS = 10.0F;
32-
3332
private final List<String> results = new ArrayList<>();
3433
private final List<Float> tokensPerSecond = new ArrayList<>();
3534

3635
@Test
3736
public void testTokensPerSecond() {
38-
String modelPath = RESOURCE_PATH + MODEL_NAME;
3937
String tokenizerPath = RESOURCE_PATH + TOKENIZER_BIN;
40-
LlamaModule mModule = new LlamaModule(modelPath, tokenizerPath, 0.8f);
38+
// Find out the model name
39+
File directory = new File(RESOURCE_PATH);
40+
Arrays.stream(directory.listFiles())
41+
.filter(file -> file.getName().endsWith(".pte"))
42+
.forEach(
43+
model -> {
44+
LlamaModule mModule = new LlamaModule(model.getPath(), tokenizerPath, 0.8f);
45+
// Print the model name because there might be more than one of them
46+
report("ModelName", model.getName());
4147

42-
int loadResult = mModule.load();
43-
// Check that the model can be load successfully
44-
assertEquals(0, loadResult);
48+
int loadResult = mModule.load();
49+
// Check that the model can be load successfully
50+
assertEquals(0, loadResult);
4551

46-
// Run a testing prompt
47-
mModule.generate("How do you do! I'm testing llama2 on mobile device", PerfTest.this);
48-
assertFalse(tokensPerSecond.isEmpty());
52+
// Run a testing prompt
53+
mModule.generate("How do you do! I'm testing llama2 on mobile device", PerfTest.this);
54+
assertFalse(tokensPerSecond.isEmpty());
4955

50-
final Float tps = tokensPerSecond.get(tokensPerSecond.size() - 1);
51-
assertTrue(
52-
"The observed TPS " + tps + " is less than the expected TPS " + EXPECTED_TPS,
53-
tps >= EXPECTED_TPS);
56+
final Float tps = tokensPerSecond.get(tokensPerSecond.size() - 1);
57+
report("TPS", tps);
58+
});
5459
}
5560

5661
@Override
@@ -62,4 +67,16 @@ public void onResult(String result) {
6267
public void onStats(float tps) {
6368
tokensPerSecond.add(tps);
6469
}
70+
71+
private void report(final String metric, final Float value) {
72+
Bundle bundle = new Bundle();
73+
bundle.putFloat(metric, value);
74+
InstrumentationRegistry.getInstrumentation().sendStatus(0, bundle);
75+
}
76+
77+
private void report(final String key, final String value) {
78+
Bundle bundle = new Bundle();
79+
bundle.putString(key, value);
80+
InstrumentationRegistry.getInstrumentation().sendStatus(0, bundle);
81+
}
6582
}

0 commit comments

Comments
 (0)