8
8
9
9
package com .example .executorchllamademo ;
10
10
11
- import static junit .framework .TestCase .assertTrue ;
12
11
import static org .junit .Assert .assertEquals ;
13
12
import static org .junit .Assert .assertFalse ;
14
13
14
+ import android .os .Bundle ;
15
15
import androidx .test .ext .junit .runners .AndroidJUnit4 ;
16
+ import androidx .test .platform .app .InstrumentationRegistry ;
17
+ import java .io .File ;
16
18
import java .util .ArrayList ;
19
+ import java .util .Arrays ;
17
20
import java .util .List ;
18
21
import org .junit .Test ;
19
22
import org .junit .runner .RunWith ;
24
27
public class PerfTest implements LlamaCallback {
25
28
26
29
private static final String RESOURCE_PATH = "/data/local/tmp/llama/" ;
27
- private static final String MODEL_NAME = "xnnpack_llama2.pte" ;
28
30
private static final String TOKENIZER_BIN = "tokenizer.bin" ;
29
31
30
- // From https://github.com/pytorch/executorch/blob/main/examples/models/llama2/README.md
31
- private static final Float EXPECTED_TPS = 10.0F ;
32
-
33
32
private final List <String > results = new ArrayList <>();
34
33
private final List <Float > tokensPerSecond = new ArrayList <>();
35
34
36
35
@ Test
37
36
public void testTokensPerSecond () {
38
- String modelPath = RESOURCE_PATH + MODEL_NAME ;
39
37
String tokenizerPath = RESOURCE_PATH + TOKENIZER_BIN ;
40
- LlamaModule mModule = new LlamaModule (modelPath , tokenizerPath , 0.8f );
38
+ // Find out the model name
39
+ File directory = new File (RESOURCE_PATH );
40
+ Arrays .stream (directory .listFiles ())
41
+ .filter (file -> file .getName ().endsWith (".pte" ))
42
+ .forEach (
43
+ model -> {
44
+ LlamaModule mModule = new LlamaModule (model .getPath (), tokenizerPath , 0.8f );
45
+ // Print the model name because there might be more than one of them
46
+ report ("ModelName" , model .getName ());
41
47
42
- int loadResult = mModule .load ();
43
- // Check that the model can be load successfully
44
- assertEquals (0 , loadResult );
48
+ int loadResult = mModule .load ();
49
+ // Check that the model can be load successfully
50
+ assertEquals (0 , loadResult );
45
51
46
- // Run a testing prompt
47
- mModule .generate ("How do you do! I'm testing llama2 on mobile device" , PerfTest .this );
48
- assertFalse (tokensPerSecond .isEmpty ());
52
+ // Run a testing prompt
53
+ mModule .generate ("How do you do! I'm testing llama2 on mobile device" , PerfTest .this );
54
+ assertFalse (tokensPerSecond .isEmpty ());
49
55
50
- final Float tps = tokensPerSecond .get (tokensPerSecond .size () - 1 );
51
- assertTrue (
52
- "The observed TPS " + tps + " is less than the expected TPS " + EXPECTED_TPS ,
53
- tps >= EXPECTED_TPS );
56
+ final Float tps = tokensPerSecond .get (tokensPerSecond .size () - 1 );
57
+ report ("TPS" , tps );
58
+ });
54
59
}
55
60
56
61
@ Override
@@ -62,4 +67,16 @@ public void onResult(String result) {
62
67
public void onStats (float tps ) {
63
68
tokensPerSecond .add (tps );
64
69
}
70
+
71
+ private void report (final String metric , final Float value ) {
72
+ Bundle bundle = new Bundle ();
73
+ bundle .putFloat (metric , value );
74
+ InstrumentationRegistry .getInstrumentation ().sendStatus (0 , bundle );
75
+ }
76
+
77
+ private void report (final String key , final String value ) {
78
+ Bundle bundle = new Bundle ();
79
+ bundle .putString (key , value );
80
+ InstrumentationRegistry .getInstrumentation ().sendStatus (0 , bundle );
81
+ }
65
82
}
0 commit comments