Skip to content

Commit e871bcf

Browse files
committed
[Eagle] Avoid worker - engine transfer for hidden states
1 parent ca7cdcc commit e871bcf

File tree

12 files changed

+188
-450
lines changed

12 files changed

+188
-450
lines changed

cpp/serve/draft_token_workspace_manager.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ void DraftTokenWorkspaceManagerObj::AllocWorkspace(ModelWorkspace* workspace,
4545
NDArray::Empty({max_num_tokens_, vocab_size_}, DataType::Float(32), device_);
4646
if (require_hidden_states) {
4747
workspace->draft_hidden_states_storage =
48-
NDArray::Empty({max_num_tokens_, hidden_size_}, hidden_states_dtype_, device_);
48+
ft_.Empty({max_num_tokens_, hidden_size_}, hidden_states_dtype_, device_);
4949
}
5050
}
5151

cpp/serve/engine_actions/eagle_batch_draft.cc

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -83,19 +83,15 @@ class EagleBatchDraftActionObj : public EngineActionObj {
8383
mstates.push_back(rsentry->mstates[model_id]);
8484
}
8585
// draft_length_ rounds of draft proposal.
86-
ObjectRef last_hidden_states{nullptr};
87-
NDArray hidden_states = Downcast<NDArray>(model_workspaces_[model_id].hidden_states);
86+
ObjectRef hidden_states = model_workspaces_[model_id].hidden_states;
8887
// Concat last hidden_states
8988
draft_token_slots_.clear();
9089
if (draft_length_ > 1) {
9190
for (int i = 0; i < num_rsentries; ++i) {
9291
draft_token_slots_.push_back(mstates[i]->draft_token_slots.back());
9392
}
94-
hidden_states = Downcast<NDArray>(models_[model_id]->GatherHiddenStates(
95-
model_workspaces_[0].draft_hidden_states_storage, draft_token_slots_, &hidden_states));
96-
ICHECK(hidden_states->ndim == 2);
97-
last_hidden_states = hidden_states.CreateView(
98-
{hidden_states->shape[0], 1, hidden_states->shape[1]}, hidden_states->dtype);
93+
hidden_states = models_[model_id]->GatherHiddenStates(
94+
model_workspaces_[0].draft_hidden_states_storage, draft_token_slots_, &hidden_states);
9995
}
10096
// The first draft token has been generated in prefill/verify stage
10197
for (int draft_id = 1; draft_id < draft_length_; ++draft_id) {
@@ -114,11 +110,10 @@ class EagleBatchDraftActionObj : public EngineActionObj {
114110

115111
// - Invoke model decode.
116112
RECORD_EVENT(trace_recorder_, request_ids, "start proposal decode");
117-
ObjectRef fused_hidden_states = models_[model_id]->FuseEmbedHidden(
118-
embeddings, last_hidden_states, /*batch_size*/ num_rsentries, /*seq_len*/ 1);
119-
hidden_states =
120-
models_[model_id]->BatchDecodeToLastHidden(fused_hidden_states, request_internal_ids);
121-
last_hidden_states = hidden_states;
113+
ObjectRef fused_embedding_hidden_states = models_[model_id]->FuseEmbedHidden(
114+
embeddings, hidden_states, /*batch_size*/ num_rsentries, /*seq_len*/ 1);
115+
hidden_states = models_[model_id]->BatchDecodeToLastHidden(fused_embedding_hidden_states,
116+
request_internal_ids);
122117
NDArray logits;
123118
if (models_[model_id]->CanGetLogits()) {
124119
logits = models_[model_id]->GetLogits(hidden_states, /*batch_size*/ num_rsentries,
@@ -145,11 +140,10 @@ class EagleBatchDraftActionObj : public EngineActionObj {
145140
// Fill range [0, num_rsentries) into `sample_indices`.
146141
std::vector<int> sample_indices(num_rsentries);
147142
std::iota(sample_indices.begin(), sample_indices.end(), 0);
148-
std::vector<NDArray> prob_dist;
149143
NDArray renormalized_probs = sampler_->BatchRenormalizeProbsByTopP(
150144
probs_on_device, sample_indices, request_ids, generation_cfg);
151145
std::vector<SampleResult> sample_results = sampler_->BatchSampleTokensWithProbAfterTopP(
152-
renormalized_probs, sample_indices, request_ids, generation_cfg, rngs, &prob_dist);
146+
renormalized_probs, sample_indices, request_ids, generation_cfg, rngs);
153147
ICHECK_EQ(sample_results.size(), num_rsentries);
154148

155149
// - Add draft token to the state.

cpp/serve/engine_actions/eagle_batch_verify.cc

Lines changed: 10 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ class EagleBatchVerifyActionObj : public EngineActionObj {
6565
Array<GenerationConfig> generation_cfg;
6666
std::vector<RandomGenerator*> rngs;
6767
std::vector<std::vector<SampleResult>> draft_output_tokens;
68-
std::vector<std::vector<NDArray>> draft_output_prob_dist;
6968
request_internal_ids.reserve(num_rsentries);
7069
all_tokens_to_verify.reserve(total_draft_length);
7170
verify_request_mstates.reserve(num_rsentries);
@@ -113,12 +112,8 @@ class EagleBatchVerifyActionObj : public EngineActionObj {
113112
RECORD_EVENT(trace_recorder_, request_ids, "finish verify embedding");
114113

115114
RECORD_EVENT(trace_recorder_, request_ids, "start verify");
116-
ObjectRef fused_hidden_states = models_[verify_model_id_]->FuseEmbedHidden(
117-
embeddings, NDArray(), 1, cum_verify_lengths[num_rsentries]);
118-
NDArray hidden_states = models_[verify_model_id_]->BatchVerifyToLastHidden(
119-
fused_hidden_states, request_internal_ids, verify_lengths);
120-
ICHECK_EQ(hidden_states->ndim, 3);
121-
ICHECK_EQ(hidden_states->shape[0], 1);
115+
ObjectRef hidden_states = models_[verify_model_id_]->BatchVerifyToLastHidden(
116+
embeddings, request_internal_ids, verify_lengths);
122117
NDArray logits =
123118
models_[verify_model_id_]->GetLogits(hidden_states, 1, cum_verify_lengths[num_rsentries]);
124119
RECORD_EVENT(trace_recorder_, request_ids, "finish verify");
@@ -179,16 +174,11 @@ class EagleBatchVerifyActionObj : public EngineActionObj {
179174

180175
{
181176
// One step draft for the following steps
182-
NDArray last_hidden_states_nd = hidden_states.CreateView(
183-
{hidden_states->shape[0] * hidden_states->shape[1], hidden_states->shape[2]},
184-
hidden_states->dtype);
185177

186-
hidden_states = Downcast<NDArray>(models_[draft_model_id_]->GatherHiddenStates(
187-
last_hidden_states_nd, last_accepted_hidden_positions,
188-
&model_workspaces_[draft_model_id_].hidden_states));
189-
ICHECK(hidden_states->ndim == 2);
190-
hidden_states = hidden_states.CreateView(
191-
{hidden_states->shape[0], 1, hidden_states->shape[1]}, hidden_states->dtype);
178+
// Gather hidden states for the last accepted tokens.
179+
hidden_states = models_[draft_model_id_]->GatherHiddenStates(
180+
hidden_states, last_accepted_hidden_positions,
181+
&model_workspaces_[draft_model_id_].hidden_states);
192182

193183
std::vector<int> input_tokens;
194184
Array<RequestModelState> mstates;
@@ -210,10 +200,10 @@ class EagleBatchVerifyActionObj : public EngineActionObj {
210200

211201
// - Invoke model decode.
212202
RECORD_EVENT(trace_recorder_, request_ids, "start proposal decode");
213-
ObjectRef fused_hidden_states = models_[draft_model_id_]->FuseEmbedHidden(
203+
ObjectRef fused_embedding_hidden_states = models_[draft_model_id_]->FuseEmbedHidden(
214204
embeddings, hidden_states, /*batch_size*/ num_rsentries, /*seq_len*/ 1);
215-
hidden_states = models_[draft_model_id_]->BatchDecodeToLastHidden(fused_hidden_states,
216-
request_internal_ids);
205+
hidden_states = models_[draft_model_id_]->BatchDecodeToLastHidden(
206+
fused_embedding_hidden_states, request_internal_ids);
217207

218208
if (models_[draft_model_id_]->CanGetLogits()) {
219209
logits = models_[draft_model_id_]->GetLogits(hidden_states, /*batch_size*/ num_rsentries,
@@ -239,22 +229,17 @@ class EagleBatchVerifyActionObj : public EngineActionObj {
239229
// Fill range [0, num_rsentries) into `sample_indices`.
240230
std::vector<int> sample_indices(num_rsentries);
241231
std::iota(sample_indices.begin(), sample_indices.end(), 0);
242-
std::vector<NDArray> prob_dist;
243232
NDArray renormalized_probs = sampler_->BatchRenormalizeProbsByTopP(
244233
probs_on_device, sample_indices, request_ids, generation_cfg);
245234
std::vector<SampleResult> sample_results = sampler_->BatchSampleTokensWithProbAfterTopP(
246-
renormalized_probs, sample_indices, request_ids, generation_cfg, rngs, &prob_dist);
235+
renormalized_probs, sample_indices, request_ids, generation_cfg, rngs);
247236
ICHECK_EQ(sample_results.size(), num_rsentries);
248237

249238
// - Slice and save hidden_states_for_sample
250239
draft_token_workspace_manager_->AllocSlots(num_rsentries, &draft_token_slots_);
251240
models_[draft_model_id_]->ScatterDraftProbs(
252241
renormalized_probs, draft_token_slots_,
253242
&model_workspaces_[verify_model_id_].draft_probs_storage);
254-
ICHECK(hidden_states->ndim == 3);
255-
hidden_states = hidden_states.CreateView(
256-
{hidden_states->shape[0] * hidden_states->shape[1], hidden_states->shape[2]},
257-
hidden_states->dtype);
258243
models_[draft_model_id_]->ScatterHiddenStates(
259244
hidden_states, draft_token_slots_,
260245
&model_workspaces_[verify_model_id_].draft_hidden_states_storage);
@@ -326,26 +311,6 @@ class EagleBatchVerifyActionObj : public EngineActionObj {
326311
return num_required_pages <= num_available_pages;
327312
}
328313

329-
/*!
330-
* \brief Get one item from a hidden_states array, which corresponds to the last token.
331-
* \param hidden_states The hidden_states of all the tokens.
332-
* \param token_pos The desired token position in the sequence.
333-
* \return The desired token's hidden_states
334-
*/
335-
NDArray GetTokenHidden(NDArray hidden_states, int token_pos) {
336-
ICHECK_EQ(hidden_states->ndim, 3);
337-
NDArray last_hidden_on_device =
338-
NDArray::Empty({hidden_states->shape[2]}, hidden_states->dtype, hidden_states->device);
339-
340-
int64_t ndata = hidden_states->shape[2];
341-
const int16_t* __restrict p_hidden =
342-
static_cast<int16_t*>(__builtin_assume_aligned(hidden_states->data, 2)) +
343-
(token_pos * ndata);
344-
345-
last_hidden_on_device.CopyFromBytes(p_hidden, ndata * sizeof(int16_t));
346-
return last_hidden_on_device;
347-
}
348-
349314
/*!
350315
* \brief The model to run decode in. When there are multiple
351316
* models, the `Step` function of the created action will not take effect.

cpp/serve/engine_actions/eagle_new_request_prefill.cc

Lines changed: 29 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,8 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj {
8383
// - Get embedding and run prefill for each model.
8484
std::vector<int> prefill_lengths;
8585
prefill_lengths.resize(/*size=*/num_rsentries, /*value=*/-1);
86-
NDArray hidden_states_for_input{nullptr};
87-
NDArray hidden_states_for_sample{nullptr};
86+
ObjectRef hidden_states_for_input{nullptr};
87+
ObjectRef hidden_states_for_sample{nullptr};
8888
NDArray logits_for_sample{nullptr};
8989
// A map used to record the entry and child_idx pair needed to fork sequence.
9090
// The base model (id 0) should record all the pairs and all the small models
@@ -167,14 +167,17 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj {
167167
}
168168

169169
RECORD_EVENT(trace_recorder_, request_ids, "start prefill");
170-
ObjectRef fused_hidden_states = models_[model_id]->FuseEmbedHidden(
171-
embeddings, hidden_states_for_input, /*batch_size*/ 1, /*seq_len*/ cum_prefill_length);
172-
NDArray hidden_states = models_[model_id]->BatchPrefillToLastHidden(
173-
fused_hidden_states, request_internal_ids, prefill_lengths);
170+
ObjectRef embedding_or_hidden_states{nullptr};
171+
if (model_id == 0) {
172+
embedding_or_hidden_states = embeddings;
173+
} else {
174+
embedding_or_hidden_states = models_[model_id]->FuseEmbedHidden(
175+
embeddings, hidden_states_for_input, /*batch_size*/ 1, /*seq_len*/ cum_prefill_length);
176+
}
177+
// hidden_states: (b * s, h)
178+
ObjectRef hidden_states = models_[model_id]->BatchPrefillToLastHidden(
179+
embedding_or_hidden_states, request_internal_ids, prefill_lengths);
174180
RECORD_EVENT(trace_recorder_, request_ids, "finish prefill");
175-
ICHECK_EQ(hidden_states->ndim, 3);
176-
ICHECK_EQ(hidden_states->shape[0], 1);
177-
ICHECK_EQ(hidden_states->shape[1], cum_prefill_length);
178181

179182
if (model_id == 0) {
180183
// We only need to sample for model 0 in prefill.
@@ -183,14 +186,23 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj {
183186

184187
// Whether to use base model to get logits.
185188
int sample_model_id = !models_[model_id]->CanGetLogits() ? 0 : model_id;
186-
hidden_states_for_sample = models_[sample_model_id]->BatchSelectLastHidden(
187-
hidden_states, request_internal_ids, prefill_lengths);
189+
190+
std::vector<int> logit_positions;
191+
{
192+
// Prepare the logit positions
193+
logit_positions.reserve(prefill_lengths.size());
194+
int total_len = 0;
195+
for (int i = 0; i < prefill_lengths.size(); ++i) {
196+
total_len += prefill_lengths[i];
197+
logit_positions.push_back(total_len - 1);
198+
}
199+
}
200+
// hidden_states_for_sample: (b * s, h)
201+
hidden_states_for_sample = models_[sample_model_id]->GatherHiddenStates(
202+
hidden_states, logit_positions, &model_workspaces_[model_id].hidden_states);
203+
// logits_for_sample: (b * s, v)
188204
logits_for_sample =
189205
models_[sample_model_id]->GetLogits(hidden_states_for_sample, 1, num_rsentries);
190-
ICHECK_EQ(hidden_states_for_sample->ndim, 3);
191-
ICHECK_EQ(hidden_states_for_sample->shape[0], 1);
192-
ICHECK_EQ(hidden_states_for_sample->shape[1], num_rsentries);
193-
194206
// - Update logits.
195207
ICHECK(logits_for_sample.defined());
196208
Array<GenerationConfig> generation_cfg;
@@ -278,11 +290,11 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj {
278290
rsentry_activated.push_back(true);
279291
}
280292
}
281-
std::vector<NDArray> prob_dist;
293+
282294
NDArray renormalized_probs = sampler_->BatchRenormalizeProbsByTopP(
283295
probs_on_device, sample_indices, request_ids, generation_cfg);
284296
std::vector<SampleResult> sample_results = sampler_->BatchSampleTokensWithProbAfterTopP(
285-
renormalized_probs, sample_indices, request_ids, generation_cfg, rngs, &prob_dist);
297+
renormalized_probs, sample_indices, request_ids, generation_cfg, rngs);
286298
ICHECK_EQ(sample_results.size(), rsentries_for_sample.size());
287299

288300
// - Update the committed tokens of states.
@@ -311,10 +323,6 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj {
311323
models_[model_id]->ScatterDraftProbs(renormalized_probs, draft_token_slots_,
312324
&model_workspaces_[0].draft_probs_storage);
313325
if (engine_config_->spec_draft_length > 1) {
314-
hidden_states_for_sample = hidden_states_for_sample.CreateView(
315-
{hidden_states_for_sample->shape[0] * hidden_states_for_sample->shape[1],
316-
hidden_states_for_sample->shape[2]},
317-
hidden_states_for_sample->dtype);
318326
models_[model_id]->ScatterHiddenStates(hidden_states_for_sample, draft_token_slots_,
319327
&model_workspaces_[0].draft_hidden_states_storage);
320328
}
@@ -567,26 +575,6 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj {
567575
ICHECK(false) << "Cannot reach here";
568576
}
569577

570-
/*!
571-
* \brief Get one item from a hidden_states array, which corresponds to the last token.
572-
* \param hidden_states The hidden_states of all the tokens.
573-
* \param token_pos The desired token position in the sequence.
574-
* \return The desired token's hidden_states
575-
*/
576-
NDArray GetTokenHidden(NDArray hidden_states, int token_pos) {
577-
ICHECK_EQ(hidden_states->ndim, 3);
578-
NDArray last_hidden_on_device =
579-
NDArray::Empty({hidden_states->shape[2]}, hidden_states->dtype, hidden_states->device);
580-
581-
int64_t ndata = hidden_states->shape[2];
582-
const int16_t* __restrict p_hidden =
583-
static_cast<int16_t*>(__builtin_assume_aligned(hidden_states->data, 2)) +
584-
(token_pos * ndata);
585-
586-
last_hidden_on_device.CopyFromBytes(p_hidden, ndata * sizeof(int16_t));
587-
return last_hidden_on_device;
588-
}
589-
590578
/*! \brief The models to run prefill in. */
591579
Array<Model> models_;
592580
/*! \brief The logit processor. */

cpp/serve/function_table.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ void FunctionTable::_InitFunctions() {
218218
Module mod = this->use_disco ? this->disco_mod->DebugGetFromRemote(0) : this->local_vm;
219219
this->get_logits_func_ = mod_get_func("get_logits");
220220
this->batch_get_logits_func_ = mod_get_func("batch_get_logits");
221-
this->batch_select_last_hidden_func_ = mod->GetFunction("batch_select_last_hidden_states", true);
221+
this->batch_select_last_hidden_func_ = mod_get_func("batch_select_last_hidden_states");
222222
this->softmax_func_ = mod->GetFunction("softmax_with_temperature", true);
223223
this->apply_logit_bias_func_ = mod->GetFunction("apply_logit_bias_inplace", true);
224224
this->apply_penalty_func_ = mod->GetFunction("apply_penalty_inplace", true);
@@ -259,11 +259,12 @@ void FunctionTable::_InitFunctions() {
259259
this->nd_get_shape_func_ = get_global_func("vm.builtin.shape_of");
260260
this->nd_copy_embedding_to_offset_func_ = get_global_func("mlc.copy_embedding_to_offset");
261261
support_backtracking_kv_ = true;
262+
this->tuple_getitem_func_ = get_global_func("vm.builtin.tuple_getitem");
262263

263264
this->gather_probs_func_ = mod->GetFunction("gather_probs", true);
264265
this->scatter_probs_func_ = mod->GetFunction("scatter_probs", true);
265-
this->gather_hidden_states_func_ = mod->GetFunction("gather_hidden_states", true);
266-
this->scatter_hidden_states_func_ = mod->GetFunction("scatter_hidden_states", true);
266+
this->gather_hidden_states_func_ = mod_get_func("gather_hidden_states");
267+
this->scatter_hidden_states_func_ = mod_get_func("scatter_hidden_states");
267268
}
268269

269270
ObjectRef FunctionTable::Empty(ShapeTuple shape, DataType dtype, Device device) const {

cpp/serve/function_table.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ struct FunctionTable {
120120
PackedFunc nd_view_func_;
121121
PackedFunc nd_get_shape_func_;
122122
PackedFunc nd_copy_embedding_to_offset_func_;
123+
PackedFunc tuple_getitem_func_;
123124
// Auxiliary functions for speculative decoding.
124125
PackedFunc gather_probs_func_;
125126
PackedFunc scatter_probs_func_;

0 commit comments

Comments
 (0)