Skip to content

Commit

Permalink
Not hardcode llama2 model in perf test
Browse files Browse the repository at this point in the history
Differential Revision: D61057535

Pull Request resolved: #4657
  • Loading branch information
huydhn committed Aug 12, 2024
1 parent e800626 commit d53f8fa
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 33 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/android-perf.yml
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,6 @@ jobs:
# TODO: Hard code llm_demo_bpe for now in this job.
android-app-archive: https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifact/llm_demo_bpe/app-debug.apk
android-test-archive: https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifact/llm_demo_bpe/app-debug-androidTest.apk
# The test spec can be downloaded from https://ossci-assets.s3.amazonaws.com/android-llama2-device-farm-test-spec.yml
test-spec: arn:aws:devicefarm:us-west-2:308535385114:upload:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/abd86868-fa63-467e-a5c7-218194665a77
test-spec: https://ossci-assets.s3.amazonaws.com/android-llm-device-farm-test-spec.yml
# Uploaded to S3 from the previous job
extra-data: https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifact/${{ matrix.model }}_${{ matrix.delegate }}/model.zip
3 changes: 1 addition & 2 deletions .github/workflows/android.yml
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,7 @@ jobs:
# Uploaded to S3 from the previous job, the name of the app comes from the project itself
android-app-archive: https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifact/llm_demo_${{ matrix.tokenizer }}/app-debug.apk
android-test-archive: https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifact/llm_demo_${{ matrix.tokenizer }}/app-debug-androidTest.apk
# The test spec can be downloaded from https://ossci-assets.s3.amazonaws.com/android-llama2-device-farm-test-spec.yml
test-spec: arn:aws:devicefarm:us-west-2:308535385114:upload:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/abd86868-fa63-467e-a5c7-218194665a77
test-spec: https://ossci-assets.s3.amazonaws.com/android-llm-device-farm-test-spec.yml
# Among the input, this is the biggest file, so it is cached on AWS to make the test faster. Note that the file is deleted by AWS after 30
# days and the job will automatically re-upload the file when that happens.
extra-data: https://ossci-assets.s3.amazonaws.com/executorch-android-llama2-7b-0717.zip
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ phases:
# Prepare the model and the tokenizer
- adb -s $DEVICEFARM_DEVICE_UDID shell "ls -la /sdcard/"
- adb -s $DEVICEFARM_DEVICE_UDID shell "mkdir -p /data/local/tmp/llama/"
- adb -s $DEVICEFARM_DEVICE_UDID shell "mv /sdcard/tokenizer.bin /data/local/tmp/llama/tokenizer.bin"
- adb -s $DEVICEFARM_DEVICE_UDID shell "mv /sdcard/xnnpack_llama2.pte /data/local/tmp/llama/xnnpack_llama2.pte"
- adb -s $DEVICEFARM_DEVICE_UDID shell "chmod 664 /data/local/tmp/llama/tokenizer.bin"
- adb -s $DEVICEFARM_DEVICE_UDID shell "chmod 664 /data/local/tmp/llama/xnnpack_llama2.pte"
- adb -s $DEVICEFARM_DEVICE_UDID shell "mv /sdcard/*.bin /data/local/tmp/llama/"
- adb -s $DEVICEFARM_DEVICE_UDID shell "mv /sdcard/*.pte /data/local/tmp/llama/"
- adb -s $DEVICEFARM_DEVICE_UDID shell "chmod 664 /data/local/tmp/llama/*.bin"
- adb -s $DEVICEFARM_DEVICE_UDID shell "chmod 664 /data/local/tmp/llama/*.pte"
- adb -s $DEVICEFARM_DEVICE_UDID shell "ls -la /data/local/tmp/llama/"

test:
Expand Down Expand Up @@ -50,14 +50,8 @@ phases:
false;
elif [ $TESTS_FAILED -ne 0 ];
then
OBSERVED_TPS=$(grep "The observed TPS " $INSTRUMENT_LOG | tail -n 1)
if [ -n "${OBSERVED_TPS}" ];
then
echo "[PyTorch] ${OBSERVED_TPS}";
else
echo "[PyTorch] Marking the test suite as failed because it failed to load the model";
fi
echo "[PyTorch] Marking the test suite as failed because it failed to load the model";
false;
elif [ $TESTS_ERRORED -ne 0 ];
then
echo "[PyTorch] Marking the test suite as failed because $TESTS_ERRORED tests errored!";
Expand All @@ -66,6 +60,17 @@ phases:
then
echo "[PyTorch] Marking the test suite as failed because the app crashed due to OOM!";
false;
# Check for this last to make sure that there is no failure
elif [ $TESTS_PASSED -ne 0 ];
then
OBSERVED_TPS=$(grep "INSTRUMENTATION_STATUS: TPS=" $INSTRUMENT_LOG | tail -n 1)
if [ -n "${OBSERVED_TPS}" ];
then
echo "[PyTorch] ${OBSERVED_TPS}";
else
echo "[PyTorch] Test passes but couldn't find the observed TPS from instrument log";
fi
fi;
post_test:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@

package com.example.executorchllamademo;

import static junit.framework.TestCase.assertTrue;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;

import android.os.Bundle;
import androidx.test.ext.junit.runners.AndroidJUnit4;
import androidx.test.platform.app.InstrumentationRegistry;
import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.junit.Test;
import org.junit.runner.RunWith;
Expand All @@ -24,33 +27,35 @@
public class PerfTest implements LlamaCallback {

private static final String RESOURCE_PATH = "/data/local/tmp/llama/";
private static final String MODEL_NAME = "xnnpack_llama2.pte";
private static final String TOKENIZER_BIN = "tokenizer.bin";

// From https://github.com/pytorch/executorch/blob/main/examples/models/llama2/README.md
private static final Float EXPECTED_TPS = 10.0F;

private final List<String> results = new ArrayList<>();
private final List<Float> tokensPerSecond = new ArrayList<>();

@Test
public void testTokensPerSecond() {
String modelPath = RESOURCE_PATH + MODEL_NAME;
String tokenizerPath = RESOURCE_PATH + TOKENIZER_BIN;
LlamaModule mModule = new LlamaModule(modelPath, tokenizerPath, 0.8f);
// Find out the model name
File directory = new File(RESOURCE_PATH);
Arrays.stream(directory.listFiles())
.filter(file -> file.getName().endsWith(".pte"))
.forEach(
model -> {
LlamaModule mModule = new LlamaModule(model.getPath(), tokenizerPath, 0.8f);
// Print the model name because there might be more than one of them
report("ModelName", model.getName());

int loadResult = mModule.load();
// Check that the model can be load successfully
assertEquals(0, loadResult);
int loadResult = mModule.load();
// Check that the model can be load successfully
assertEquals(0, loadResult);

// Run a testing prompt
mModule.generate("How do you do! I'm testing llama2 on mobile device", PerfTest.this);
assertFalse(tokensPerSecond.isEmpty());
// Run a testing prompt
mModule.generate("How do you do! I'm testing llama2 on mobile device", PerfTest.this);
assertFalse(tokensPerSecond.isEmpty());

final Float tps = tokensPerSecond.get(tokensPerSecond.size() - 1);
assertTrue(
"The observed TPS " + tps + " is less than the expected TPS " + EXPECTED_TPS,
tps >= EXPECTED_TPS);
final Float tps = tokensPerSecond.get(tokensPerSecond.size() - 1);
report("TPS", tps);
});
}

@Override
Expand All @@ -62,4 +67,16 @@ public void onResult(String result) {
public void onStats(float tps) {
tokensPerSecond.add(tps);
}

private void report(final String metric, final Float value) {
Bundle bundle = new Bundle();
bundle.putFloat(metric, value);
InstrumentationRegistry.getInstrumentation().sendStatus(0, bundle);
}

private void report(final String key, final String value) {
Bundle bundle = new Bundle();
bundle.putString(key, value);
InstrumentationRegistry.getInstrumentation().sendStatus(0, bundle);
}
}

0 comments on commit d53f8fa

Please sign in to comment.