-
Notifications
You must be signed in to change notification settings - Fork 295
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Not hardcode llama2 model in perf test #4657
Merged
Merged
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
d87872b
Not hardcode llama2 model in perf test
huydhn 820b50b
Also include the new spec
huydhn b64f631
Debug
huydhn 6d41809
Fix compilation error
huydhn b92e471
Debug print
huydhn 55024d8
Merge branch 'main' into hack-the-tps-value
huydhn 3f69c55
Attempt to use sendStatus
huydhn 3c084ec
Report TPS via instrument status
huydhn 9372c8f
Just use the spec from S3
huydhn 8a8aefe
It's working
huydhn 56508c4
Also print the model name
huydhn 51dfaf3
Fix lint
huydhn 2eff98c
Minor tweak
huydhn 1c752fb
Fix internal lint
huydhn 47e413b
Update spec one more time
huydhn c8b5c94
Check for passing test last
huydhn 2d3c2c0
Lint yet
huydhn d08a5eb
Address review comments
huydhn File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,12 +8,15 @@ | |
|
||
package com.example.executorchllamademo; | ||
|
||
import static junit.framework.TestCase.assertTrue; | ||
import static org.junit.Assert.assertEquals; | ||
import static org.junit.Assert.assertFalse; | ||
|
||
import android.os.Bundle; | ||
import androidx.test.ext.junit.runners.AndroidJUnit4; | ||
import androidx.test.platform.app.InstrumentationRegistry; | ||
import java.io.File; | ||
import java.util.ArrayList; | ||
import java.util.Arrays; | ||
import java.util.List; | ||
import org.junit.Test; | ||
import org.junit.runner.RunWith; | ||
|
@@ -24,33 +27,35 @@ | |
public class PerfTest implements LlamaCallback { | ||
|
||
private static final String RESOURCE_PATH = "/data/local/tmp/llama/"; | ||
private static final String MODEL_NAME = "xnnpack_llama2.pte"; | ||
private static final String TOKENIZER_BIN = "tokenizer.bin"; | ||
|
||
// From https://github.com/pytorch/executorch/blob/main/examples/models/llama2/README.md | ||
private static final Float EXPECTED_TPS = 10.0F; | ||
|
||
private final List<String> results = new ArrayList<>(); | ||
private final List<Float> tokensPerSecond = new ArrayList<>(); | ||
|
||
@Test | ||
public void testTokensPerSecond() { | ||
String modelPath = RESOURCE_PATH + MODEL_NAME; | ||
String tokenizerPath = RESOURCE_PATH + TOKENIZER_BIN; | ||
LlamaModule mModule = new LlamaModule(modelPath, tokenizerPath, 0.8f); | ||
// Find out the model name | ||
File directory = new File(RESOURCE_PATH); | ||
Arrays.stream(directory.listFiles()) | ||
.filter(file -> file.getName().endsWith(".pte")) | ||
.forEach( | ||
model -> { | ||
LlamaModule mModule = new LlamaModule(model.getPath(), tokenizerPath, 0.8f); | ||
// Print the model name because there might be more than one of them | ||
report("ModelName", model.getName()); | ||
|
||
int loadResult = mModule.load(); | ||
// Check that the model can be load successfully | ||
assertEquals(0, loadResult); | ||
int loadResult = mModule.load(); | ||
// Check that the model can be load successfully | ||
assertEquals(0, loadResult); | ||
Comment on lines
+48
to
+50
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. linter |
||
|
||
// Run a testing prompt | ||
mModule.generate("How do you do! I'm testing llama2 on mobile device", PerfTest.this); | ||
assertFalse(tokensPerSecond.isEmpty()); | ||
// Run a testing prompt | ||
mModule.generate("How do you do! I'm testing llama2 on mobile device", PerfTest.this); | ||
assertFalse(tokensPerSecond.isEmpty()); | ||
|
||
final Float tps = tokensPerSecond.get(tokensPerSecond.size() - 1); | ||
assertTrue( | ||
"The observed TPS " + tps + " is less than the expected TPS " + EXPECTED_TPS, | ||
tps >= EXPECTED_TPS); | ||
final Float tps = tokensPerSecond.get(tokensPerSecond.size() - 1); | ||
report("TPS", tps); | ||
}); | ||
} | ||
|
||
@Override | ||
|
@@ -62,4 +67,16 @@ public void onResult(String result) { | |
public void onStats(float tps) { | ||
tokensPerSecond.add(tps); | ||
} | ||
|
||
private void report(final String metric, final Float value) { | ||
Bundle bundle = new Bundle(); | ||
bundle.putFloat(metric, value); | ||
InstrumentationRegistry.getInstrumentation().sendStatus(0, bundle); | ||
} | ||
|
||
private void report(final String key, final String value) { | ||
Bundle bundle = new Bundle(); | ||
bundle.putString(key, value); | ||
InstrumentationRegistry.getInstrumentation().sendStatus(0, bundle); | ||
} | ||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replace "llama" with "llm" for those paths as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I remember this path
/data/local/tmp/llama/
was hard-coded in the app before. If that's not the case anymore, I could update it, maybe after #4676 lands to avoid the need to upload the spec manually