Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Serving] Fix support of large n under low max batch size #2136

Merged
merged 1 commit into from
Apr 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 54 additions & 16 deletions cpp/serve/engine_actions/eagle_new_request_prefill.cc
Original file line number Diff line number Diff line change
Expand Up @@ -217,17 +217,21 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj {
std::vector<int> sample_indices;
std::vector<RequestStateEntry> rsentries_for_sample;
std::vector<RandomGenerator*> rngs;
std::vector<bool> rsentry_activated;
sample_indices.reserve(num_rsentries);
rsentries_for_sample.reserve(num_rsentries);
rngs.reserve(num_rsentries);
rsentry_activated.reserve(num_rsentries);
request_ids.clear();
generation_cfg.clear();
for (int i = 0; i < num_rsentries; ++i) {
const RequestStateEntry& rsentry = prefill_inputs[i].rsentry;
int remaining_num_child_to_activate = prefill_inputs[i].num_child_to_activate;
for (int child_idx : rsentry->child_indices) {
// Only use base model to judge if we need to add child entries.
if (rstates_of_entries[i]->entries[child_idx]->mstates[0]->committed_tokens.empty() ||
fork_rsentry_child_map[i].count(child_idx)) {
if (rstates_of_entries[i]->entries[child_idx]->status == RequestStateStatus::kPending &&
(rstates_of_entries[i]->entries[child_idx]->mstates[0]->committed_tokens.empty() ||
fork_rsentry_child_map[i].count(child_idx))) {
// If rstates_of_entries[i]->entries[child_idx] has no committed token,
// the prefill of the current rsentry will unblock
// rstates_of_entries[i]->entries[child_idx],
Expand All @@ -239,6 +243,16 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj {
generation_cfg.push_back(rsentry->request->generation_cfg);
rngs.push_back(&rstates_of_entries[i]->entries[child_idx]->rng);

// We only fork the first `num_child_to_activate` children.
// The children not being forked will be forked via later prefills.
// Usually `num_child_to_activate` is the same as the number of children.
// But it can be fewer subject to the KV cache max num sequence limit.
if (remaining_num_child_to_activate == 0) {
rsentry_activated.push_back(false);
continue;
}
rsentry_activated.push_back(true);
--remaining_num_child_to_activate;
if (model_id == 0) {
ICHECK(rstates_of_entries[i]->entries[child_idx]->status ==
RequestStateStatus::kPending);
Expand All @@ -261,6 +275,7 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj {
request_ids.push_back(rsentry->request->id);
generation_cfg.push_back(rsentry->request->generation_cfg);
rngs.push_back(&rsentry->rng);
rsentry_activated.push_back(true);
}
}
std::vector<NDArray> prob_dist;
Expand All @@ -275,6 +290,12 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj {
if (model_id == 0) {
for (int mid = 0; mid < static_cast<int>(models_.size()); ++mid) {
rsentries_for_sample[i]->mstates[mid]->CommitToken(sample_results[i]);
if (!rsentry_activated[i]) {
// When the child rsentry is not activated,
// add the sampled token as an input of the mstate for prefill.
rsentries_for_sample[i]->mstates[mid]->inputs.push_back(
TokenData(std::vector<int64_t>{sample_results[i].sampled_token_id.first}));
}
}
// Only base model trigger timing records.
if (rsentries_for_sample[i]->mstates[0]->committed_tokens.size() == 1) {
Expand Down Expand Up @@ -332,7 +353,8 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj {
/*! \brief The class of request state entry and its maximum allowed length for prefill. */
struct PrefillInput {
RequestStateEntry rsentry;
int max_prefill_length;
int max_prefill_length = 0;
int num_child_to_activate = 0;
};

/*!
Expand Down Expand Up @@ -376,30 +398,46 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj {
total_input_length += input_length;
total_required_pages += num_require_pages;
// - Attempt 1. Check if the entire request state entry can fit for prefill.
if (CanPrefill(estate, num_prefill_rsentries + 1 + rsentry->child_indices.size(),
total_input_length, total_required_pages, num_available_pages,
current_total_seq_len, num_running_rsentries)) {
prefill_inputs.push_back({rsentry, input_length});
num_prefill_rsentries += 1 + rsentry->child_indices.size();
bool can_prefill = false;
for (int num_child_to_activate = rsentry->child_indices.size(); num_child_to_activate >= 0;
--num_child_to_activate) {
if (CanPrefill(estate, num_prefill_rsentries + 1 + num_child_to_activate,
total_input_length, total_required_pages, num_available_pages,
current_total_seq_len, num_running_rsentries)) {
prefill_inputs.push_back({rsentry, input_length, num_child_to_activate});
num_prefill_rsentries += 1 + num_child_to_activate;
can_prefill = true;
break;
}
}
if (can_prefill) {
continue;
}
total_input_length -= input_length;
total_required_pages -= num_require_pages;

// - Attempt 2. Check if the request state entry can partially fit by input chunking.
ICHECK_LE(total_input_length, kv_cache_config_->prefill_chunk_size);
input_length =
std::min(input_length, kv_cache_config_->prefill_chunk_size - total_input_length);
if (kv_cache_config_->prefill_chunk_size - total_input_length >= input_length ||
kv_cache_config_->prefill_chunk_size == total_input_length) {
// 1. If the input length can fit the remaining prefill chunk size,
// it means the failure of attempt 1 is not because of the input
// length being too long, and thus chunking does not help.
// 2. If the total input length already reaches the prefill chunk size,
// the current request state entry will not be able to be processed.
// So we can safely return in either case.
prefill_stops = true;
break;
}
input_length = kv_cache_config_->prefill_chunk_size - total_input_length;
num_require_pages =
(input_length + kv_cache_config_->page_size - 1) / kv_cache_config_->page_size;
total_input_length += input_length;
total_required_pages += num_require_pages;
if (input_length > 0 &&
CanPrefill(estate, num_prefill_rsentries + 1 + rsentry->child_indices.size(),
total_input_length, total_required_pages, num_available_pages,
current_total_seq_len, num_running_rsentries)) {
prefill_inputs.push_back({rsentry, input_length});
num_prefill_rsentries += 1 + rsentry->child_indices.size();
if (CanPrefill(estate, num_prefill_rsentries + 1, total_input_length, total_required_pages,
num_available_pages, current_total_seq_len, num_running_rsentries)) {
prefill_inputs.push_back({rsentry, input_length, 0});
num_prefill_rsentries += 1;
}

// - Prefill stops here.
Expand Down
111 changes: 75 additions & 36 deletions cpp/serve/engine_actions/new_request_prefill.cc
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,11 @@ class NewRequestPrefillActionObj : public EngineActionObj {
std::vector<int> sample_indices;
std::vector<RequestStateEntry> rsentries_for_sample;
std::vector<RandomGenerator*> rngs;
std::vector<bool> rsentry_activated;
sample_indices.reserve(num_rsentries);
rsentries_for_sample.reserve(num_rsentries);
rngs.reserve(num_rsentries);
rsentry_activated.reserve(num_rsentries);
request_ids.clear();
generation_cfg.clear();
for (int i = 0; i < num_rsentries; ++i) {
Expand All @@ -179,29 +181,42 @@ class NewRequestPrefillActionObj : public EngineActionObj {
continue;
}

int remaining_num_child_to_activate = prefill_inputs[i].num_child_to_activate;
for (int child_idx : rsentry->child_indices) {
if (rstates_of_entries[i]->entries[child_idx]->mstates[0]->committed_tokens.empty()) {
// If rstates_of_entries[i]->entries[child_idx] has no committed token,
// the prefill of the current rsentry will unblock
// rstates_of_entries[i]->entries[child_idx],
// and thus we want to sample a token for rstates_of_entries[i]->entries[child_idx].
sample_indices.push_back(i);
rsentries_for_sample.push_back(rstates_of_entries[i]->entries[child_idx]);
request_ids.push_back(rsentry->request->id);
generation_cfg.push_back(rsentry->request->generation_cfg);
rngs.push_back(&rstates_of_entries[i]->entries[child_idx]->rng);

ICHECK(rstates_of_entries[i]->entries[child_idx]->status == RequestStateStatus::kPending);
rstates_of_entries[i]->entries[child_idx]->status = RequestStateStatus::kAlive;
for (int model_id = 0; model_id < static_cast<int>(models_.size()); ++model_id) {
int64_t child_internal_id =
rstates_of_entries[i]->entries[child_idx]->mstates[model_id]->internal_id;
models_[model_id]->ForkSequence(rsentry->mstates[model_id]->internal_id,
child_internal_id);
// Enable sliding window for the child sequence if the child is not a parent.
if (rstates_of_entries[i]->entries[child_idx]->child_indices.empty()) {
models_[model_id]->EnableSlidingWindowForSeq(child_internal_id);
}
// If rstates_of_entries[i]->entries[child_idx] has no committed token,
// the prefill of the current rsentry will unblock
// rstates_of_entries[i]->entries[child_idx],
// and thus we want to sample a token for rstates_of_entries[i]->entries[child_idx].
if (rstates_of_entries[i]->entries[child_idx]->status != RequestStateStatus::kPending ||
!rstates_of_entries[i]->entries[child_idx]->mstates[0]->committed_tokens.empty()) {
continue;
}
sample_indices.push_back(i);
rsentries_for_sample.push_back(rstates_of_entries[i]->entries[child_idx]);
request_ids.push_back(rsentry->request->id);
generation_cfg.push_back(rsentry->request->generation_cfg);
rngs.push_back(&rstates_of_entries[i]->entries[child_idx]->rng);

ICHECK(rstates_of_entries[i]->entries[child_idx]->status == RequestStateStatus::kPending);
// We only fork the first `num_child_to_activate` children.
// The children not being forked will be forked via later prefills.
// Usually `num_child_to_activate` is the same as the number of children.
// But it can be fewer subject to the KV cache max num sequence limit.
if (remaining_num_child_to_activate == 0) {
rsentry_activated.push_back(false);
continue;
}
rsentry_activated.push_back(true);
--remaining_num_child_to_activate;
rstates_of_entries[i]->entries[child_idx]->status = RequestStateStatus::kAlive;
for (int model_id = 0; model_id < static_cast<int>(models_.size()); ++model_id) {
int64_t child_internal_id =
rstates_of_entries[i]->entries[child_idx]->mstates[model_id]->internal_id;
models_[model_id]->ForkSequence(rsentry->mstates[model_id]->internal_id,
child_internal_id);
// Enable sliding window for the child sequence if the child is not a parent.
if (rstates_of_entries[i]->entries[child_idx]->child_indices.empty()) {
models_[model_id]->EnableSlidingWindowForSeq(child_internal_id);
}
}
}
Expand All @@ -212,6 +227,7 @@ class NewRequestPrefillActionObj : public EngineActionObj {
request_ids.push_back(rsentry->request->id);
generation_cfg.push_back(rsentry->request->generation_cfg);
rngs.push_back(&rsentry->rng);
rsentry_activated.push_back(true);
}
}
std::vector<SampleResult> sample_results = sampler_->BatchSampleTokens(
Expand All @@ -224,6 +240,12 @@ class NewRequestPrefillActionObj : public EngineActionObj {
for (int i = 0; i < static_cast<int>(rsentries_for_sample.size()); ++i) {
for (const RequestModelState& mstate : rsentries_for_sample[i]->mstates) {
mstate->CommitToken(sample_results[i]);
if (!rsentry_activated[i]) {
// When the child rsentry is not activated,
// add the sampled token as an input of the mstate for prefill.
mstate->inputs.push_back(
TokenData(std::vector<int64_t>{sample_results[i].sampled_token_id.first}));
}
}
if (rsentries_for_sample[i]->mstates[0]->committed_tokens.size() == 1) {
rsentries_for_sample[i]->tprefill_finish = tnow;
Expand Down Expand Up @@ -270,7 +292,8 @@ class NewRequestPrefillActionObj : public EngineActionObj {
/*! \brief The class of request state entry and its maximum allowed length for prefill. */
struct PrefillInput {
RequestStateEntry rsentry;
int max_prefill_length;
int max_prefill_length = 0;
int num_child_to_activate = 0;
};

/*!
Expand Down Expand Up @@ -314,30 +337,46 @@ class NewRequestPrefillActionObj : public EngineActionObj {
total_input_length += input_length;
total_required_pages += num_require_pages;
// - Attempt 1. Check if the entire request state entry can fit for prefill.
if (CanPrefill(estate, num_prefill_rsentries + 1 + rsentry->child_indices.size(),
total_input_length, total_required_pages, num_available_pages,
current_total_seq_len, num_running_rsentries)) {
prefill_inputs.push_back({rsentry, input_length});
num_prefill_rsentries += 1 + rsentry->child_indices.size();
bool can_prefill = false;
for (int num_child_to_activate = rsentry->child_indices.size(); num_child_to_activate >= 0;
--num_child_to_activate) {
if (CanPrefill(estate, num_prefill_rsentries + 1 + num_child_to_activate,
total_input_length, total_required_pages, num_available_pages,
current_total_seq_len, num_running_rsentries)) {
prefill_inputs.push_back({rsentry, input_length, num_child_to_activate});
num_prefill_rsentries += 1 + num_child_to_activate;
can_prefill = true;
break;
}
}
if (can_prefill) {
continue;
}
total_input_length -= input_length;
total_required_pages -= num_require_pages;

// - Attempt 2. Check if the request state entry can partially fit by input chunking.
ICHECK_LE(total_input_length, kv_cache_config_->prefill_chunk_size);
input_length =
std::min(input_length, kv_cache_config_->prefill_chunk_size - total_input_length);
if (kv_cache_config_->prefill_chunk_size - total_input_length >= input_length ||
kv_cache_config_->prefill_chunk_size == total_input_length) {
// 1. If the input length can fit the remaining prefill chunk size,
// it means the failure of attempt 1 is not because of the input
// length being too long, and thus chunking does not help.
// 2. If the total input length already reaches the prefill chunk size,
// the current request state entry will not be able to be processed.
// So we can safely return in either case.
prefill_stops = true;
break;
}
input_length = kv_cache_config_->prefill_chunk_size - total_input_length;
num_require_pages =
(input_length + kv_cache_config_->page_size - 1) / kv_cache_config_->page_size;
total_input_length += input_length;
total_required_pages += num_require_pages;
if (input_length > 0 &&
CanPrefill(estate, num_prefill_rsentries + 1 + rsentry->child_indices.size(),
total_input_length, total_required_pages, num_available_pages,
current_total_seq_len, num_running_rsentries)) {
prefill_inputs.push_back({rsentry, input_length});
num_prefill_rsentries += 1 + rsentry->child_indices.size();
if (CanPrefill(estate, num_prefill_rsentries + 1, total_input_length, total_required_pages,
num_available_pages, current_total_seq_len, num_running_rsentries)) {
prefill_inputs.push_back({rsentry, input_length, 0});
num_prefill_rsentries += 1;
}

// - Prefill stops here.
Expand Down
Loading
Loading