Skip to content

[ExecuTorch][Llama] Change runner to enable chunked prefill #9805

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions examples/models/llama/runner/runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include <executorch/examples/models/llama/runner/runner.h>

#include <algorithm>
#include <ctime>

#include <executorch/extension/llm/runner/util.h>
Expand Down Expand Up @@ -140,7 +141,8 @@ Error Runner::load() {
text_prefiller_ = std::make_unique<llm::TextPrefiller>(
text_decoder_runner_.get(),
metadata_.at(kUseKVCache),
metadata_.at(kEnableDynamicShape));
metadata_.at(kEnableDynamicShape),
metadata_.at(kMaxSeqLen));

text_token_generator_ = std::make_unique<llm::TextTokenGenerator>(
tokenizer_.get(),
Expand Down Expand Up @@ -221,11 +223,11 @@ Error Runner::generate(

ET_CHECK_MSG(num_prompt_tokens >= 1, "Expected at least 1 prompt token");
ET_CHECK_MSG(
num_prompt_tokens < metadata_.at(kMaxSeqLen),
num_prompt_tokens < metadata_.at(kMaxContextLen),
"num_prompt_tokens %d >= max_seq_len_ %" PRId64
", Max seq length exceeded - please increase max seq len value in .../llama2/model.py",
num_prompt_tokens,
metadata_.at(kMaxSeqLen));
metadata_.at(kMaxContextLen));
ET_CHECK_MSG(
num_prompt_tokens < seq_len,
"num_prompt_tokens %d >= seq_len %d, Sequence length exceeded - please increase the seq_len value passed to generate()",
Expand All @@ -242,10 +244,10 @@ Error Runner::generate(
}
int64_t pos = 0;
auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos);
stats_.first_token_ms = llm::time_in_ms();
stats_.prompt_eval_end_ms = llm::time_in_ms();
ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error());
uint64_t cur_token = prefill_res.get();
stats_.first_token_ms = llm::time_in_ms();
stats_.prompt_eval_end_ms = llm::time_in_ms();

// print the first token from prefill. No prev_token so use cur_token for it.
wrapped_callback(
Expand Down
3 changes: 2 additions & 1 deletion examples/models/llava/runner/llava_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ Error LlavaRunner::load() {
text_prefiller_ = std::make_unique<llm::TextPrefiller>(
text_decoder_runner_.get(),
/*use_kv_cache=*/true,
/*enable_parallel_prefill=*/true);
/*enable_parallel_prefill=*/true,
/*max_seq_len=*/128);

// Load the image prefiller
image_prefiller_ = std::make_unique<LlavaImagePrefiller>(module_.get());
Expand Down
47 changes: 45 additions & 2 deletions extension/llm/runner/text_prefiller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
// LLM.

#include <executorch/extension/llm/runner/text_prefiller.h>
#include <algorithm>

namespace executorch {
namespace extension {
Expand All @@ -18,10 +19,13 @@ namespace llm {
TextPrefiller::TextPrefiller(
TextDecoderRunner* text_decoder_runner,
bool use_kv_cache,
bool enable_parallel_prefill)
bool enable_parallel_prefill,
int64_t max_seq_len)
: text_decoder_runner_(text_decoder_runner),
use_kv_cache_(use_kv_cache),
enable_parallel_prefill_(enable_parallel_prefill) {}
enable_parallel_prefill_(enable_parallel_prefill),
max_seq_len_(max_seq_len > 0 ? max_seq_len - 1 : 127) {
} // -1 because for some reason tracing results in this upperbound

::executorch::runtime::Result<uint64_t> TextPrefiller::prefill(
std::vector<uint64_t>& prompt_tokens,
Expand All @@ -30,6 +34,45 @@ ::executorch::runtime::Result<uint64_t> TextPrefiller::prefill(
if (!text_decoder_runner_->is_method_loaded()) {
ET_CHECK_OK_OR_RETURN_ERROR(text_decoder_runner_->load());
}

// Check if we need to chunk the prompt tokens
int32_t num_prompt_tokens = prompt_tokens.size();

// If prompt tokens exceed max_seq_len_, we need to chunk them
if (num_prompt_tokens > max_seq_len_) {
uint64_t cur_token = 0;
int num_tokens_to_process = 0;

while (num_tokens_to_process < num_prompt_tokens) {
auto num_tokens_to_prefill_with = std::min<int>(
num_prompt_tokens - num_tokens_to_process, max_seq_len_);

std::vector<uint64_t> prompt_tokens_to_process(
num_tokens_to_prefill_with);
std::copy(
prompt_tokens.begin() + num_tokens_to_process,
prompt_tokens.begin() + num_tokens_to_process +
num_tokens_to_prefill_with,
prompt_tokens_to_process.begin());

// Process this chunk
auto chunk_result = prefillChunk(prompt_tokens_to_process, start_pos);
ET_CHECK_OK_OR_RETURN_ERROR(chunk_result.error());
cur_token = chunk_result.get();

num_tokens_to_process += num_tokens_to_prefill_with;
}

return cur_token;
} else {
// If prompt tokens don't exceed max_seq_len_, process them directly
return prefillChunk(prompt_tokens, start_pos);
}
}

::executorch::runtime::Result<uint64_t> TextPrefiller::prefillChunk(
std::vector<uint64_t>& prompt_tokens,
int64_t& start_pos) {
// enable_parallel_prefill_ maybe set even when not using kv cache
// When kv cache is not used, start pos is ignored
int32_t num_prompt_tokens = prompt_tokens.size();
Expand Down
15 changes: 14 additions & 1 deletion extension/llm/runner/text_prefiller.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ class ET_EXPERIMENTAL TextPrefiller {
TextPrefiller(
TextDecoderRunner* text_decoder_runner,
bool use_kv_cache_,
bool enable_parallel_prefill);
bool enable_parallel_prefill,
int64_t max_seq_len = 128);
/**
* Prefill an LLM Module with the given text input.
* @param prompt_tokens The text prompt tokens to the LLM Module. Encoded by
Expand All @@ -35,10 +36,22 @@ class ET_EXPERIMENTAL TextPrefiller {
std::vector<uint64_t>& prompt_tokens,
int64_t& start_pos);

/**
* Helper method to prefill a chunk of tokens.
* @param prompt_tokens The chunk of text prompt tokens to process.
* @param start_pos The starting position in KV cache of the input in the LLM
* Module.
* @return The next token of the LLM Module after prefilling this chunk.
*/
::executorch::runtime::Result<uint64_t> prefillChunk(
std::vector<uint64_t>& prompt_tokens,
int64_t& start_pos);

private:
TextDecoderRunner* text_decoder_runner_;
bool use_kv_cache_;
bool enable_parallel_prefill_;
int64_t max_seq_len_;
};

} // namespace llm
Expand Down
Loading