Skip to content

Commit 1e9a596

Browse files
cccclaifacebook-github-bot
authored andcommitted
Support multiple prompts in the runner
Summary: As a preparation for the multiturn conversation, we can provide multiple prompts and execute them in sequence. Example command: ``` ./qnn_llama3_2_runner --model_path hybrid_llama_qnn.pte --tokenizer_path tiktokenizer.bin --eval_mode 1 --prompt "Once upon a time" --prompt "girl named Lily." --prompt "her toys and her favorite toy was a big," --kv_updater "ShiftPointer" --logits_scale 0.1 --output_path output.txt --num_iters 1 ``` It will be hard to use any char as delimiter, so we use `--prompt` to explicitly mark a prompt and collect them together. Differential Revision: D72276104
1 parent 150cbe1 commit 1e9a596

File tree

1 file changed

+19
-3
lines changed

1 file changed

+19
-3
lines changed

examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ DEFINE_string(
3535
"Records inference speed. For CI purpose.");
3636
DEFINE_string(tokenizer_path, "tokenizer.bin", "Tokenizer stuff.");
3737
DEFINE_string(prompt, "The answer to the ultimate question is", "Prompt.");
38+
3839
DEFINE_string(
3940
system_prompt,
4041
"",
@@ -59,9 +60,22 @@ DEFINE_string(
5960
"SmartMask");
6061
DEFINE_int32(num_iters, 1, "total num of iterations to run.");
6162

63+
std::vector<std::string> CollectPrompts(int argc, char** argv) {
64+
// Collect all prompts from command line, example usage:
65+
// --prompt "prompt1" --prompt "prompt2" --prompt "prompt3"
66+
std::vector<std::string> prompts;
67+
for (int i = 1; i < argc; i++) {
68+
if (std::string(argv[i]) == "--prompt" && i + 1 < argc) {
69+
prompts.push_back(argv[i+1]);
70+
i++; // Skip the next argument
71+
}
72+
}
73+
return prompts;
74+
}
75+
6276
int main(int argc, char** argv) {
77+
std::vector<std::string> prompts = CollectPrompts(argc, argv);
6378
gflags::ParseCommandLineFlags(&argc, &argv, true);
64-
6579
// create llama runner
6680
example::Runner runner(
6781
{FLAGS_model_path},
@@ -83,11 +97,13 @@ int main(int argc, char** argv) {
8397
};
8498
// generate tokens & store inference output
8599
for (int i = 0; i < FLAGS_num_iters; i++) {
86-
runner.generate(
100+
for (const auto& prompt : prompts) {
101+
runner.generate(
87102
FLAGS_seq_len,
88-
FLAGS_prompt.c_str(),
103+
prompt.c_str(),
89104
FLAGS_system_prompt.c_str(),
90105
callback);
106+
}
91107
}
92108
fout.write(buf.data(), buf.size());
93109
fout.close();

0 commit comments

Comments
 (0)