@@ -65,7 +65,6 @@ class EagleBatchVerifyActionObj : public EngineActionObj {
65
65
Array<GenerationConfig> generation_cfg;
66
66
std::vector<RandomGenerator*> rngs;
67
67
std::vector<std::vector<SampleResult>> draft_output_tokens;
68
- std::vector<std::vector<NDArray>> draft_output_prob_dist;
69
68
request_internal_ids.reserve (num_rsentries);
70
69
all_tokens_to_verify.reserve (total_draft_length);
71
70
verify_request_mstates.reserve (num_rsentries);
@@ -113,12 +112,8 @@ class EagleBatchVerifyActionObj : public EngineActionObj {
113
112
RECORD_EVENT (trace_recorder_, request_ids, " finish verify embedding" );
114
113
115
114
RECORD_EVENT (trace_recorder_, request_ids, " start verify" );
116
- ObjectRef fused_hidden_states = models_[verify_model_id_]->FuseEmbedHidden (
117
- embeddings, NDArray (), 1 , cum_verify_lengths[num_rsentries]);
118
- NDArray hidden_states = models_[verify_model_id_]->BatchVerifyToLastHidden (
119
- fused_hidden_states, request_internal_ids, verify_lengths);
120
- ICHECK_EQ (hidden_states->ndim , 3 );
121
- ICHECK_EQ (hidden_states->shape [0 ], 1 );
115
+ ObjectRef hidden_states = models_[verify_model_id_]->BatchVerifyToLastHidden (
116
+ embeddings, request_internal_ids, verify_lengths);
122
117
NDArray logits =
123
118
models_[verify_model_id_]->GetLogits (hidden_states, 1 , cum_verify_lengths[num_rsentries]);
124
119
RECORD_EVENT (trace_recorder_, request_ids, " finish verify" );
@@ -179,16 +174,11 @@ class EagleBatchVerifyActionObj : public EngineActionObj {
179
174
180
175
{
181
176
// One step draft for the following steps
182
- NDArray last_hidden_states_nd = hidden_states.CreateView (
183
- {hidden_states->shape [0 ] * hidden_states->shape [1 ], hidden_states->shape [2 ]},
184
- hidden_states->dtype );
185
177
186
- hidden_states = Downcast<NDArray>(models_[draft_model_id_]->GatherHiddenStates (
187
- last_hidden_states_nd, last_accepted_hidden_positions,
188
- &model_workspaces_[draft_model_id_].hidden_states ));
189
- ICHECK (hidden_states->ndim == 2 );
190
- hidden_states = hidden_states.CreateView (
191
- {hidden_states->shape [0 ], 1 , hidden_states->shape [1 ]}, hidden_states->dtype );
178
+ // Gather hidden states for the last accepted tokens.
179
+ hidden_states = models_[draft_model_id_]->GatherHiddenStates (
180
+ hidden_states, last_accepted_hidden_positions,
181
+ &model_workspaces_[draft_model_id_].hidden_states );
192
182
193
183
std::vector<int > input_tokens;
194
184
Array<RequestModelState> mstates;
@@ -210,10 +200,10 @@ class EagleBatchVerifyActionObj : public EngineActionObj {
210
200
211
201
// - Invoke model decode.
212
202
RECORD_EVENT (trace_recorder_, request_ids, " start proposal decode" );
213
- ObjectRef fused_hidden_states = models_[draft_model_id_]->FuseEmbedHidden (
203
+ ObjectRef fused_embedding_hidden_states = models_[draft_model_id_]->FuseEmbedHidden (
214
204
embeddings, hidden_states, /* batch_size*/ num_rsentries, /* seq_len*/ 1 );
215
- hidden_states = models_[draft_model_id_]->BatchDecodeToLastHidden (fused_hidden_states,
216
- request_internal_ids);
205
+ hidden_states = models_[draft_model_id_]->BatchDecodeToLastHidden (
206
+ fused_embedding_hidden_states, request_internal_ids);
217
207
218
208
if (models_[draft_model_id_]->CanGetLogits ()) {
219
209
logits = models_[draft_model_id_]->GetLogits (hidden_states, /* batch_size*/ num_rsentries,
@@ -239,22 +229,17 @@ class EagleBatchVerifyActionObj : public EngineActionObj {
239
229
// Fill range [0, num_rsentries) into `sample_indices`.
240
230
std::vector<int > sample_indices (num_rsentries);
241
231
std::iota (sample_indices.begin (), sample_indices.end (), 0 );
242
- std::vector<NDArray> prob_dist;
243
232
NDArray renormalized_probs = sampler_->BatchRenormalizeProbsByTopP (
244
233
probs_on_device, sample_indices, request_ids, generation_cfg);
245
234
std::vector<SampleResult> sample_results = sampler_->BatchSampleTokensWithProbAfterTopP (
246
- renormalized_probs, sample_indices, request_ids, generation_cfg, rngs, &prob_dist );
235
+ renormalized_probs, sample_indices, request_ids, generation_cfg, rngs);
247
236
ICHECK_EQ (sample_results.size (), num_rsentries);
248
237
249
238
// - Slice and save hidden_states_for_sample
250
239
draft_token_workspace_manager_->AllocSlots (num_rsentries, &draft_token_slots_);
251
240
models_[draft_model_id_]->ScatterDraftProbs (
252
241
renormalized_probs, draft_token_slots_,
253
242
&model_workspaces_[verify_model_id_].draft_probs_storage );
254
- ICHECK (hidden_states->ndim == 3 );
255
- hidden_states = hidden_states.CreateView (
256
- {hidden_states->shape [0 ] * hidden_states->shape [1 ], hidden_states->shape [2 ]},
257
- hidden_states->dtype );
258
243
models_[draft_model_id_]->ScatterHiddenStates (
259
244
hidden_states, draft_token_slots_,
260
245
&model_workspaces_[verify_model_id_].draft_hidden_states_storage );
@@ -326,26 +311,6 @@ class EagleBatchVerifyActionObj : public EngineActionObj {
326
311
return num_required_pages <= num_available_pages;
327
312
}
328
313
329
- /* !
330
- * \brief Get one item from a hidden_states array, which corresponds to the last token.
331
- * \param hidden_states The hidden_states of all the tokens.
332
- * \param token_pos The desired token position in the sequence.
333
- * \return The desired token's hidden_states
334
- */
335
- NDArray GetTokenHidden (NDArray hidden_states, int token_pos) {
336
- ICHECK_EQ (hidden_states->ndim , 3 );
337
- NDArray last_hidden_on_device =
338
- NDArray::Empty ({hidden_states->shape [2 ]}, hidden_states->dtype , hidden_states->device );
339
-
340
- int64_t ndata = hidden_states->shape [2 ];
341
- const int16_t * __restrict p_hidden =
342
- static_cast <int16_t *>(__builtin_assume_aligned (hidden_states->data , 2 )) +
343
- (token_pos * ndata);
344
-
345
- last_hidden_on_device.CopyFromBytes (p_hidden, ndata * sizeof (int16_t ));
346
- return last_hidden_on_device;
347
- }
348
-
349
314
/* !
350
315
* \brief The model to run decode in. When there are multiple
351
316
* models, the `Step` function of the created action will not take effect.
0 commit comments