From 66f0a41972fd8a926d59bb58f28851f71c0291d5 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 23 Apr 2024 17:22:45 -0700 Subject: [PATCH] [Eagle] Avoid worker - engine transfer for hidden states --- cpp/serve/draft_token_workspace_manager.cc | 2 +- cpp/serve/engine_actions/eagle_batch_draft.cc | 22 +- .../engine_actions/eagle_batch_verify.cc | 55 +-- .../eagle_new_request_prefill.cc | 70 ++- cpp/serve/function_table.cc | 7 +- cpp/serve/function_table.h | 1 + cpp/serve/model.cc | 407 +++++------------- cpp/serve/model.h | 46 +- cpp/serve/sampler/gpu_sampler.cc | 2 - python/mlc_llm/interface/compile.py | 7 +- python/mlc_llm/model/eagle/eagle_model.py | 4 +- python/mlc_llm/model/llama/llama_model.py | 15 +- 12 files changed, 188 insertions(+), 450 deletions(-) diff --git a/cpp/serve/draft_token_workspace_manager.cc b/cpp/serve/draft_token_workspace_manager.cc index 185b899e14..d004e91ee5 100644 --- a/cpp/serve/draft_token_workspace_manager.cc +++ b/cpp/serve/draft_token_workspace_manager.cc @@ -45,7 +45,7 @@ void DraftTokenWorkspaceManagerObj::AllocWorkspace(ModelWorkspace* workspace, NDArray::Empty({max_num_tokens_, vocab_size_}, DataType::Float(32), device_); if (require_hidden_states) { workspace->draft_hidden_states_storage = - NDArray::Empty({max_num_tokens_, hidden_size_}, hidden_states_dtype_, device_); + ft_.Empty({max_num_tokens_, hidden_size_}, hidden_states_dtype_, device_); } } diff --git a/cpp/serve/engine_actions/eagle_batch_draft.cc b/cpp/serve/engine_actions/eagle_batch_draft.cc index 7ad66a045c..b4e7ec4c39 100644 --- a/cpp/serve/engine_actions/eagle_batch_draft.cc +++ b/cpp/serve/engine_actions/eagle_batch_draft.cc @@ -83,19 +83,15 @@ class EagleBatchDraftActionObj : public EngineActionObj { mstates.push_back(rsentry->mstates[model_id]); } // draft_length_ rounds of draft proposal. - ObjectRef last_hidden_states{nullptr}; - NDArray hidden_states = Downcast(model_workspaces_[model_id].hidden_states); + ObjectRef hidden_states = model_workspaces_[model_id].hidden_states; // Concat last hidden_states draft_token_slots_.clear(); if (draft_length_ > 1) { for (int i = 0; i < num_rsentries; ++i) { draft_token_slots_.push_back(mstates[i]->draft_token_slots.back()); } - hidden_states = Downcast(models_[model_id]->GatherHiddenStates( - model_workspaces_[0].draft_hidden_states_storage, draft_token_slots_, &hidden_states)); - ICHECK(hidden_states->ndim == 2); - last_hidden_states = hidden_states.CreateView( - {hidden_states->shape[0], 1, hidden_states->shape[1]}, hidden_states->dtype); + hidden_states = models_[model_id]->GatherHiddenStates( + model_workspaces_[0].draft_hidden_states_storage, draft_token_slots_, &hidden_states); } // The first draft token has been generated in prefill/verify stage for (int draft_id = 1; draft_id < draft_length_; ++draft_id) { @@ -114,11 +110,10 @@ class EagleBatchDraftActionObj : public EngineActionObj { // - Invoke model decode. RECORD_EVENT(trace_recorder_, request_ids, "start proposal decode"); - ObjectRef fused_hidden_states = models_[model_id]->FuseEmbedHidden( - embeddings, last_hidden_states, /*batch_size*/ num_rsentries, /*seq_len*/ 1); - hidden_states = - models_[model_id]->BatchDecodeToLastHidden(fused_hidden_states, request_internal_ids); - last_hidden_states = hidden_states; + ObjectRef fused_embedding_hidden_states = models_[model_id]->FuseEmbedHidden( + embeddings, hidden_states, /*batch_size*/ num_rsentries, /*seq_len*/ 1); + hidden_states = models_[model_id]->BatchDecodeToLastHidden(fused_embedding_hidden_states, + request_internal_ids); NDArray logits; if (models_[model_id]->CanGetLogits()) { logits = models_[model_id]->GetLogits(hidden_states, /*batch_size*/ num_rsentries, @@ -145,11 +140,10 @@ class EagleBatchDraftActionObj : public EngineActionObj { // Fill range [0, num_rsentries) into `sample_indices`. std::vector sample_indices(num_rsentries); std::iota(sample_indices.begin(), sample_indices.end(), 0); - std::vector prob_dist; NDArray renormalized_probs = sampler_->BatchRenormalizeProbsByTopP( probs_on_device, sample_indices, request_ids, generation_cfg); std::vector sample_results = sampler_->BatchSampleTokensWithProbAfterTopP( - renormalized_probs, sample_indices, request_ids, generation_cfg, rngs, &prob_dist); + renormalized_probs, sample_indices, request_ids, generation_cfg, rngs); ICHECK_EQ(sample_results.size(), num_rsentries); // - Add draft token to the state. diff --git a/cpp/serve/engine_actions/eagle_batch_verify.cc b/cpp/serve/engine_actions/eagle_batch_verify.cc index d52f60d5c7..f7c858192d 100644 --- a/cpp/serve/engine_actions/eagle_batch_verify.cc +++ b/cpp/serve/engine_actions/eagle_batch_verify.cc @@ -65,7 +65,6 @@ class EagleBatchVerifyActionObj : public EngineActionObj { Array generation_cfg; std::vector rngs; std::vector> draft_output_tokens; - std::vector> draft_output_prob_dist; request_internal_ids.reserve(num_rsentries); all_tokens_to_verify.reserve(total_draft_length); verify_request_mstates.reserve(num_rsentries); @@ -113,12 +112,8 @@ class EagleBatchVerifyActionObj : public EngineActionObj { RECORD_EVENT(trace_recorder_, request_ids, "finish verify embedding"); RECORD_EVENT(trace_recorder_, request_ids, "start verify"); - ObjectRef fused_hidden_states = models_[verify_model_id_]->FuseEmbedHidden( - embeddings, NDArray(), 1, cum_verify_lengths[num_rsentries]); - NDArray hidden_states = models_[verify_model_id_]->BatchVerifyToLastHidden( - fused_hidden_states, request_internal_ids, verify_lengths); - ICHECK_EQ(hidden_states->ndim, 3); - ICHECK_EQ(hidden_states->shape[0], 1); + ObjectRef hidden_states = models_[verify_model_id_]->BatchVerifyToLastHidden( + embeddings, request_internal_ids, verify_lengths); NDArray logits = models_[verify_model_id_]->GetLogits(hidden_states, 1, cum_verify_lengths[num_rsentries]); RECORD_EVENT(trace_recorder_, request_ids, "finish verify"); @@ -179,16 +174,11 @@ class EagleBatchVerifyActionObj : public EngineActionObj { { // One step draft for the following steps - NDArray last_hidden_states_nd = hidden_states.CreateView( - {hidden_states->shape[0] * hidden_states->shape[1], hidden_states->shape[2]}, - hidden_states->dtype); - hidden_states = Downcast(models_[draft_model_id_]->GatherHiddenStates( - last_hidden_states_nd, last_accepted_hidden_positions, - &model_workspaces_[draft_model_id_].hidden_states)); - ICHECK(hidden_states->ndim == 2); - hidden_states = hidden_states.CreateView( - {hidden_states->shape[0], 1, hidden_states->shape[1]}, hidden_states->dtype); + // Gather hidden states for the last accepted tokens. + hidden_states = models_[draft_model_id_]->GatherHiddenStates( + hidden_states, last_accepted_hidden_positions, + &model_workspaces_[draft_model_id_].hidden_states); std::vector input_tokens; Array mstates; @@ -210,10 +200,10 @@ class EagleBatchVerifyActionObj : public EngineActionObj { // - Invoke model decode. RECORD_EVENT(trace_recorder_, request_ids, "start proposal decode"); - ObjectRef fused_hidden_states = models_[draft_model_id_]->FuseEmbedHidden( + ObjectRef fused_embedding_hidden_states = models_[draft_model_id_]->FuseEmbedHidden( embeddings, hidden_states, /*batch_size*/ num_rsentries, /*seq_len*/ 1); - hidden_states = models_[draft_model_id_]->BatchDecodeToLastHidden(fused_hidden_states, - request_internal_ids); + hidden_states = models_[draft_model_id_]->BatchDecodeToLastHidden( + fused_embedding_hidden_states, request_internal_ids); if (models_[draft_model_id_]->CanGetLogits()) { logits = models_[draft_model_id_]->GetLogits(hidden_states, /*batch_size*/ num_rsentries, @@ -239,11 +229,10 @@ class EagleBatchVerifyActionObj : public EngineActionObj { // Fill range [0, num_rsentries) into `sample_indices`. std::vector sample_indices(num_rsentries); std::iota(sample_indices.begin(), sample_indices.end(), 0); - std::vector prob_dist; NDArray renormalized_probs = sampler_->BatchRenormalizeProbsByTopP( probs_on_device, sample_indices, request_ids, generation_cfg); std::vector sample_results = sampler_->BatchSampleTokensWithProbAfterTopP( - renormalized_probs, sample_indices, request_ids, generation_cfg, rngs, &prob_dist); + renormalized_probs, sample_indices, request_ids, generation_cfg, rngs); ICHECK_EQ(sample_results.size(), num_rsentries); // - Slice and save hidden_states_for_sample @@ -251,10 +240,6 @@ class EagleBatchVerifyActionObj : public EngineActionObj { models_[draft_model_id_]->ScatterDraftProbs( renormalized_probs, draft_token_slots_, &model_workspaces_[verify_model_id_].draft_probs_storage); - ICHECK(hidden_states->ndim == 3); - hidden_states = hidden_states.CreateView( - {hidden_states->shape[0] * hidden_states->shape[1], hidden_states->shape[2]}, - hidden_states->dtype); models_[draft_model_id_]->ScatterHiddenStates( hidden_states, draft_token_slots_, &model_workspaces_[verify_model_id_].draft_hidden_states_storage); @@ -326,26 +311,6 @@ class EagleBatchVerifyActionObj : public EngineActionObj { return num_required_pages <= num_available_pages; } - /*! - * \brief Get one item from a hidden_states array, which corresponds to the last token. - * \param hidden_states The hidden_states of all the tokens. - * \param token_pos The desired token position in the sequence. - * \return The desired token's hidden_states - */ - NDArray GetTokenHidden(NDArray hidden_states, int token_pos) { - ICHECK_EQ(hidden_states->ndim, 3); - NDArray last_hidden_on_device = - NDArray::Empty({hidden_states->shape[2]}, hidden_states->dtype, hidden_states->device); - - int64_t ndata = hidden_states->shape[2]; - const int16_t* __restrict p_hidden = - static_cast(__builtin_assume_aligned(hidden_states->data, 2)) + - (token_pos * ndata); - - last_hidden_on_device.CopyFromBytes(p_hidden, ndata * sizeof(int16_t)); - return last_hidden_on_device; - } - /*! * \brief The model to run decode in. When there are multiple * models, the `Step` function of the created action will not take effect. diff --git a/cpp/serve/engine_actions/eagle_new_request_prefill.cc b/cpp/serve/engine_actions/eagle_new_request_prefill.cc index 57310f7986..80de254ca8 100644 --- a/cpp/serve/engine_actions/eagle_new_request_prefill.cc +++ b/cpp/serve/engine_actions/eagle_new_request_prefill.cc @@ -83,8 +83,8 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { // - Get embedding and run prefill for each model. std::vector prefill_lengths; prefill_lengths.resize(/*size=*/num_rsentries, /*value=*/-1); - NDArray hidden_states_for_input{nullptr}; - NDArray hidden_states_for_sample{nullptr}; + ObjectRef hidden_states_for_input{nullptr}; + ObjectRef hidden_states_for_sample{nullptr}; NDArray logits_for_sample{nullptr}; // A map used to record the entry and child_idx pair needed to fork sequence. // The base model (id 0) should record all the pairs and all the small models @@ -167,14 +167,17 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { } RECORD_EVENT(trace_recorder_, request_ids, "start prefill"); - ObjectRef fused_hidden_states = models_[model_id]->FuseEmbedHidden( - embeddings, hidden_states_for_input, /*batch_size*/ 1, /*seq_len*/ cum_prefill_length); - NDArray hidden_states = models_[model_id]->BatchPrefillToLastHidden( - fused_hidden_states, request_internal_ids, prefill_lengths); + ObjectRef embedding_or_hidden_states{nullptr}; + if (model_id == 0) { + embedding_or_hidden_states = embeddings; + } else { + embedding_or_hidden_states = models_[model_id]->FuseEmbedHidden( + embeddings, hidden_states_for_input, /*batch_size*/ 1, /*seq_len*/ cum_prefill_length); + } + // hidden_states: (b * s, h) + ObjectRef hidden_states = models_[model_id]->BatchPrefillToLastHidden( + embedding_or_hidden_states, request_internal_ids, prefill_lengths); RECORD_EVENT(trace_recorder_, request_ids, "finish prefill"); - ICHECK_EQ(hidden_states->ndim, 3); - ICHECK_EQ(hidden_states->shape[0], 1); - ICHECK_EQ(hidden_states->shape[1], cum_prefill_length); if (model_id == 0) { // We only need to sample for model 0 in prefill. @@ -183,14 +186,23 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { // Whether to use base model to get logits. int sample_model_id = !models_[model_id]->CanGetLogits() ? 0 : model_id; - hidden_states_for_sample = models_[sample_model_id]->BatchSelectLastHidden( - hidden_states, request_internal_ids, prefill_lengths); + + std::vector logit_positions; + { + // Prepare the logit positions + logit_positions.reserve(prefill_lengths.size()); + int total_len = 0; + for (int i = 0; i < prefill_lengths.size(); ++i) { + total_len += prefill_lengths[i]; + logit_positions.push_back(total_len - 1); + } + } + // hidden_states_for_sample: (b * s, h) + hidden_states_for_sample = models_[sample_model_id]->GatherHiddenStates( + hidden_states, logit_positions, &model_workspaces_[model_id].hidden_states); + // logits_for_sample: (b * s, v) logits_for_sample = models_[sample_model_id]->GetLogits(hidden_states_for_sample, 1, num_rsentries); - ICHECK_EQ(hidden_states_for_sample->ndim, 3); - ICHECK_EQ(hidden_states_for_sample->shape[0], 1); - ICHECK_EQ(hidden_states_for_sample->shape[1], num_rsentries); - // - Update logits. ICHECK(logits_for_sample.defined()); Array generation_cfg; @@ -278,11 +290,11 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { rsentry_activated.push_back(true); } } - std::vector prob_dist; + NDArray renormalized_probs = sampler_->BatchRenormalizeProbsByTopP( probs_on_device, sample_indices, request_ids, generation_cfg); std::vector sample_results = sampler_->BatchSampleTokensWithProbAfterTopP( - renormalized_probs, sample_indices, request_ids, generation_cfg, rngs, &prob_dist); + renormalized_probs, sample_indices, request_ids, generation_cfg, rngs); ICHECK_EQ(sample_results.size(), rsentries_for_sample.size()); // - Update the committed tokens of states. @@ -311,10 +323,6 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { models_[model_id]->ScatterDraftProbs(renormalized_probs, draft_token_slots_, &model_workspaces_[0].draft_probs_storage); if (engine_config_->spec_draft_length > 1) { - hidden_states_for_sample = hidden_states_for_sample.CreateView( - {hidden_states_for_sample->shape[0] * hidden_states_for_sample->shape[1], - hidden_states_for_sample->shape[2]}, - hidden_states_for_sample->dtype); models_[model_id]->ScatterHiddenStates(hidden_states_for_sample, draft_token_slots_, &model_workspaces_[0].draft_hidden_states_storage); } @@ -567,26 +575,6 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { ICHECK(false) << "Cannot reach here"; } - /*! - * \brief Get one item from a hidden_states array, which corresponds to the last token. - * \param hidden_states The hidden_states of all the tokens. - * \param token_pos The desired token position in the sequence. - * \return The desired token's hidden_states - */ - NDArray GetTokenHidden(NDArray hidden_states, int token_pos) { - ICHECK_EQ(hidden_states->ndim, 3); - NDArray last_hidden_on_device = - NDArray::Empty({hidden_states->shape[2]}, hidden_states->dtype, hidden_states->device); - - int64_t ndata = hidden_states->shape[2]; - const int16_t* __restrict p_hidden = - static_cast(__builtin_assume_aligned(hidden_states->data, 2)) + - (token_pos * ndata); - - last_hidden_on_device.CopyFromBytes(p_hidden, ndata * sizeof(int16_t)); - return last_hidden_on_device; - } - /*! \brief The models to run prefill in. */ Array models_; /*! \brief The logit processor. */ diff --git a/cpp/serve/function_table.cc b/cpp/serve/function_table.cc index 4e0301eb2d..16db4a8a03 100644 --- a/cpp/serve/function_table.cc +++ b/cpp/serve/function_table.cc @@ -218,7 +218,7 @@ void FunctionTable::_InitFunctions() { Module mod = this->use_disco ? this->disco_mod->DebugGetFromRemote(0) : this->local_vm; this->get_logits_func_ = mod_get_func("get_logits"); this->batch_get_logits_func_ = mod_get_func("batch_get_logits"); - this->batch_select_last_hidden_func_ = mod->GetFunction("batch_select_last_hidden_states", true); + this->batch_select_last_hidden_func_ = mod_get_func("batch_select_last_hidden_states"); this->softmax_func_ = mod->GetFunction("softmax_with_temperature", true); this->apply_logit_bias_func_ = mod->GetFunction("apply_logit_bias_inplace", true); this->apply_penalty_func_ = mod->GetFunction("apply_penalty_inplace", true); @@ -259,11 +259,12 @@ void FunctionTable::_InitFunctions() { this->nd_get_shape_func_ = get_global_func("vm.builtin.shape_of"); this->nd_copy_embedding_to_offset_func_ = get_global_func("mlc.copy_embedding_to_offset"); support_backtracking_kv_ = true; + this->tuple_getitem_func_ = get_global_func("vm.builtin.tuple_getitem"); this->gather_probs_func_ = mod->GetFunction("gather_probs", true); this->scatter_probs_func_ = mod->GetFunction("scatter_probs", true); - this->gather_hidden_states_func_ = mod->GetFunction("gather_hidden_states", true); - this->scatter_hidden_states_func_ = mod->GetFunction("scatter_hidden_states", true); + this->gather_hidden_states_func_ = mod_get_func("gather_hidden_states"); + this->scatter_hidden_states_func_ = mod_get_func("scatter_hidden_states"); } ObjectRef FunctionTable::Empty(ShapeTuple shape, DataType dtype, Device device) const { diff --git a/cpp/serve/function_table.h b/cpp/serve/function_table.h index e368edcb9c..2350f3d37a 100644 --- a/cpp/serve/function_table.h +++ b/cpp/serve/function_table.h @@ -120,6 +120,7 @@ struct FunctionTable { PackedFunc nd_view_func_; PackedFunc nd_get_shape_func_; PackedFunc nd_copy_embedding_to_offset_func_; + PackedFunc tuple_getitem_func_; // Auxiliary functions for speculative decoding. PackedFunc gather_probs_func_; PackedFunc scatter_probs_func_; diff --git a/cpp/serve/model.cc b/cpp/serve/model.cc index 8918cecdc4..be76b40e2e 100644 --- a/cpp/serve/model.cc +++ b/cpp/serve/model.cc @@ -5,7 +5,6 @@ */ #include "model.h" -#include #include #include #include @@ -137,35 +136,23 @@ class ModelImpl : public ModelObj { return ft_.get_logits_func_.defined() && ft_.batch_get_logits_func_.defined(); } - NDArray GetLogits(const ObjectRef& last_hidden_states, int batch_size, int seq_len) final { + NDArray GetLogits(const ObjectRef& hidden_states, int batch_size, int seq_len) final { NVTXScopedRange nvtx_scope("GetLogits"); CHECK(ft_.get_logits_func_.defined()) << "`get_logits` function is not found in the model."; - ObjectRef hidden_states_dref_or_nd; - CHECK(!last_hidden_states->IsInstance()); - // hidden_states: (b, s, h) - NDArray hidden_states = Downcast(last_hidden_states); - ICHECK_NE(hidden_size_, -1); - ICHECK_EQ(hidden_states->ndim, 3); - ICHECK_EQ(hidden_states->shape[0], batch_size); - ICHECK_EQ(hidden_states->shape[1], seq_len); - ICHECK_EQ(hidden_states->shape[2], hidden_size_); - ICHECK_EQ(hidden_states->device.device_type, device_.device_type); - ICHECK_EQ(hidden_states->device.device_id, device_.device_id); - - hidden_states = - hidden_states.CreateView({batch_size * seq_len, hidden_size_}, hidden_states->dtype); - - // This copy can be avoided by not copying the hidden states to engine. - hidden_states_dref_or_nd = ft_.CopyToWorker0( - hidden_states, "hidden_states", {max_num_sequence_ * prefill_chunk_size_, hidden_size_}); + ObjectRef hidden_states_dref_or_nd{nullptr}; + if (!ft_.use_disco && hidden_states->IsInstance()) { + hidden_states_dref_or_nd = Downcast(hidden_states)->DebugGetFromRemote(0); + } else { + hidden_states_dref_or_nd = hidden_states; + } ObjectRef ret = ft_.get_logits_func_(hidden_states_dref_or_nd, params_); if (trace_enabled_) { TVMSynchronize(device_.device_type, device_.device_id, nullptr); } NDArray logits{nullptr}; - if (ret->IsInstance()) { + if (ft_.use_disco) { logits = Downcast(ret)->DebugGetFromRemote(0); } else { logits = Downcast(ret); @@ -177,142 +164,11 @@ class ModelImpl : public ModelObj { return logits.CreateView({batch_size, seq_len, logits->shape[1]}, logits->dtype); } - NDArray BatchGetLogits(const ObjectRef& last_hidden_states, const std::vector& seq_ids, - const std::vector& lengths) { - NVTXScopedRange nvtx_scope("BatchGetLogits"); - CHECK(!seq_ids.empty()); - CHECK_EQ(seq_ids.size(), lengths.size()); - int num_sequences = seq_ids.size(); - int total_length = 0; - - int* p_logit_pos = static_cast(logit_pos_arr_->data); - for (int i = 0; i < num_sequences; ++i) { - total_length += lengths[i]; - p_logit_pos[i] = total_length - 1; - } - NDArray logit_pos_nd = logit_pos_arr_.CreateView({num_sequences}, DataType::Int(32)); - ObjectRef logit_pos_dref_or_nd = - ft_.CopyToWorker0(logit_pos_nd, "logit_pos", {max_num_sequence_}); - - CHECK(ft_.batch_get_logits_func_.defined()) - << "`batch_get_logits` function is not found in the model."; - - ObjectRef hidden_states_dref_or_nd; - CHECK(!last_hidden_states->IsInstance()); - // hidden_states: (b, s, h) - NDArray hidden_states = Downcast(last_hidden_states); - ICHECK_NE(hidden_size_, -1); - ICHECK_EQ(hidden_states->ndim, 3); - ICHECK_EQ(hidden_states->shape[0], 1); - ICHECK_EQ(hidden_states->shape[1], total_length); - ICHECK_EQ(hidden_states->shape[2], hidden_size_); - ICHECK_EQ(hidden_states->device.device_type, device_.device_type); - ICHECK_EQ(hidden_states->device.device_id, device_.device_id); - - hidden_states = hidden_states.CreateView({total_length, hidden_size_}, hidden_states->dtype); - - // This copy can be avoided by not copying the hidden states to engine. - hidden_states_dref_or_nd = ft_.CopyToWorker0( - hidden_states, "hidden_states", {max_num_sequence_ * prefill_chunk_size_, hidden_size_}); - - ObjectRef ret = - ft_.batch_get_logits_func_(hidden_states_dref_or_nd, logit_pos_dref_or_nd, params_); - if (trace_enabled_) { - TVMSynchronize(device_.device_type, device_.device_id, nullptr); - } - - NDArray logits; - logits = Downcast(ret); - CHECK(logits.defined()); - // logits: (b * s, v) - ICHECK_EQ(logits->ndim, 2); - ICHECK_EQ(logits->shape[0], num_sequences); - return logits.CreateView({1, num_sequences, logits->shape[1]}, logits->dtype); - } - - NDArray BatchSelectLastHidden(const ObjectRef& last_hidden_states, - const std::vector& seq_ids, - const std::vector& lengths) { - NVTXScopedRange nvtx_scope("BatchSelectLastHidden"); - CHECK(!seq_ids.empty()); - CHECK_EQ(seq_ids.size(), lengths.size()); - int num_sequences = seq_ids.size(); - int total_length = 0; - - int* p_logit_pos = static_cast(logit_pos_arr_->data); - for (int i = 0; i < num_sequences; ++i) { - total_length += lengths[i]; - p_logit_pos[i] = total_length - 1; - } - NDArray logit_pos_nd = logit_pos_arr_.CreateView({num_sequences}, DataType::Int(32)); - - ObjectRef logit_pos_dref_or_nd = ft_.CopyToWorker0(logit_pos_nd, "logit_pos_local", - {max_num_sequence_}, /*local_only=*/true); - - CHECK(ft_.batch_select_last_hidden_func_.defined()) - << "`batch_select_last_hidden_states` function is not found in the model."; - - ObjectRef hidden_states_dref_or_nd; - CHECK(!last_hidden_states->IsInstance()); - // hidden_states: (b, s, h) - NDArray hidden_states = Downcast(last_hidden_states); - ICHECK_NE(hidden_size_, -1); - ICHECK_EQ(hidden_states->ndim, 3); - ICHECK_EQ(hidden_states->shape[0], 1); - ICHECK_EQ(hidden_states->shape[1], total_length); - ICHECK_EQ(hidden_states->shape[2], hidden_size_); - ICHECK_EQ(hidden_states->device.device_type, device_.device_type); - ICHECK_EQ(hidden_states->device.device_id, device_.device_id); - - hidden_states_dref_or_nd = - hidden_states.CreateView({total_length, hidden_size_}, hidden_states->dtype); - - ObjectRef ret = - ft_.batch_select_last_hidden_func_(hidden_states_dref_or_nd, logit_pos_dref_or_nd); - if (trace_enabled_) { - TVMSynchronize(device_.device_type, device_.device_id, nullptr); - } - - NDArray hidden; - hidden = Downcast(ret); - // hidden: (b * s, v) - ICHECK_EQ(hidden->ndim, 2); - ICHECK_EQ(hidden->shape[0], num_sequences); - return hidden.CreateView({1, num_sequences, hidden->shape[1]}, hidden->dtype); - } - - NDArray ConcatLastHidden(std::vector& hidden_states, ObjectRef* dst) final { - NVTXScopedRange nvtx_scope("ConcatLastHidden"); - - CHECK(dst->defined()); - - int cum_length = 0; - ICHECK_GE(hidden_states.size(), 1); - for (auto hidden : hidden_states) { - ICHECK_EQ(hidden->ndim, 1); - // No ICHECK_EQ(hidden->shape[0], hidden_size_) here to allow different hidden_sizes. - hidden = hidden.CreateView({1, hidden_size_}, hidden->dtype); - // Reuse the copy embedding function - ObjectRef hidden_dref_or_nd = - ft_.CopyToWorker0(hidden, "hidden_for_concat", {1, hidden_size_}); - ft_.nd_copy_embedding_to_offset_func_(hidden_dref_or_nd, *dst, cum_length); - cum_length += 1; - } - NDArray ret{nullptr}; - if ((*dst)->IsInstance()) { - ret = Downcast(*dst)->DebugGetFromRemote(0); - } else { - ret = Downcast(*dst); - } - ret = ret.CreateView({cum_length, hidden_size_}, hidden_states[0]->dtype); - return ret; - } - ObjectRef FuseEmbedHidden(const ObjectRef& embeddings, const ObjectRef& previous_hidden_states, int batch_size, int seq_len) final { NVTXScopedRange nvtx_scope("FuseEmbedHidden"); - ObjectRef embeddings_dref_or_nd; + ObjectRef embeddings_dref_or_nd{nullptr}; if (!embeddings->IsInstance()) { // embeddings: (n, h) NDArray embeddings_nd = Downcast(embeddings); @@ -320,51 +176,33 @@ class ModelImpl : public ModelObj { ICHECK_EQ(embeddings_nd->ndim, 2); ICHECK_GE(embeddings_nd->shape[0], batch_size * seq_len); ICHECK_EQ(embeddings_nd->shape[1], hidden_size_); - ICHECK_EQ(embeddings_nd->device.device_type, device_.device_type); - ICHECK_EQ(embeddings_nd->device.device_id, device_.device_id); embeddings_dref_or_nd = embeddings_nd.CreateView({batch_size * seq_len, hidden_size_}, embeddings_nd->dtype); - - if (!ft_.fuse_embed_hidden_func_.defined() || !previous_hidden_states.defined()) { - // Model has no support for fuse_embed_hidden_states or this is the first model (base model) - return embeddings_nd.CreateView({batch_size, seq_len, hidden_size_}, embeddings_nd->dtype); - } } else { ShapeTuple embedding_shape{batch_size * seq_len, hidden_size_}; embeddings_dref_or_nd = ft_.nd_view_func_(embeddings, embedding_shape); - - if (!ft_.fuse_embed_hidden_func_.defined() || !previous_hidden_states.defined()) { - // Model has no support for fuse_embed_hidden_states or this is the first model (base model) - ShapeTuple embedding_shape{batch_size, seq_len, hidden_size_}; - return ft_.nd_view_func_(embeddings, embedding_shape); - } } - NDArray hidden_states = Downcast(previous_hidden_states); - CHECK(hidden_states.defined()); - ICHECK_EQ(hidden_states->ndim, 3); - ICHECK_EQ(hidden_states->shape[0], batch_size); - ICHECK_EQ(hidden_states->shape[1], seq_len); - ICHECK_EQ(hidden_states->shape[2], hidden_size_); - ICHECK_EQ(hidden_states->device.device_type, device_.device_type); - ICHECK_EQ(hidden_states->device.device_id, device_.device_id); - NDArray hidden_states_2d = - hidden_states.CreateView({batch_size * seq_len, hidden_size_}, hidden_states->dtype); - auto hidden_states_dref_or_nd = - ft_.CopyToWorker0(hidden_states_2d, "hidden_states_2d", - {max_num_sequence_ * prefill_chunk_size_, hidden_size_}); - - ObjectRef ret = - ft_.fuse_embed_hidden_func_(embeddings_dref_or_nd, hidden_states_dref_or_nd, params_); + ObjectRef previous_hidden_states_dref_or_nd{nullptr}; + if (!ft_.use_disco && previous_hidden_states->IsInstance()) { + previous_hidden_states_dref_or_nd = + Downcast(previous_hidden_states)->DebugGetFromRemote(0); + } else { + previous_hidden_states_dref_or_nd = previous_hidden_states; + } + ObjectRef fused = ft_.fuse_embed_hidden_func_(embeddings_dref_or_nd, + previous_hidden_states_dref_or_nd, params_); if (trace_enabled_) { TVMSynchronize(device_.device_type, device_.device_id, nullptr); } - if (!ret->IsInstance()) { - NDArray fused = Downcast(ret); - return fused.CreateView({batch_size, seq_len, hidden_size_}, fused->dtype); + ShapeTuple out_shape{batch_size, seq_len, hidden_size_}; + if (ft_.use_disco) { + return ft_.nd_view_func_(fused, out_shape); } else { - ShapeTuple fused_shape{batch_size, seq_len, hidden_size_}; - return ft_.nd_view_func_(ret, fused_shape); + NDArray fused_nd = Downcast(fused); + ICHECK_EQ(fused_nd->ndim, 2); + ICHECK_EQ(fused_nd->shape[0], batch_size * seq_len); + return fused_nd.CreateView(out_shape, fused_nd->dtype); } } @@ -439,9 +277,9 @@ class ModelImpl : public ModelObj { return logits; } - NDArray BatchPrefillToLastHidden(const ObjectRef& hidden_states, - const std::vector& seq_ids, - const std::vector& lengths) final { + ObjectRef BatchPrefillToLastHidden(const ObjectRef& embedding_or_hidden_states, + const std::vector& seq_ids, + const std::vector& lengths) final { NVTXScopedRange nvtx_scope("BatchPrefillToLastHidden"); CHECK(!seq_ids.empty()); CHECK_EQ(seq_ids.size(), lengths.size()); @@ -452,19 +290,15 @@ class ModelImpl : public ModelObj { total_length += lengths[i]; } - ObjectRef hidden_states_dref_or_nd; - if (!hidden_states->IsInstance()) { - // hidden_states: (1, n, h) - NDArray hidden_states_nd = Downcast(hidden_states); - ICHECK_EQ(hidden_states_nd->ndim, 3); - ICHECK_EQ(hidden_states_nd->shape[0], 1); - ICHECK_EQ(hidden_states_nd->shape[1], total_length); - ICHECK_EQ(hidden_states_nd->shape[2], hidden_size_); - hidden_states_dref_or_nd = - hidden_states_nd.CreateView({1, total_length, hidden_size_}, hidden_states_nd->dtype); + ObjectRef embedding_or_hidden_states_dref_or_nd{nullptr}; + ShapeTuple hidden_states_shape{1, total_length, hidden_size_}; + if (!ft_.use_disco) { + NDArray embedding_or_hidden_states_nd = Downcast(embedding_or_hidden_states); + embedding_or_hidden_states_dref_or_nd = embedding_or_hidden_states_nd.CreateView( + hidden_states_shape, embedding_or_hidden_states_nd->dtype); } else { - ShapeTuple hidden_states_shape{1, total_length, hidden_size_}; - hidden_states_dref_or_nd = ft_.nd_view_func_(hidden_states, hidden_states_shape); + embedding_or_hidden_states_dref_or_nd = + ft_.nd_view_func_(embedding_or_hidden_states, hidden_states_shape); } CHECK(ft_.prefill_to_last_hidden_func_.defined()) @@ -479,32 +313,34 @@ class ModelImpl : public ModelObj { ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple); // args: embeddings, logit_pos, kv_cache, params - ObjectRef ret; + ObjectRef result{nullptr}; if (seq_ids.size() == 1) { CHECK(ft_.single_batch_prefill_to_last_hidden_func_.defined()) << "`single_batch_prefill_to_last_hidden_states` function is not found in the model."; - ret = ft_.single_batch_prefill_to_last_hidden_func_(hidden_states_dref_or_nd, kv_cache_, - params_); + result = ft_.single_batch_prefill_to_last_hidden_func_(embedding_or_hidden_states_dref_or_nd, + kv_cache_, params_); } else { - ret = ft_.prefill_to_last_hidden_func_(hidden_states_dref_or_nd, kv_cache_, params_); - } - NDArray last_hidden_states; - if (ft_.use_disco) { - Array result = Downcast(ret)->DebugGetFromRemote(0); - last_hidden_states = Downcast(result[0]); - } else { - last_hidden_states = Downcast>(ret)[0]; + result = ft_.prefill_to_last_hidden_func_(embedding_or_hidden_states_dref_or_nd, kv_cache_, + params_); } + ObjectRef hidden_states = ft_.tuple_getitem_func_(result, 0); + if (trace_enabled_) { TVMSynchronize(device_.device_type, device_.device_id, nullptr); } ft_.kv_cache_end_forward_func_(kv_cache_); - // hidden_states: (1, total_length, v) - ICHECK_EQ(last_hidden_states->ndim, 3); - ICHECK_EQ(last_hidden_states->shape[0], 1); - ICHECK_EQ(last_hidden_states->shape[1], total_length); - return last_hidden_states; + ShapeTuple out_shape{total_length, hidden_size_}; + if (ft_.use_disco) { + return ft_.nd_view_func_(hidden_states, out_shape); + } else { + NDArray hidden_states_nd = Downcast(hidden_states); + ICHECK_EQ(hidden_states_nd->ndim, 3); + ICHECK_EQ(hidden_states_nd->shape[0], 1); + ICHECK_EQ(hidden_states_nd->shape[1], total_length); + ICHECK_EQ(hidden_states_nd->shape[2], hidden_size_); + return hidden_states_nd.CreateView(out_shape, hidden_states_nd->dtype); + } } NDArray BatchDecode(const ObjectRef& embeddings, const std::vector& seq_ids) final { @@ -567,8 +403,8 @@ class ModelImpl : public ModelObj { return logits; } - NDArray BatchDecodeToLastHidden(const ObjectRef& hidden_states, - const std::vector& seq_ids) final { + ObjectRef BatchDecodeToLastHidden(const ObjectRef& hidden_states_dref_or_nd, + const std::vector& seq_ids) final { NVTXScopedRange nvtx_scope("BatchDecodeToLastHidden"); int num_sequence = seq_ids.size(); @@ -578,21 +414,6 @@ class ModelImpl : public ModelObj { ICHECK(ft_.kv_cache_end_forward_func_.defined()); ICHECK(kv_cache_.defined()) << "KV cache has not been initialized."; - ObjectRef hidden_states_dref_or_nd; - if (!hidden_states->IsInstance()) { - // hidden_states: (1, n, h) - NDArray hidden_states_nd = Downcast(hidden_states); - ICHECK_EQ(hidden_states_nd->ndim, 3); - ICHECK_EQ(hidden_states_nd->shape[0], num_sequence); - ICHECK_EQ(hidden_states_nd->shape[1], 1); - ICHECK_EQ(hidden_states_nd->shape[2], hidden_size_); - hidden_states_dref_or_nd = - hidden_states_nd.CreateView({num_sequence, 1, hidden_size_}, hidden_states_nd->dtype); - } else { - ShapeTuple hidden_states_shape{num_sequence, 1, hidden_size_}; - hidden_states_dref_or_nd = ft_.nd_view_func_(hidden_states, hidden_states_shape); - } - // Reserve in KV cache for the lengths of the input. // Begin forward with the sequence ids and new lengths. IntTuple seq_ids_tuple(seq_ids); @@ -600,32 +421,34 @@ class ModelImpl : public ModelObj { ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple); // args: embeddings, kv_cache, params - ObjectRef ret; + ObjectRef result{nullptr}; if (seq_ids.size() == 1) { CHECK(ft_.single_batch_decode_to_last_hidden_func_.defined()) << "`decode_to_last_hidden_states` function is not found in the model."; - ret = ft_.single_batch_decode_to_last_hidden_func_(hidden_states_dref_or_nd, kv_cache_, - params_); - } else { - ret = ft_.decode_to_last_hidden_func_(hidden_states_dref_or_nd, kv_cache_, params_); - } - NDArray last_hidden_states; - if (ft_.use_disco) { - Array result = Downcast(ret)->DebugGetFromRemote(0); - last_hidden_states = Downcast(result[0]); + result = ft_.single_batch_decode_to_last_hidden_func_(hidden_states_dref_or_nd, kv_cache_, + params_); } else { - last_hidden_states = Downcast>(ret)[0]; + result = ft_.decode_to_last_hidden_func_(hidden_states_dref_or_nd, kv_cache_, params_); } + ft_.kv_cache_end_forward_func_(kv_cache_); + ObjectRef hidden_states = ft_.tuple_getitem_func_(result, 0); + if (trace_enabled_) { TVMSynchronize(device_.device_type, device_.device_id, nullptr); } - ft_.kv_cache_end_forward_func_(kv_cache_); - // hidden_states: (b, 1, v) - ICHECK_EQ(last_hidden_states->ndim, 3); - ICHECK_EQ(last_hidden_states->shape[0], num_sequence); - ICHECK_EQ(last_hidden_states->shape[1], 1); - return last_hidden_states; + // hidden_states: (b, 1, v) to (b, v) + ShapeTuple out_shape{num_sequence, hidden_size_}; + if (ft_.use_disco) { + return ft_.nd_view_func_(hidden_states, out_shape); + } else { + NDArray hidden_states_nd = Downcast(hidden_states); + ICHECK_EQ(hidden_states_nd->ndim, 3); + ICHECK_EQ(hidden_states_nd->shape[0], num_sequence); + ICHECK_EQ(hidden_states_nd->shape[1], 1); + ICHECK_EQ(hidden_states_nd->shape[2], hidden_size_); + return hidden_states_nd.CreateView(out_shape, hidden_states_nd->dtype); + } } NDArray BatchVerify(const ObjectRef& embeddings, const std::vector& seq_ids, @@ -688,9 +511,9 @@ class ModelImpl : public ModelObj { return logits; } - NDArray BatchVerifyToLastHidden(const ObjectRef& hidden_states, - const std::vector& seq_ids, - const std::vector& lengths) final { + ObjectRef BatchVerifyToLastHidden(const ObjectRef& embeddings, + const std::vector& seq_ids, + const std::vector& lengths) final { NVTXScopedRange nvtx_scope("BatchVerifyToLastHidden"); CHECK(!seq_ids.empty()); CHECK_EQ(seq_ids.size(), lengths.size()); @@ -706,45 +529,46 @@ class ModelImpl : public ModelObj { ICHECK(ft_.kv_cache_end_forward_func_.defined()); ICHECK(kv_cache_.defined()) << "KV cache has not been initialized."; - ObjectRef hidden_states_dref_or_nd; - if (!hidden_states->IsInstance()) { - // hidden_states: (1, n, h) - NDArray hidden_states_nd = Downcast(hidden_states); - ICHECK_EQ(hidden_states_nd->ndim, 3); - ICHECK_EQ(hidden_states_nd->shape[0], 1); - ICHECK_EQ(hidden_states_nd->shape[1], total_length); - ICHECK_EQ(hidden_states_nd->shape[2], hidden_size_); - hidden_states_dref_or_nd = - hidden_states_nd.CreateView({1, total_length, hidden_size_}, hidden_states_nd->dtype); + ObjectRef embeddings_dref_or_nd; + if (!embeddings->IsInstance()) { + // embeddings: (1, n, h) + NDArray embeddings_nd = Downcast(embeddings); + ICHECK_NE(hidden_size_, -1); + ICHECK_EQ(embeddings_nd->ndim, 2); + ICHECK_GE(embeddings_nd->shape[0], total_length); + ICHECK_EQ(embeddings_nd->shape[1], hidden_size_); + ICHECK_EQ(embeddings_nd->device.device_type, device_.device_type); + ICHECK_EQ(embeddings_nd->device.device_id, device_.device_id); + embeddings_dref_or_nd = + embeddings_nd.CreateView({1, total_length, hidden_size_}, embeddings_nd->dtype); } else { - ShapeTuple hidden_states_shape{1, total_length, hidden_size_}; - hidden_states_dref_or_nd = ft_.nd_view_func_(hidden_states, hidden_states_shape); + ShapeTuple embedding_shape{1, total_length, hidden_size_}; + embeddings_dref_or_nd = ft_.nd_view_func_(embeddings, embedding_shape); } - // Begin forward with the sequence ids and new lengths. IntTuple seq_ids_tuple(seq_ids); IntTuple lengths_tuple(lengths.begin(), lengths.end()); ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple); // args: embeddings, logit_pos, kv_cache, params - ObjectRef ret = ft_.verify_to_last_hidden_func_(hidden_states_dref_or_nd, kv_cache_, params_); - NDArray last_hidden_states; - if (ft_.use_disco) { - Array result = Downcast(ret)->DebugGetFromRemote(0); - last_hidden_states = Downcast(result[0]); - } else { - last_hidden_states = Downcast>(ret)[0]; - } + ObjectRef result = ft_.verify_to_last_hidden_func_(embeddings_dref_or_nd, kv_cache_, params_); + ft_.kv_cache_end_forward_func_(kv_cache_); + ObjectRef hidden_states = ft_.tuple_getitem_func_(result, 0); if (trace_enabled_) { TVMSynchronize(device_.device_type, device_.device_id, nullptr); } - ft_.kv_cache_end_forward_func_(kv_cache_); - // hidden_states: (1, total_length, v) - ICHECK_EQ(last_hidden_states->ndim, 3); - ICHECK_EQ(last_hidden_states->shape[0], 1); - ICHECK_EQ(last_hidden_states->shape[1], total_length); - return last_hidden_states; + ShapeTuple out_shape{total_length, hidden_size_}; + if (!ft_.use_disco) { + NDArray hidden_states_nd = Downcast(hidden_states); + ICHECK_EQ(hidden_states_nd->ndim, 3); + ICHECK_EQ(hidden_states_nd->shape[0], 1); + ICHECK_EQ(hidden_states_nd->shape[1], total_length); + ICHECK_EQ(hidden_states_nd->shape[2], hidden_size_); + return hidden_states_nd.CreateView(out_shape, hidden_states_nd->dtype); + } else { + return ft_.nd_view_func_(hidden_states, out_shape); + } } /*********************** KV Cache Management ***********************/ @@ -877,8 +701,7 @@ class ModelImpl : public ModelObj { ICHECK_EQ(hidden_states_shape[0], prefill_chunk_size_); this->hidden_size_ = hidden_states_shape[1]; this->hidden_states_dtype_ = hidden_states_nd->dtype; - // TODO(wuwei): We can keep hidden_states on the worker after refactor - return hidden_states_nd; + return hidden_states; } void Reset() final { @@ -897,13 +720,18 @@ class ModelImpl : public ModelObj { ObjectRef GatherHiddenStates(const ObjectRef& input, const std::vector& indices, ObjectRef* dst) final { - NDArray dst_view = Downcast(*dst).CreateView( - {static_cast(indices.size()), hidden_size_}, hidden_states_dtype_); + ObjectRef dst_view{nullptr}; + ShapeTuple out_shape{static_cast(indices.size()), hidden_size_}; + if ((*dst)->IsInstance()) { + dst_view = ft_.nd_view_func_(*dst, out_shape); + } else { + NDArray dst_nd = Downcast(*dst); + dst_view = dst_nd.CreateView(out_shape, hidden_states_dtype_); + } NDArray indices_nd = logit_pos_arr_.CreateView({static_cast(indices.size())}, DataType::Int(32)); indices_nd.CopyFromBytes(indices.data(), indices.size() * sizeof(int)); - ObjectRef indices_device = - ft_.CopyToWorker0(indices_nd, "logit_pos_local", {max_num_sequence_}, /*local_only=*/true); + ObjectRef indices_device = ft_.CopyToWorker0(indices_nd, "logit_pos", {max_num_sequence_}); ft_.gather_hidden_states_func_(input, indices_device, dst_view); return dst_view; } @@ -913,8 +741,7 @@ class ModelImpl : public ModelObj { NDArray indices_nd = logit_pos_arr_.CreateView({static_cast(indices.size())}, DataType::Int(32)); indices_nd.CopyFromBytes(indices.data(), indices.size() * sizeof(int)); - ObjectRef indices_device = - ft_.CopyToWorker0(indices_nd, "logit_pos_local", {max_num_sequence_}, /*local_only=*/true); + ObjectRef indices_device = ft_.CopyToWorker0(indices_nd, "logit_pos", {max_num_sequence_}); ft_.scatter_hidden_states_func_(input, indices_device, *dst); } diff --git a/cpp/serve/model.h b/cpp/serve/model.h index d672739581..f587969bfb 100644 --- a/cpp/serve/model.h +++ b/cpp/serve/model.h @@ -7,6 +7,7 @@ #ifndef MLC_LLM_SERVE_MODEL_H_ #define MLC_LLM_SERVE_MODEL_H_ +#include #include #include @@ -139,35 +140,6 @@ class ModelObj : public Object { */ virtual NDArray GetLogits(const ObjectRef& last_hidden_states, int batch_size, int seq_len) = 0; - /*! - * \brief Compute logits for last hidden_states in a batch. - * \param last_hidden_states The last hidden_states to compute logits for. - * \param seq_ids The id of the sequence in the KV cache. - * \param lengths The length of each sequence to prefill. - * \return The computed logits. - */ - virtual NDArray BatchGetLogits(const ObjectRef& last_hidden_states, - const std::vector& seq_ids, - const std::vector& lengths) = 0; - - /*! - * \brief Select desired hidden_states for last hidden_states in a batch. - * \param last_hidden_states The last hidden_states to select from. - * \param seq_ids The id of the sequence in the KV cache. - * \param lengths The length of each sequence to prefill. - * \return The last hidden_states for the batch. - */ - virtual NDArray BatchSelectLastHidden(const ObjectRef& last_hidden_states, - const std::vector& seq_ids, - const std::vector& lengths) = 0; - - /*! - * \brief Concat a list of 1D hidden_states to 2D tensor. - * \param hidden_states The hidden_states to concat. - * \param dst The copy destination. - */ - virtual NDArray ConcatLastHidden(std::vector& hidden_states, ObjectRef* dst) = 0; - /*! * \brief Batch prefill function. Embedding in, logits out. * The embedding order of sequences in `embedding_arr` follows @@ -188,9 +160,9 @@ class ModelObj : public Object { * \param lengths The length of each sequence to prefill. * \return The hidden_states for the next token. */ - virtual NDArray BatchPrefillToLastHidden(const ObjectRef& hidden_states, - const std::vector& seq_ids, - const std::vector& lengths) = 0; + virtual ObjectRef BatchPrefillToLastHidden(const ObjectRef& hidden_states, + const std::vector& seq_ids, + const std::vector& lengths) = 0; /*! * \brief Batch decode function. Embedding in, logits out. @@ -209,8 +181,8 @@ class ModelObj : public Object { * \param seq_id The id of the sequence in the KV cache. * \return The hidden_states for the next token for each sequence in the batch. */ - virtual NDArray BatchDecodeToLastHidden(const ObjectRef& hidden_states, - const std::vector& seq_ids) = 0; + virtual ObjectRef BatchDecodeToLastHidden(const ObjectRef& hidden_states, + const std::vector& seq_ids) = 0; /*! * \brief Batch verify function. Embedding in, logits out. @@ -236,9 +208,9 @@ class ModelObj : public Object { * That is to say, it does not accept "running a verify step for a subset * of the full batch". */ - virtual NDArray BatchVerifyToLastHidden(const ObjectRef& hidden_states, - const std::vector& seq_ids, - const std::vector& lengths) = 0; + virtual ObjectRef BatchVerifyToLastHidden(const ObjectRef& hidden_states, + const std::vector& seq_ids, + const std::vector& lengths) = 0; /*********************** KV Cache Management ***********************/ diff --git a/cpp/serve/sampler/gpu_sampler.cc b/cpp/serve/sampler/gpu_sampler.cc index c6f463eb32..87a9a31d30 100644 --- a/cpp/serve/sampler/gpu_sampler.cc +++ b/cpp/serve/sampler/gpu_sampler.cc @@ -74,7 +74,6 @@ class GPUSampler : public SamplerObj { sample_indices_device_ = NDArray::Empty({max_num_sample}, dtype_i32_, device); top_p_device_ = NDArray::Empty({max_num_sample}, dtype_f32_, device); top_prob_offsets_device_ = NDArray::Empty({max_num_sample * 5}, dtype_i32_, device); - draft_probs_device_ = NDArray::Empty({max_num_sample, vocab_size}, dtype_f32_, device); draft_tokens_device_ = NDArray::Empty({max_num_sample}, dtype_i32_, device); token_tree_first_child_device_ = NDArray::Empty({max_num_sample}, dtype_i32_, device); token_tree_next_sibling_device_ = NDArray::Empty({max_num_sample}, dtype_i32_, device); @@ -630,7 +629,6 @@ class GPUSampler : public SamplerObj { NDArray sample_indices_device_; NDArray top_p_device_; NDArray top_prob_offsets_device_; - NDArray draft_probs_device_; NDArray draft_tokens_device_; NDArray token_tree_first_child_device_; NDArray token_tree_next_sibling_device_; diff --git a/python/mlc_llm/interface/compile.py b/python/mlc_llm/interface/compile.py index 4e8bcabd9e..19eadd8206 100644 --- a/python/mlc_llm/interface/compile.py +++ b/python/mlc_llm/interface/compile.py @@ -1,4 +1,5 @@ """Python entrypoint of compilation.""" + import dataclasses import math from io import StringIO @@ -162,7 +163,11 @@ def _find_kv_cache_bytes(model: nn.Module, model_config) -> int: logger.info("Running optimizations using TVM Unity") additional_tirs = _apply_preproc_to_params(named_params, model_config) variable_bounds = _get_variable_bounds(model_config) - cuda_graph_symbolic_capture_hints = {"batch_decode": ["batch_size"]} + cuda_graph_symbolic_capture_hints = { + "batch_decode": ["batch_size"], + "batch_decode_to_last_last_hidden_states": ["batch_size"], + "batch_verify_to_last_last_hidden_states": ["batch_size", "seq_len"], + } metadata = { "model_type": args.model.name, "quantization": args.quantization.name, diff --git a/python/mlc_llm/model/eagle/eagle_model.py b/python/mlc_llm/model/eagle/eagle_model.py index 355618df09..9d7820b841 100644 --- a/python/mlc_llm/model/eagle/eagle_model.py +++ b/python/mlc_llm/model/eagle/eagle_model.py @@ -190,8 +190,8 @@ def get_default_spec(self): }, }, "fuse_embed_hidden_states": { - "input_embed": nn.spec.Tensor(["length", self.hidden_size], self.dtype), - "hidden_states": nn.spec.Tensor(["length", self.hidden_size], self.dtype), + "input_embed": nn.spec.Tensor(["seq_len", self.hidden_size], self.dtype), + "hidden_states": nn.spec.Tensor(["seq_len", self.hidden_size], self.dtype), "$": { "param_mode": "packed", "effect_mode": "none", diff --git a/python/mlc_llm/model/llama/llama_model.py b/python/mlc_llm/model/llama/llama_model.py index 18238f688e..60c8f138d1 100644 --- a/python/mlc_llm/model/llama/llama_model.py +++ b/python/mlc_llm/model/llama/llama_model.py @@ -248,16 +248,11 @@ def get_logits(self, hidden_states: Tensor): logits = logits.astype("float32") return logits - def batch_get_logits(self, hidden_states: Tensor, logit_positions: Tensor): + def batch_select_last_hidden_states(self, hidden_states: Tensor, logit_positions: Tensor): op_ext.configure() if self.tensor_parallel_shards > 1: logit_positions = op.ccl_broadcast_from_worker0(logit_positions) hidden_states = op.take(hidden_states, logit_positions, axis=0) - return self.get_logits(hidden_states) - - def batch_select_last_hidden_states(self, hidden_states: Tensor, logit_positions: Tensor): - op_ext.configure() - hidden_states = op.take(hidden_states, logit_positions, axis=0) return hidden_states def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): @@ -368,14 +363,6 @@ def get_default_spec(self): "effect_mode": "none", }, }, - "batch_get_logits": { - "hidden_states": nn.spec.Tensor(["seq_len", self.hidden_size], self.dtype), - "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), - "$": { - "param_mode": "packed", - "effect_mode": "none", - }, - }, "batch_select_last_hidden_states": { "hidden_states": nn.spec.Tensor(["seq_len", self.hidden_size], self.dtype), "logit_positions": nn.spec.Tensor(["batch_size"], "int32"),