Skip to content

Commit

Permalink
Add a mock perf test for llama2 on Android (#2963)
Browse files Browse the repository at this point in the history
Summary:
I'm trying to setup a simple perf test when running llama2 on Android.  It's naively sent a prompt and record the TPS.  Open for comment about the test here before setting this up on CI.

### Testing

Copy the exported model and the tokenizer as usual, then cd to the app and run `./gradlew :app:connectAndroidTest`.  The test will fail if the model is failed to load or if the TPS is lower than 7 as measure by https://github.com/pytorch/executorch/tree/main/examples/models/llama2

Pull Request resolved: #2963

Reviewed By: kirklandsign

Differential Revision: D55951637

Pulled By: huydhn

fbshipit-source-id: 34c189aefd7e31514fcf49103352ef3cf8e5b2c9
  • Loading branch information
huydhn authored and facebook-github-bot committed Apr 11, 2024
1 parent 2fc99b0 commit d761f99
Showing 1 changed file with 65 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

package com.example.executorchllamademo;

import static junit.framework.TestCase.assertTrue;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;

import androidx.test.ext.junit.runners.AndroidJUnit4;
import java.util.ArrayList;
import java.util.List;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.pytorch.executorch.LlamaCallback;
import org.pytorch.executorch.LlamaModule;

@RunWith(AndroidJUnit4.class)
public class PerfTest implements LlamaCallback {

private static final String RESOURCE_PATH = "/data/local/tmp/llama/";
private static final String MODEL_NAME = "xnnpack_llama2.pte";
private static final String TOKENIZER_BIN = "tokenizer.bin";

// From https://github.com/pytorch/executorch/blob/main/examples/models/llama2/README.md
private static final Float EXPECTED_TPS = 7.0F;

private final List<String> results = new ArrayList<>();
private final List<Float> tokensPerSecond = new ArrayList<>();

@Test
public void testTokensPerSecond() {
String modelPath = RESOURCE_PATH + MODEL_NAME;
String tokenizerPath = RESOURCE_PATH + TOKENIZER_BIN;
LlamaModule mModule = new LlamaModule(modelPath, tokenizerPath, 0.8f);

int loadResult = mModule.load();
// Check that the model can be load successfully
assertEquals(0, loadResult);

// Run a testing prompt
mModule.generate("How do you do! I'm testing llama2 on mobile device", PerfTest.this);
assertFalse(tokensPerSecond.isEmpty());

final Float tps = tokensPerSecond.get(tokensPerSecond.size() - 1);
assertTrue(
"The observed TPS " + tps + " is less than the expected TPS " + EXPECTED_TPS,
tps >= EXPECTED_TPS);
}

@Override
public void onResult(String result) {
results.add(result);
}

@Override
public void onStats(float tps) {
tokensPerSecond.add(tps);
}
}

0 comments on commit d761f99

Please sign in to comment.