Skip to content

Commit

Permalink
[Eagle] Avoid worker - engine transfer for hidden states
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 committed Apr 30, 2024
1 parent ca7cdcc commit e871bcf
Show file tree
Hide file tree
Showing 12 changed files with 188 additions and 450 deletions.
2 changes: 1 addition & 1 deletion cpp/serve/draft_token_workspace_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ void DraftTokenWorkspaceManagerObj::AllocWorkspace(ModelWorkspace* workspace,
NDArray::Empty({max_num_tokens_, vocab_size_}, DataType::Float(32), device_);
if (require_hidden_states) {
workspace->draft_hidden_states_storage =
NDArray::Empty({max_num_tokens_, hidden_size_}, hidden_states_dtype_, device_);
ft_.Empty({max_num_tokens_, hidden_size_}, hidden_states_dtype_, device_);
}
}

Expand Down
22 changes: 8 additions & 14 deletions cpp/serve/engine_actions/eagle_batch_draft.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,19 +83,15 @@ class EagleBatchDraftActionObj : public EngineActionObj {
mstates.push_back(rsentry->mstates[model_id]);
}
// draft_length_ rounds of draft proposal.
ObjectRef last_hidden_states{nullptr};
NDArray hidden_states = Downcast<NDArray>(model_workspaces_[model_id].hidden_states);
ObjectRef hidden_states = model_workspaces_[model_id].hidden_states;
// Concat last hidden_states
draft_token_slots_.clear();
if (draft_length_ > 1) {
for (int i = 0; i < num_rsentries; ++i) {
draft_token_slots_.push_back(mstates[i]->draft_token_slots.back());
}
hidden_states = Downcast<NDArray>(models_[model_id]->GatherHiddenStates(
model_workspaces_[0].draft_hidden_states_storage, draft_token_slots_, &hidden_states));
ICHECK(hidden_states->ndim == 2);
last_hidden_states = hidden_states.CreateView(
{hidden_states->shape[0], 1, hidden_states->shape[1]}, hidden_states->dtype);
hidden_states = models_[model_id]->GatherHiddenStates(
model_workspaces_[0].draft_hidden_states_storage, draft_token_slots_, &hidden_states);
}
// The first draft token has been generated in prefill/verify stage
for (int draft_id = 1; draft_id < draft_length_; ++draft_id) {
Expand All @@ -114,11 +110,10 @@ class EagleBatchDraftActionObj : public EngineActionObj {

// - Invoke model decode.
RECORD_EVENT(trace_recorder_, request_ids, "start proposal decode");
ObjectRef fused_hidden_states = models_[model_id]->FuseEmbedHidden(
embeddings, last_hidden_states, /*batch_size*/ num_rsentries, /*seq_len*/ 1);
hidden_states =
models_[model_id]->BatchDecodeToLastHidden(fused_hidden_states, request_internal_ids);
last_hidden_states = hidden_states;
ObjectRef fused_embedding_hidden_states = models_[model_id]->FuseEmbedHidden(
embeddings, hidden_states, /*batch_size*/ num_rsentries, /*seq_len*/ 1);
hidden_states = models_[model_id]->BatchDecodeToLastHidden(fused_embedding_hidden_states,
request_internal_ids);
NDArray logits;
if (models_[model_id]->CanGetLogits()) {
logits = models_[model_id]->GetLogits(hidden_states, /*batch_size*/ num_rsentries,
Expand All @@ -145,11 +140,10 @@ class EagleBatchDraftActionObj : public EngineActionObj {
// Fill range [0, num_rsentries) into `sample_indices`.
std::vector<int> sample_indices(num_rsentries);
std::iota(sample_indices.begin(), sample_indices.end(), 0);
std::vector<NDArray> prob_dist;
NDArray renormalized_probs = sampler_->BatchRenormalizeProbsByTopP(
probs_on_device, sample_indices, request_ids, generation_cfg);
std::vector<SampleResult> sample_results = sampler_->BatchSampleTokensWithProbAfterTopP(
renormalized_probs, sample_indices, request_ids, generation_cfg, rngs, &prob_dist);
renormalized_probs, sample_indices, request_ids, generation_cfg, rngs);
ICHECK_EQ(sample_results.size(), num_rsentries);

// - Add draft token to the state.
Expand Down
55 changes: 10 additions & 45 deletions cpp/serve/engine_actions/eagle_batch_verify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ class EagleBatchVerifyActionObj : public EngineActionObj {
Array<GenerationConfig> generation_cfg;
std::vector<RandomGenerator*> rngs;
std::vector<std::vector<SampleResult>> draft_output_tokens;
std::vector<std::vector<NDArray>> draft_output_prob_dist;
request_internal_ids.reserve(num_rsentries);
all_tokens_to_verify.reserve(total_draft_length);
verify_request_mstates.reserve(num_rsentries);
Expand Down Expand Up @@ -113,12 +112,8 @@ class EagleBatchVerifyActionObj : public EngineActionObj {
RECORD_EVENT(trace_recorder_, request_ids, "finish verify embedding");

RECORD_EVENT(trace_recorder_, request_ids, "start verify");
ObjectRef fused_hidden_states = models_[verify_model_id_]->FuseEmbedHidden(
embeddings, NDArray(), 1, cum_verify_lengths[num_rsentries]);
NDArray hidden_states = models_[verify_model_id_]->BatchVerifyToLastHidden(
fused_hidden_states, request_internal_ids, verify_lengths);
ICHECK_EQ(hidden_states->ndim, 3);
ICHECK_EQ(hidden_states->shape[0], 1);
ObjectRef hidden_states = models_[verify_model_id_]->BatchVerifyToLastHidden(
embeddings, request_internal_ids, verify_lengths);
NDArray logits =
models_[verify_model_id_]->GetLogits(hidden_states, 1, cum_verify_lengths[num_rsentries]);
RECORD_EVENT(trace_recorder_, request_ids, "finish verify");
Expand Down Expand Up @@ -179,16 +174,11 @@ class EagleBatchVerifyActionObj : public EngineActionObj {

{
// One step draft for the following steps
NDArray last_hidden_states_nd = hidden_states.CreateView(
{hidden_states->shape[0] * hidden_states->shape[1], hidden_states->shape[2]},
hidden_states->dtype);

hidden_states = Downcast<NDArray>(models_[draft_model_id_]->GatherHiddenStates(
last_hidden_states_nd, last_accepted_hidden_positions,
&model_workspaces_[draft_model_id_].hidden_states));
ICHECK(hidden_states->ndim == 2);
hidden_states = hidden_states.CreateView(
{hidden_states->shape[0], 1, hidden_states->shape[1]}, hidden_states->dtype);
// Gather hidden states for the last accepted tokens.
hidden_states = models_[draft_model_id_]->GatherHiddenStates(
hidden_states, last_accepted_hidden_positions,
&model_workspaces_[draft_model_id_].hidden_states);

std::vector<int> input_tokens;
Array<RequestModelState> mstates;
Expand All @@ -210,10 +200,10 @@ class EagleBatchVerifyActionObj : public EngineActionObj {

// - Invoke model decode.
RECORD_EVENT(trace_recorder_, request_ids, "start proposal decode");
ObjectRef fused_hidden_states = models_[draft_model_id_]->FuseEmbedHidden(
ObjectRef fused_embedding_hidden_states = models_[draft_model_id_]->FuseEmbedHidden(
embeddings, hidden_states, /*batch_size*/ num_rsentries, /*seq_len*/ 1);
hidden_states = models_[draft_model_id_]->BatchDecodeToLastHidden(fused_hidden_states,
request_internal_ids);
hidden_states = models_[draft_model_id_]->BatchDecodeToLastHidden(
fused_embedding_hidden_states, request_internal_ids);

if (models_[draft_model_id_]->CanGetLogits()) {
logits = models_[draft_model_id_]->GetLogits(hidden_states, /*batch_size*/ num_rsentries,
Expand All @@ -239,22 +229,17 @@ class EagleBatchVerifyActionObj : public EngineActionObj {
// Fill range [0, num_rsentries) into `sample_indices`.
std::vector<int> sample_indices(num_rsentries);
std::iota(sample_indices.begin(), sample_indices.end(), 0);
std::vector<NDArray> prob_dist;
NDArray renormalized_probs = sampler_->BatchRenormalizeProbsByTopP(
probs_on_device, sample_indices, request_ids, generation_cfg);
std::vector<SampleResult> sample_results = sampler_->BatchSampleTokensWithProbAfterTopP(
renormalized_probs, sample_indices, request_ids, generation_cfg, rngs, &prob_dist);
renormalized_probs, sample_indices, request_ids, generation_cfg, rngs);
ICHECK_EQ(sample_results.size(), num_rsentries);

// - Slice and save hidden_states_for_sample
draft_token_workspace_manager_->AllocSlots(num_rsentries, &draft_token_slots_);
models_[draft_model_id_]->ScatterDraftProbs(
renormalized_probs, draft_token_slots_,
&model_workspaces_[verify_model_id_].draft_probs_storage);
ICHECK(hidden_states->ndim == 3);
hidden_states = hidden_states.CreateView(
{hidden_states->shape[0] * hidden_states->shape[1], hidden_states->shape[2]},
hidden_states->dtype);
models_[draft_model_id_]->ScatterHiddenStates(
hidden_states, draft_token_slots_,
&model_workspaces_[verify_model_id_].draft_hidden_states_storage);
Expand Down Expand Up @@ -326,26 +311,6 @@ class EagleBatchVerifyActionObj : public EngineActionObj {
return num_required_pages <= num_available_pages;
}

/*!
* \brief Get one item from a hidden_states array, which corresponds to the last token.
* \param hidden_states The hidden_states of all the tokens.
* \param token_pos The desired token position in the sequence.
* \return The desired token's hidden_states
*/
NDArray GetTokenHidden(NDArray hidden_states, int token_pos) {
ICHECK_EQ(hidden_states->ndim, 3);
NDArray last_hidden_on_device =
NDArray::Empty({hidden_states->shape[2]}, hidden_states->dtype, hidden_states->device);

int64_t ndata = hidden_states->shape[2];
const int16_t* __restrict p_hidden =
static_cast<int16_t*>(__builtin_assume_aligned(hidden_states->data, 2)) +
(token_pos * ndata);

last_hidden_on_device.CopyFromBytes(p_hidden, ndata * sizeof(int16_t));
return last_hidden_on_device;
}

/*!
* \brief The model to run decode in. When there are multiple
* models, the `Step` function of the created action will not take effect.
Expand Down
70 changes: 29 additions & 41 deletions cpp/serve/engine_actions/eagle_new_request_prefill.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj {
// - Get embedding and run prefill for each model.
std::vector<int> prefill_lengths;
prefill_lengths.resize(/*size=*/num_rsentries, /*value=*/-1);
NDArray hidden_states_for_input{nullptr};
NDArray hidden_states_for_sample{nullptr};
ObjectRef hidden_states_for_input{nullptr};
ObjectRef hidden_states_for_sample{nullptr};
NDArray logits_for_sample{nullptr};
// A map used to record the entry and child_idx pair needed to fork sequence.
// The base model (id 0) should record all the pairs and all the small models
Expand Down Expand Up @@ -167,14 +167,17 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj {
}

RECORD_EVENT(trace_recorder_, request_ids, "start prefill");
ObjectRef fused_hidden_states = models_[model_id]->FuseEmbedHidden(
embeddings, hidden_states_for_input, /*batch_size*/ 1, /*seq_len*/ cum_prefill_length);
NDArray hidden_states = models_[model_id]->BatchPrefillToLastHidden(
fused_hidden_states, request_internal_ids, prefill_lengths);
ObjectRef embedding_or_hidden_states{nullptr};
if (model_id == 0) {
embedding_or_hidden_states = embeddings;
} else {
embedding_or_hidden_states = models_[model_id]->FuseEmbedHidden(
embeddings, hidden_states_for_input, /*batch_size*/ 1, /*seq_len*/ cum_prefill_length);
}
// hidden_states: (b * s, h)
ObjectRef hidden_states = models_[model_id]->BatchPrefillToLastHidden(
embedding_or_hidden_states, request_internal_ids, prefill_lengths);
RECORD_EVENT(trace_recorder_, request_ids, "finish prefill");
ICHECK_EQ(hidden_states->ndim, 3);
ICHECK_EQ(hidden_states->shape[0], 1);
ICHECK_EQ(hidden_states->shape[1], cum_prefill_length);

if (model_id == 0) {
// We only need to sample for model 0 in prefill.
Expand All @@ -183,14 +186,23 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj {

// Whether to use base model to get logits.
int sample_model_id = !models_[model_id]->CanGetLogits() ? 0 : model_id;
hidden_states_for_sample = models_[sample_model_id]->BatchSelectLastHidden(
hidden_states, request_internal_ids, prefill_lengths);

std::vector<int> logit_positions;
{
// Prepare the logit positions
logit_positions.reserve(prefill_lengths.size());
int total_len = 0;
for (int i = 0; i < prefill_lengths.size(); ++i) {
total_len += prefill_lengths[i];
logit_positions.push_back(total_len - 1);
}
}
// hidden_states_for_sample: (b * s, h)
hidden_states_for_sample = models_[sample_model_id]->GatherHiddenStates(
hidden_states, logit_positions, &model_workspaces_[model_id].hidden_states);
// logits_for_sample: (b * s, v)
logits_for_sample =
models_[sample_model_id]->GetLogits(hidden_states_for_sample, 1, num_rsentries);
ICHECK_EQ(hidden_states_for_sample->ndim, 3);
ICHECK_EQ(hidden_states_for_sample->shape[0], 1);
ICHECK_EQ(hidden_states_for_sample->shape[1], num_rsentries);

// - Update logits.
ICHECK(logits_for_sample.defined());
Array<GenerationConfig> generation_cfg;
Expand Down Expand Up @@ -278,11 +290,11 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj {
rsentry_activated.push_back(true);
}
}
std::vector<NDArray> prob_dist;

NDArray renormalized_probs = sampler_->BatchRenormalizeProbsByTopP(
probs_on_device, sample_indices, request_ids, generation_cfg);
std::vector<SampleResult> sample_results = sampler_->BatchSampleTokensWithProbAfterTopP(
renormalized_probs, sample_indices, request_ids, generation_cfg, rngs, &prob_dist);
renormalized_probs, sample_indices, request_ids, generation_cfg, rngs);
ICHECK_EQ(sample_results.size(), rsentries_for_sample.size());

// - Update the committed tokens of states.
Expand Down Expand Up @@ -311,10 +323,6 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj {
models_[model_id]->ScatterDraftProbs(renormalized_probs, draft_token_slots_,
&model_workspaces_[0].draft_probs_storage);
if (engine_config_->spec_draft_length > 1) {
hidden_states_for_sample = hidden_states_for_sample.CreateView(
{hidden_states_for_sample->shape[0] * hidden_states_for_sample->shape[1],
hidden_states_for_sample->shape[2]},
hidden_states_for_sample->dtype);
models_[model_id]->ScatterHiddenStates(hidden_states_for_sample, draft_token_slots_,
&model_workspaces_[0].draft_hidden_states_storage);
}
Expand Down Expand Up @@ -567,26 +575,6 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj {
ICHECK(false) << "Cannot reach here";
}

/*!
* \brief Get one item from a hidden_states array, which corresponds to the last token.
* \param hidden_states The hidden_states of all the tokens.
* \param token_pos The desired token position in the sequence.
* \return The desired token's hidden_states
*/
NDArray GetTokenHidden(NDArray hidden_states, int token_pos) {
ICHECK_EQ(hidden_states->ndim, 3);
NDArray last_hidden_on_device =
NDArray::Empty({hidden_states->shape[2]}, hidden_states->dtype, hidden_states->device);

int64_t ndata = hidden_states->shape[2];
const int16_t* __restrict p_hidden =
static_cast<int16_t*>(__builtin_assume_aligned(hidden_states->data, 2)) +
(token_pos * ndata);

last_hidden_on_device.CopyFromBytes(p_hidden, ndata * sizeof(int16_t));
return last_hidden_on_device;
}

/*! \brief The models to run prefill in. */
Array<Model> models_;
/*! \brief The logit processor. */
Expand Down
7 changes: 4 additions & 3 deletions cpp/serve/function_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ void FunctionTable::_InitFunctions() {
Module mod = this->use_disco ? this->disco_mod->DebugGetFromRemote(0) : this->local_vm;
this->get_logits_func_ = mod_get_func("get_logits");
this->batch_get_logits_func_ = mod_get_func("batch_get_logits");
this->batch_select_last_hidden_func_ = mod->GetFunction("batch_select_last_hidden_states", true);
this->batch_select_last_hidden_func_ = mod_get_func("batch_select_last_hidden_states");
this->softmax_func_ = mod->GetFunction("softmax_with_temperature", true);
this->apply_logit_bias_func_ = mod->GetFunction("apply_logit_bias_inplace", true);
this->apply_penalty_func_ = mod->GetFunction("apply_penalty_inplace", true);
Expand Down Expand Up @@ -259,11 +259,12 @@ void FunctionTable::_InitFunctions() {
this->nd_get_shape_func_ = get_global_func("vm.builtin.shape_of");
this->nd_copy_embedding_to_offset_func_ = get_global_func("mlc.copy_embedding_to_offset");
support_backtracking_kv_ = true;
this->tuple_getitem_func_ = get_global_func("vm.builtin.tuple_getitem");

this->gather_probs_func_ = mod->GetFunction("gather_probs", true);
this->scatter_probs_func_ = mod->GetFunction("scatter_probs", true);
this->gather_hidden_states_func_ = mod->GetFunction("gather_hidden_states", true);
this->scatter_hidden_states_func_ = mod->GetFunction("scatter_hidden_states", true);
this->gather_hidden_states_func_ = mod_get_func("gather_hidden_states");
this->scatter_hidden_states_func_ = mod_get_func("scatter_hidden_states");
}

ObjectRef FunctionTable::Empty(ShapeTuple shape, DataType dtype, Device device) const {
Expand Down
1 change: 1 addition & 0 deletions cpp/serve/function_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ struct FunctionTable {
PackedFunc nd_view_func_;
PackedFunc nd_get_shape_func_;
PackedFunc nd_copy_embedding_to_offset_func_;
PackedFunc tuple_getitem_func_;
// Auxiliary functions for speculative decoding.
PackedFunc gather_probs_func_;
PackedFunc scatter_probs_func_;
Expand Down
Loading

0 comments on commit e871bcf

Please sign in to comment.