-
Notifications
You must be signed in to change notification settings - Fork 295
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Not hardcode llama2 model in perf test #4657
Changes from 17 commits
d87872b
820b50b
b64f631
6d41809
b92e471
55024d8
3f69c55
3c084ec
9372c8f
8a8aefe
56508c4
51dfaf3
2eff98c
1c752fb
47e413b
c8b5c94
2d3c2c0
d08a5eb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,10 +11,12 @@ phases: | |
# Prepare the model and the tokenizer | ||
- adb -s $DEVICEFARM_DEVICE_UDID shell "ls -la /sdcard/" | ||
- adb -s $DEVICEFARM_DEVICE_UDID shell "mkdir -p /data/local/tmp/llama/" | ||
- adb -s $DEVICEFARM_DEVICE_UDID shell "mv /sdcard/tokenizer.bin /data/local/tmp/llama/tokenizer.bin" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh I just notice that this file is actually has a local copy stored in the demo-apps dir! Curious to know how it is uploaded to s3 so that the link https://ossci-assets.s3.amazonaws.com/android-llama2-device-farm-test-spec-v2.yml could work? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, the upload is still done manually, I plan to write a workflow to automatically upload it to S3 in the next PR, so that we can just update this file |
||
- adb -s $DEVICEFARM_DEVICE_UDID shell "mv /sdcard/xnnpack_llama2.pte /data/local/tmp/llama/xnnpack_llama2.pte" | ||
- adb -s $DEVICEFARM_DEVICE_UDID shell "chmod 664 /data/local/tmp/llama/tokenizer.bin" | ||
- adb -s $DEVICEFARM_DEVICE_UDID shell "chmod 664 /data/local/tmp/llama/xnnpack_llama2.pte" | ||
- adb -s $DEVICEFARM_DEVICE_UDID shell "mv /sdcard/*.bin /data/local/tmp/llama/" | ||
- adb -s $DEVICEFARM_DEVICE_UDID shell "mv /sdcard/*.pte /data/local/tmp/llama/" | ||
- adb -s $DEVICEFARM_DEVICE_UDID shell "mv /sdcard/*.pt /data/local/tmp/llama/" | ||
- adb -s $DEVICEFARM_DEVICE_UDID shell "chmod 664 /data/local/tmp/llama/*.bin" | ||
- adb -s $DEVICEFARM_DEVICE_UDID shell "chmod 664 /data/local/tmp/llama/*.pte" | ||
- adb -s $DEVICEFARM_DEVICE_UDID shell "chmod 664 /data/local/tmp/llama/*.pt" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We probably won't need this line. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a bug here in |
||
- adb -s $DEVICEFARM_DEVICE_UDID shell "ls -la /data/local/tmp/llama/" | ||
|
||
test: | ||
|
@@ -50,14 +52,8 @@ phases: | |
false; | ||
elif [ $TESTS_FAILED -ne 0 ]; | ||
then | ||
OBSERVED_TPS=$(grep "The observed TPS " $INSTRUMENT_LOG | tail -n 1) | ||
|
||
if [ -n "${OBSERVED_TPS}" ]; | ||
then | ||
echo "[PyTorch] ${OBSERVED_TPS}"; | ||
else | ||
echo "[PyTorch] Marking the test suite as failed because it failed to load the model"; | ||
fi | ||
echo "[PyTorch] Marking the test suite as failed because it failed to load the model"; | ||
false; | ||
elif [ $TESTS_ERRORED -ne 0 ]; | ||
then | ||
echo "[PyTorch] Marking the test suite as failed because $TESTS_ERRORED tests errored!"; | ||
|
@@ -66,6 +62,17 @@ phases: | |
then | ||
echo "[PyTorch] Marking the test suite as failed because the app crashed due to OOM!"; | ||
false; | ||
# Check for this last to make sure that there is no failure | ||
elif [ $TESTS_PASSED -ne 0 ]; | ||
then | ||
OBSERVED_TPS=$(grep "INSTRUMENTATION_STATUS: TPS=" $INSTRUMENT_LOG | tail -n 1) | ||
|
||
if [ -n "${OBSERVED_TPS}" ]; | ||
then | ||
echo "[PyTorch] ${OBSERVED_TPS}"; | ||
else | ||
echo "[PyTorch] Test passes but couldn't find the observed TPS from instrument log"; | ||
fi | ||
fi; | ||
|
||
post_test: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,12 +8,15 @@ | |
|
||
package com.example.executorchllamademo; | ||
|
||
import static junit.framework.TestCase.assertTrue; | ||
import static org.junit.Assert.assertEquals; | ||
import static org.junit.Assert.assertFalse; | ||
|
||
import android.os.Bundle; | ||
import androidx.test.ext.junit.runners.AndroidJUnit4; | ||
import androidx.test.platform.app.InstrumentationRegistry; | ||
import java.io.File; | ||
import java.util.ArrayList; | ||
import java.util.Arrays; | ||
import java.util.List; | ||
import org.junit.Test; | ||
import org.junit.runner.RunWith; | ||
|
@@ -24,33 +27,35 @@ | |
public class PerfTest implements LlamaCallback { | ||
|
||
private static final String RESOURCE_PATH = "/data/local/tmp/llama/"; | ||
private static final String MODEL_NAME = "xnnpack_llama2.pte"; | ||
private static final String TOKENIZER_BIN = "tokenizer.bin"; | ||
|
||
// From https://github.com/pytorch/executorch/blob/main/examples/models/llama2/README.md | ||
private static final Float EXPECTED_TPS = 10.0F; | ||
|
||
private final List<String> results = new ArrayList<>(); | ||
private final List<Float> tokensPerSecond = new ArrayList<>(); | ||
|
||
@Test | ||
public void testTokensPerSecond() { | ||
String modelPath = RESOURCE_PATH + MODEL_NAME; | ||
String tokenizerPath = RESOURCE_PATH + TOKENIZER_BIN; | ||
LlamaModule mModule = new LlamaModule(modelPath, tokenizerPath, 0.8f); | ||
// Find out the model name | ||
File directory = new File(RESOURCE_PATH); | ||
Arrays.stream(directory.listFiles()) | ||
.filter(file -> file.getName().endsWith(".pte") || file.getName().endsWith(".pt")) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't need There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see, let me remove this, but it looks like something is not right where the https://gha-artifacts.s3.amazonaws.com/pytorch/executorch/10332366636/artifact/stories110M_xnnpack/model.zip contains only There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah, the fix is in #4642 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So you merge the test spec fix and ignore the "failed to load model" issue There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Oh, got it, let me push a commit to fix the other comments and land this PR |
||
.forEach( | ||
model -> { | ||
LlamaModule mModule = new LlamaModule(model.getPath(), tokenizerPath, 0.8f); | ||
// Print the model name because there might be more than one of them | ||
report("ModelName", model.getName()); | ||
|
||
int loadResult = mModule.load(); | ||
// Check that the model can be load successfully | ||
assertEquals(0, loadResult); | ||
int loadResult = mModule.load(); | ||
// Check that the model can be load successfully | ||
assertEquals(0, loadResult); | ||
Comment on lines
+48
to
+50
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. linter |
||
|
||
// Run a testing prompt | ||
mModule.generate("How do you do! I'm testing llama2 on mobile device", PerfTest.this); | ||
assertFalse(tokensPerSecond.isEmpty()); | ||
// Run a testing prompt | ||
mModule.generate("How do you do! I'm testing llama2 on mobile device", PerfTest.this); | ||
assertFalse(tokensPerSecond.isEmpty()); | ||
|
||
final Float tps = tokensPerSecond.get(tokensPerSecond.size() - 1); | ||
assertTrue( | ||
"The observed TPS " + tps + " is less than the expected TPS " + EXPECTED_TPS, | ||
tps >= EXPECTED_TPS); | ||
final Float tps = tokensPerSecond.get(tokensPerSecond.size() - 1); | ||
report("TPS", tps); | ||
}); | ||
} | ||
|
||
@Override | ||
|
@@ -62,4 +67,16 @@ public void onResult(String result) { | |
public void onStats(float tps) { | ||
tokensPerSecond.add(tps); | ||
} | ||
|
||
private void report(final String metric, final Float value) { | ||
Bundle bundle = new Bundle(); | ||
bundle.putFloat(metric, value); | ||
InstrumentationRegistry.getInstrumentation().sendStatus(0, bundle); | ||
} | ||
|
||
private void report(final String key, final String value) { | ||
Bundle bundle = new Bundle(); | ||
bundle.putString(key, value); | ||
InstrumentationRegistry.getInstrumentation().sendStatus(0, bundle); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we rename it to something more generic, i.e. android-llm-device-farm-test-spec-v2.yml?