Skip to content

Commit ddf0d9e

Browse files
authored
support qnn runner multi iter run
Differential Revision: D70842764 Pull Request resolved: #9071
1 parent 366ad75 commit ddf0d9e

File tree

5 files changed

+42
-13
lines changed

5 files changed

+42
-13
lines changed

examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ DEFINE_string(
2525
model_path,
2626
"kv_llama_qnn.pte",
2727
"Model serialized in flatbuffer format.");
28-
2928
DEFINE_string(
3029
output_path,
3130
"outputs.txt",
@@ -48,7 +47,6 @@ DEFINE_int32(
4847
seq_len,
4948
128,
5049
"Total number of tokens to generate (prompt + output).");
51-
5250
DEFINE_int32(
5351
eval_mode,
5452
1,
@@ -59,6 +57,7 @@ DEFINE_string(
5957
kv_updater,
6058
"How to update kv cache. Choose between SmartMask and ShiftPointer",
6159
"SmartMask");
60+
DEFINE_int32(num_iters, 1, "total num of iterations to run.");
6261

6362
int main(int argc, char** argv) {
6463
gflags::ParseCommandLineFlags(&argc, &argv, true);
@@ -72,7 +71,8 @@ int main(int argc, char** argv) {
7271
FLAGS_logits_offset,
7372
FLAGS_temperature,
7473
FLAGS_eval_mode,
75-
FLAGS_kv_updater);
74+
FLAGS_kv_updater,
75+
FLAGS_num_iters);
7676
std::vector<char> buf;
7777
buf.reserve(5 * FLAGS_seq_len); // assume each token is around 5 char
7878
std::ofstream fout(FLAGS_output_path.c_str());
@@ -82,11 +82,13 @@ int main(int argc, char** argv) {
8282
}
8383
};
8484
// generate tokens & store inference output
85-
runner.generate(
86-
FLAGS_seq_len,
87-
FLAGS_prompt.c_str(),
88-
FLAGS_system_prompt.c_str(),
89-
callback);
85+
for (int i = 0; i < FLAGS_num_iters; i++) {
86+
runner.generate(
87+
FLAGS_seq_len,
88+
FLAGS_prompt.c_str(),
89+
FLAGS_system_prompt.c_str(),
90+
callback);
91+
}
9092
fout.write(buf.data(), buf.size());
9193
fout.close();
9294
return 0;

examples/qualcomm/oss_scripts/llama/runner/io_manager.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,14 @@ void ShiftPointerIoMgr::init_io() {
168168
}
169169
}
170170

171+
void ShiftPointerIoMgr::reset_io() {
172+
IO* ptr = static_cast<IO*>(data_ptr_.get());
173+
std::fill(
174+
ptr->prefill_attention_mask.begin(),
175+
ptr->prefill_attention_mask.end(),
176+
0);
177+
std::fill(ptr->kv_attention_mask.begin(), ptr->kv_attention_mask.end(), 0);
178+
}
171179
void ShiftPointerIoMgr::prepare_kv_io(
172180
const std::vector<Result<MethodMeta>>& methods_meta) {
173181
for (int i = 0; i < modules_.size(); ++i) {
@@ -885,6 +893,17 @@ void SmartMaskIoMgr::init_io() {
885893
ptr->init_io_ptrs(shared_ptr, io_bytes_map);
886894
}
887895

896+
void SmartMaskIoMgr::reset_io() {
897+
IO* ptr = static_cast<IO*>(data_ptr_.get());
898+
int32_t prefill_attn_size = prefill_ar_len_ * context_len_;
899+
int32_t kv_attn_size = kv_ar_len_ * context_len_;
900+
std::fill(
901+
ptr->prefill_attention_mask,
902+
ptr->prefill_attention_mask + prefill_attn_size,
903+
0);
904+
std::fill(ptr->kv_attention_mask, ptr->kv_attention_mask + kv_attn_size, 0);
905+
}
906+
888907
void SmartMaskIoMgr::prepare_kv_io(
889908
const std::vector<Result<MethodMeta>>& methods_meta) {
890909
for (int i = 0; i < modules_.size(); ++i) {

examples/qualcomm/oss_scripts/llama/runner/io_manager.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ class IoMgrBase {
3333
std::vector<std::shared_ptr<executorch::extension::Module>>& modules);
3434
virtual ~IoMgrBase();
3535
virtual void init_io() = 0;
36+
virtual void reset_io() = 0;
3637
virtual void prepare_prefill_io(
3738
const std::vector<
3839
executorch::runtime::Result<executorch::runtime::MethodMeta>>&
@@ -97,6 +98,7 @@ class ShiftPointerIoMgr : public IoMgrBase {
9798
const bool use_int64_token);
9899

99100
void init_io() override;
101+
void reset_io() override;
100102
void prepare_prefill_io(
101103
const std::vector<
102104
executorch::runtime::Result<executorch::runtime::MethodMeta>>&
@@ -199,6 +201,7 @@ class SmartMaskIoMgr : public IoMgrBase {
199201
const bool use_int64_token);
200202

201203
void init_io() override;
204+
void reset_io() override;
202205
void prepare_prefill_io(
203206
const std::vector<
204207
executorch::runtime::Result<executorch::runtime::MethodMeta>>&

examples/qualcomm/oss_scripts/llama/runner/runner.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ Runner::Runner(
4848
const int32_t logits_offset,
4949
const float temperature,
5050
const int eval_mode,
51-
const std::string& kv_updater)
51+
const std::string& kv_updater,
52+
const int num_iters)
5253
: n_bos_(1),
5354
n_eos_(1),
5455
tokenizer_path_(tokenizer_path),
@@ -57,7 +58,8 @@ Runner::Runner(
5758
logits_offset_(logits_offset),
5859
temperature_(temperature),
5960
eval_mode_(static_cast<EvalMode>(eval_mode)),
60-
kv_updater_(kv_updater) {
61+
kv_updater_(kv_updater),
62+
num_iters_(num_iters) {
6163
for (size_t i = 0; i < models_path.size(); ++i) {
6264
modules_.push_back(std::make_shared<Module>(
6365
models_path[i], Module::LoadMode::MmapUseMlockIgnoreErrors));
@@ -280,7 +282,7 @@ Error Runner::generate(
280282
std::unordered_map<std::string, std::vector<std::vector<Tensor>>>
281283
input_tensors, output_tensors;
282284
std::unordered_map<std::string, std::vector<std::vector<EValue>>> inputs;
283-
if (!is_loaded()) {
285+
if (!is_loaded() || (num_iters_ > 1)) {
284286
stats_.model_load_start_ms = time_in_ms();
285287
ET_CHECK_OK_OR_RETURN_ERROR(load());
286288
for (auto method_name : method_names_) {
@@ -445,7 +447,8 @@ Error Runner::generate(
445447
if (stats_callback) {
446448
stats_callback(stats_);
447449
}
448-
450+
io_mgr_->reset_io();
451+
prompt_.clear();
449452
return Error::Ok;
450453
}
451454

examples/qualcomm/oss_scripts/llama/runner/runner.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ class Runner {
3434
const int32_t logits_offset,
3535
const float temperature,
3636
const int eval_mode,
37-
const std::string& kv_updater);
37+
const std::string& kv_updater,
38+
const int num_iters);
3839

3940
struct Stats {
4041
// Scaling factor for timestamps - in this case, we use ms.
@@ -117,6 +118,7 @@ class Runner {
117118
std::vector<std::string> method_names_;
118119
LlamaVersion llama_version_;
119120
std::string kv_updater_;
121+
int num_iters_;
120122
};
121123

122124
} // namespace example

0 commit comments

Comments
 (0)