From d861a7d3c67f8e6a24fc97e2bccd389c7f8c19e6 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Tue, 9 Jan 2024 08:57:07 -0500 Subject: [PATCH] [Serving][Refactor] Logit processor and logit bias support This PR refactors the existing logit processing pipeline with a unfiied logit processor class. The logit processor class exposes two functions: - `InplaceUpdateLogits`, which takes in the raw logits produced by the model, and apply logit bias (which is introduced in this PR), presence/frequency/repetition penalties, and token id mask in order when needed. - `ComputeProbsFromLogits`, which takes in the updated logits, and invoke softmax with temperature to compute the probability distribution. The logit processor completely runs on GPU. This being said, all the logit bias / penalty / mask application and the softmax is backed by GPU kernels. This is a highlight difference compared with the logit processing prior to this PR, where the processing happens on CPU, and softmax also happens on CPU when any logit process is needed. With the unified logit processor, we simplified the interface of handling model's output logits in engine actions to make it cleaner. We also simplified the interface of Sampler. Preliminary results show that LogitProcessor brings a bit perf improvement when any processing is needed. --- cpp/serve/config.cc | 22 + cpp/serve/config.h | 1 + cpp/serve/engine.cc | 31 +- cpp/serve/engine_actions/action.h | 18 +- cpp/serve/engine_actions/batch_decode.cc | 27 +- cpp/serve/engine_actions/batch_draft.cc | 26 +- cpp/serve/engine_actions/batch_verify.cc | 31 +- .../engine_actions/new_request_prefill.cc | 41 +- cpp/serve/function_table.cc | 13 +- cpp/serve/function_table.h | 3 + cpp/serve/logit_processor.cc | 404 ++++++++++++++ cpp/serve/logit_processor.h | 94 ++++ cpp/serve/model.cc | 41 +- cpp/serve/model.h | 15 +- cpp/serve/request_state.cc | 14 +- cpp/serve/request_state.h | 5 + cpp/serve/sampler.cc | 512 ++++-------------- cpp/serve/sampler.h | 22 +- .../compiler_pass/attach_to_ir_module.py | 113 ++++ python/mlc_chat/compiler_pass/pipeline.py | 2 + .../mlc_chat/protocol/openai_api_protocol.py | 46 +- python/mlc_chat/serve/config.py | 6 +- tests/python/serve/server/test_server.py | 71 ++- 23 files changed, 1008 insertions(+), 550 deletions(-) create mode 100644 cpp/serve/logit_processor.cc create mode 100644 cpp/serve/logit_processor.h diff --git a/cpp/serve/config.cc b/cpp/serve/config.cc index 3c4d77d6a6..804ff9fe93 100644 --- a/cpp/serve/config.cc +++ b/cpp/serve/config.cc @@ -52,6 +52,22 @@ GenerationConfig::GenerationConfig(String config_json_str) { n->repetition_penalty = config["repetition_penalty"].get(); CHECK(n->repetition_penalty > 0) << "Repetition penalty must be a positive number!"; } + if (config.count("logit_bias")) { + CHECK(config["logit_bias"].is() || config["logit_bias"].is()); + if (config["logit_bias"].is()) { + picojson::object logit_bias_json = config["logit_bias"].get(); + std::vector> logit_bias; + logit_bias.reserve(logit_bias_json.size()); + for (auto [token_id_str, bias] : logit_bias_json) { + CHECK(bias.is()); + double bias_value = bias.get(); + CHECK_LE(std::fabs(bias_value), 100.0) + << "Logit bias value should be in range [-100, 100]."; + logit_bias.emplace_back(std::stoi(token_id_str), bias_value); + } + n->logit_bias = std::move(logit_bias); + } + } if (config.count("max_tokens")) { if (config["max_tokens"].is()) { n->max_tokens = config["max_tokens"].get(); @@ -115,6 +131,12 @@ String GenerationConfigNode::AsJSONString() const { config["max_tokens"] = picojson::value(static_cast(this->max_tokens)); config["seed"] = picojson::value(static_cast(this->seed)); + picojson::object logit_bias_obj; + for (auto [token_id, bias] : logit_bias) { + logit_bias_obj[std::to_string(token_id)] = picojson::value(static_cast(bias)); + } + config["logit_bias"] = picojson::value(logit_bias_obj); + picojson::array stop_strs_arr; for (String stop_str : this->stop_strs) { stop_strs_arr.push_back(picojson::value(stop_str)); diff --git a/cpp/serve/config.h b/cpp/serve/config.h index 34bbfc9880..c9ebf0c847 100644 --- a/cpp/serve/config.h +++ b/cpp/serve/config.h @@ -25,6 +25,7 @@ class GenerationConfigNode : public Object { double frequency_penalty = 0.0; double presence_penalty = 0.0; double repetition_penalty = 1.0; + std::vector> logit_bias; int seed; bool ignore_eos = false; diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index 08376712be..28b1e70006 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -19,6 +19,7 @@ #include "engine_actions/action_commons.h" #include "engine_state.h" #include "event_trace_recorder.h" +#include "logit_processor.h" #include "model.h" #include "request.h" #include "request_state.h" @@ -53,10 +54,10 @@ class EngineImpl : public Engine { this->engine_mode_ = EngineMode(engine_mode_json_str); this->request_stream_callback_ = std::move(request_stream_callback); this->trace_recorder_ = trace_recorder; - this->sampler_ = Sampler::Create(/*sampler_kind=*/"cpu", trace_recorder_); this->tokenizer_ = Tokenizer::FromPath(tokenizer_path); this->token_table_ = tokenizer_->TokenTable(); // Step 2. Initialize each model independently. + // Create the logit processor and sampler. this->models_.clear(); for (const auto& model_info : model_infos) { TVMArgValue model_lib = std::get<0>(model_info); @@ -71,26 +72,35 @@ class EngineImpl : public Engine { << this->max_single_sequence_length_; this->models_.push_back(model); } + int max_logit_processor_num_token = kv_cache_config_->max_num_sequence; + if (engine_mode_->enable_speculative) { + max_logit_processor_num_token *= engine_mode_->spec_draft_length; + } + LogitProcessor logit_processor = + this->models_[0]->CreateLogitProcessor(max_logit_processor_num_token, trace_recorder); + Sampler sampler = Sampler::Create(/*sampler_kind=*/"cpu", trace_recorder_); // Step 3. Initialize engine actions that represent state transitions. if (this->engine_mode_->enable_speculative) { // Speculative decoding is only possible for more than one model. ICHECK_GT(this->models_.size(), 1U); this->actions_ = { EngineAction::NewRequestPrefill(this->models_, // - this->sampler_, // + logit_processor, // + sampler, // this->kv_cache_config_, // this->trace_recorder_), - EngineAction::BatchDraft(this->models_, this->sampler_, this->trace_recorder_, + EngineAction::BatchDraft(this->models_, logit_processor, sampler, this->trace_recorder_, this->engine_mode_->spec_draft_length), - EngineAction::BatchVerify(this->models_, this->sampler_, this->kv_cache_config_, + EngineAction::BatchVerify(this->models_, logit_processor, sampler, this->kv_cache_config_, this->trace_recorder_)}; } else { - this->actions_ = { - EngineAction::NewRequestPrefill(this->models_, // - this->sampler_, // - this->kv_cache_config_, // - this->trace_recorder_), - EngineAction::BatchDecode(this->models_, this->sampler_, this->trace_recorder_)}; + this->actions_ = {EngineAction::NewRequestPrefill(this->models_, // + logit_processor, // + sampler, // + this->kv_cache_config_, // + this->trace_recorder_), + EngineAction::BatchDecode(this->models_, logit_processor, sampler, + this->trace_recorder_)}; } // Step 4. Automatically set the threading backend max concurrency. SetThreadMaxConcurrency(); @@ -196,7 +206,6 @@ class EngineImpl : public Engine { KVCacheConfig kv_cache_config_; EngineMode engine_mode_; int max_single_sequence_length_; - Sampler sampler_; Tokenizer tokenizer_; std::vector token_table_; // Models diff --git a/cpp/serve/engine_actions/action.h b/cpp/serve/engine_actions/action.h index cd2ef33f99..8e305e26af 100644 --- a/cpp/serve/engine_actions/action.h +++ b/cpp/serve/engine_actions/action.h @@ -53,13 +53,14 @@ class EngineAction : public ObjectRef { * \brief Create the action that prefills requests in the `waiting_queue` * of the engine state. * \param models The models to run prefill in. + * \param logit_processor The logit processor. * \param sampler The sampler to sample new tokens. * \param kv_cache_config The KV cache config to help decide prefill is doable. * \param trace_recorder The event trace recorder for requests. * \return The created action object. */ - static EngineAction NewRequestPrefill(Array models, Sampler sampler, - KVCacheConfig kv_cache_config, + static EngineAction NewRequestPrefill(Array models, LogitProcessor logit_processor, + Sampler sampler, KVCacheConfig kv_cache_config, Optional trace_recorder); /*! * \brief Create the action that runs one-step decode for requests in the @@ -74,8 +75,8 @@ class EngineAction : public ObjectRef { * \param trace_recorder The event trace recorder for requests. * \return The created action object. */ - static EngineAction BatchDecode(Array models, Sampler sampler, - Optional trace_recorder); + static EngineAction BatchDecode(Array models, LogitProcessor logit_processor, + Sampler sampler, Optional trace_recorder); /*! * \brief Create the action that runs one-step speculative draft proposal for @@ -88,8 +89,9 @@ class EngineAction : public ObjectRef { * \param draft_length The number of draft proposal rounds. * \return The created action object. */ - static EngineAction BatchDraft(Array models, Sampler sampler, - Optional trace_recorder, int draft_length = 4); + static EngineAction BatchDraft(Array models, LogitProcessor logit_processor, + Sampler sampler, Optional trace_recorder, + int draft_length = 4); /*! * \brief Create the action that runs one-step speculative verification for requests in the @@ -102,8 +104,8 @@ class EngineAction : public ObjectRef { * \param trace_recorder The event trace recorder for requests. * \return The created action object. */ - static EngineAction BatchVerify(Array models, Sampler sampler, - KVCacheConfig kv_cache_config, + static EngineAction BatchVerify(Array models, LogitProcessor logit_processor, + Sampler sampler, KVCacheConfig kv_cache_config, Optional trace_recorder); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(EngineAction, ObjectRef, EngineActionObj); diff --git a/cpp/serve/engine_actions/batch_decode.cc b/cpp/serve/engine_actions/batch_decode.cc index 410e94d286..627e46bc9a 100644 --- a/cpp/serve/engine_actions/batch_decode.cc +++ b/cpp/serve/engine_actions/batch_decode.cc @@ -24,9 +24,10 @@ namespace serve { */ class BatchDecodeActionObj : public EngineActionObj { public: - explicit BatchDecodeActionObj(Array models, Sampler sampler, - Optional trace_recorder) + explicit BatchDecodeActionObj(Array models, LogitProcessor logit_processor, + Sampler sampler, Optional trace_recorder) : models_(std::move(models)), + logit_processor_(std::move(logit_processor)), sampler_(std::move(sampler)), trace_recorder_(std::move(trace_recorder)) {} @@ -92,11 +93,17 @@ class BatchDecodeActionObj : public EngineActionObj { ICHECK_EQ(logits->shape[0], embeddings->shape[0]); ICHECK_EQ(logits->shape[1], 1); + // - Update logits. + logits = logits.CreateView({num_requests, logits->shape[2]}, logits->dtype); + logit_processor_->InplaceUpdateLogits(logits, generation_cfg, mstates, request_ids); + + // - Compute probability distributions. + NDArray probs_device = + logit_processor_->ComputeProbsFromLogits(logits, generation_cfg, request_ids); + // - Sample tokens. - RECORD_EVENT(trace_recorder_, request_ids, "start sampling"); std::vector next_tokens = - sampler_->BatchSampleTokens(logits, models_[0], mstates, generation_cfg, rngs); - RECORD_EVENT(trace_recorder_, request_ids, "finish sampling"); + sampler_->BatchSampleTokens(probs_device, request_ids, generation_cfg, rngs); ICHECK_EQ(next_tokens.size(), num_requests); // - Update the committed tokens of states. @@ -122,16 +129,20 @@ class BatchDecodeActionObj : public EngineActionObj { * models, the `Step` function of the created action will not take effect. */ Array models_; + /*! \brief The logit processor. */ + LogitProcessor logit_processor_; /*! \brief The sampler to sample new tokens. */ Sampler sampler_; /*! \brief Event trace recorder. */ Optional trace_recorder_; }; -EngineAction EngineAction::BatchDecode(Array models, Sampler sampler, +EngineAction EngineAction::BatchDecode(Array models, LogitProcessor logit_processor, + Sampler sampler, Optional trace_recorder) { - return EngineAction(make_object(std::move(models), std::move(sampler), - std::move(trace_recorder))); + return EngineAction( + make_object(std::move(models), std::move(logit_processor), + std::move(sampler), std::move(trace_recorder))); } } // namespace serve diff --git a/cpp/serve/engine_actions/batch_draft.cc b/cpp/serve/engine_actions/batch_draft.cc index 3f5622cc6d..403350c4af 100644 --- a/cpp/serve/engine_actions/batch_draft.cc +++ b/cpp/serve/engine_actions/batch_draft.cc @@ -20,9 +20,10 @@ namespace serve { */ class BatchDraftActionObj : public EngineActionObj { public: - explicit BatchDraftActionObj(Array models, Sampler sampler, + explicit BatchDraftActionObj(Array models, LogitProcessor logit_processor, Sampler sampler, Optional trace_recorder, int draft_length) : models_(std::move(models)), + logit_processor_(std::move(logit_processor)), sampler_(std::move(sampler)), trace_recorder_(std::move(trace_recorder)), draft_length_(draft_length) { @@ -102,13 +103,19 @@ class BatchDraftActionObj : public EngineActionObj { ICHECK_EQ(logits->shape[0], embeddings->shape[0]); ICHECK_EQ(logits->shape[1], 1); + // - Update logits. + logits = logits.CreateView({num_requests, logits->shape[2]}, logits->dtype); + logit_processor_->InplaceUpdateLogits(logits, generation_cfg, mstates, request_ids); + + // - Compute probability distributions. + NDArray probs_device = + logit_processor_->ComputeProbsFromLogits(logits, generation_cfg, request_ids); + // - Sample tokens. - RECORD_EVENT(trace_recorder_, request_ids, "start proposal sampling"); std::vector prob_dist; std::vector token_probs; std::vector next_tokens = sampler_->BatchSampleTokens( - logits, models_[model_id], mstates, generation_cfg, rngs, &prob_dist, &token_probs); - RECORD_EVENT(trace_recorder_, request_ids, "finish proposal sampling"); + probs_device, request_ids, generation_cfg, rngs, &prob_dist, &token_probs); ICHECK_EQ(next_tokens.size(), num_requests); // - Update the draft tokens, prob dist, token probs of states. @@ -143,6 +150,8 @@ class BatchDraftActionObj : public EngineActionObj { /*! \brief The model to run draft generation in speculative decoding. */ Array models_; + /*! \brief The logit processor. */ + LogitProcessor logit_processor_; /*! \brief The sampler to sample new tokens. */ Sampler sampler_; /*! \brief Event trace recorder. */ @@ -151,11 +160,12 @@ class BatchDraftActionObj : public EngineActionObj { int draft_length_; }; -EngineAction EngineAction::BatchDraft(Array models, Sampler sampler, - Optional trace_recorder, +EngineAction EngineAction::BatchDraft(Array models, LogitProcessor logit_processor, + Sampler sampler, Optional trace_recorder, int draft_length) { - return EngineAction(make_object(std::move(models), std::move(sampler), - std::move(trace_recorder), draft_length)); + return EngineAction(make_object( + std::move(models), std::move(logit_processor), std::move(sampler), std::move(trace_recorder), + draft_length)); } } // namespace serve diff --git a/cpp/serve/engine_actions/batch_verify.cc b/cpp/serve/engine_actions/batch_verify.cc index ef33449fd7..e4aa836127 100644 --- a/cpp/serve/engine_actions/batch_verify.cc +++ b/cpp/serve/engine_actions/batch_verify.cc @@ -26,9 +26,11 @@ namespace serve { */ class BatchVerifyActionObj : public EngineActionObj { public: - explicit BatchVerifyActionObj(Array models, Sampler sampler, KVCacheConfig kv_cache_config, + explicit BatchVerifyActionObj(Array models, LogitProcessor logit_processor, + Sampler sampler, KVCacheConfig kv_cache_config, Optional trace_recorder) : models_(std::move(models)), + logit_processor_(std::move(logit_processor)), sampler_(std::move(sampler)), kv_cache_config_(std::move(kv_cache_config)), trace_recorder_(std::move(trace_recorder)), @@ -103,13 +105,22 @@ class BatchVerifyActionObj : public EngineActionObj { ICHECK_EQ(logits->shape[0], 1); ICHECK_EQ(logits->shape[1], total_draft_length); + // - Update logits. std::vector cum_verify_lengths = {0}; for (int i = 0; i < num_requests; ++i) { cum_verify_lengths.push_back(cum_verify_lengths.back() + draft_lengths[i]); } + logits = logits.CreateView({total_draft_length, logits->shape[2]}, logits->dtype); + logit_processor_->InplaceUpdateLogits(logits, generation_cfg, verify_request_mstates, + request_ids, &cum_verify_lengths, &draft_output_tokens); + + // - Compute probability distributions. + NDArray probs_device = logit_processor_->ComputeProbsFromLogits( + logits, generation_cfg, request_ids, &cum_verify_lengths); + std::vector> accepted_tokens_arr = sampler_->BatchVerifyDraftTokens( - logits, cum_verify_lengths, models_[verify_model_id_], verify_request_mstates, - generation_cfg, rngs, draft_output_tokens, draft_output_token_prob, draft_output_prob_dist); + probs_device, request_ids, cum_verify_lengths, verify_request_mstates, generation_cfg, rngs, + draft_output_tokens, draft_output_token_prob, draft_output_prob_dist); ICHECK_EQ(accepted_tokens_arr.size(), num_requests); for (int i = 0; i < num_requests; ++i) { @@ -222,6 +233,8 @@ class BatchVerifyActionObj : public EngineActionObj { * models, the `Step` function of the created action will not take effect. */ Array models_; + /*! \brief The logit processor. */ + LogitProcessor logit_processor_; /*! \brief The sampler to sample new tokens. */ Sampler sampler_; /*! \brief The kv cache config. */ @@ -233,15 +246,15 @@ class BatchVerifyActionObj : public EngineActionObj { /*! \brief The ids of verify/draft models. */ const int verify_model_id_ = 0; const int draft_model_id_ = 1; - const float eps_ = 1e-9; + const float eps_ = 1e-5; }; -EngineAction EngineAction::BatchVerify(Array models, Sampler sampler, - KVCacheConfig kv_cache_config, +EngineAction EngineAction::BatchVerify(Array models, LogitProcessor logit_processor, + Sampler sampler, KVCacheConfig kv_cache_config, Optional trace_recorder) { - return EngineAction(make_object(std::move(models), std::move(sampler), - std::move(kv_cache_config), - std::move(trace_recorder))); + return EngineAction(make_object( + std::move(models), std::move(logit_processor), std::move(sampler), std::move(kv_cache_config), + std::move(trace_recorder))); } } // namespace serve diff --git a/cpp/serve/engine_actions/new_request_prefill.cc b/cpp/serve/engine_actions/new_request_prefill.cc index bf0d607c92..a3f1b2d17c 100644 --- a/cpp/serve/engine_actions/new_request_prefill.cc +++ b/cpp/serve/engine_actions/new_request_prefill.cc @@ -18,10 +18,11 @@ namespace serve { */ class NewRequestPrefillActionObj : public EngineActionObj { public: - explicit NewRequestPrefillActionObj(Array models, Sampler sampler, - KVCacheConfig kv_cache_config, + explicit NewRequestPrefillActionObj(Array models, LogitProcessor logit_processor, + Sampler sampler, KVCacheConfig kv_cache_config, Optional trace_recorder) : models_(std::move(models)), + logit_processor_(std::move(logit_processor)), sampler_(std::move(sampler)), kv_cache_config_(std::move(kv_cache_config)), trace_recorder_(std::move(trace_recorder)) {} @@ -87,23 +88,31 @@ class NewRequestPrefillActionObj : public EngineActionObj { } } - // - Sample tokens. + // - Update logits. ICHECK(logits_for_sample.defined()); - logits_for_sample = logits_for_sample.CreateView({num_requests, 1, logits_for_sample->shape[2]}, - logits_for_sample->dtype); + Array generation_cfg; Array mstates_for_sample; std::vector rngs; + generation_cfg.reserve(num_requests); mstates_for_sample.reserve(num_requests); rngs.reserve(num_requests); for (int i = 0; i < num_requests; ++i) { + generation_cfg.push_back(requests[i]->generation_cfg); mstates_for_sample.push_back(rstates[i]->mstates[0]); rngs.push_back(&rstates[i]->rng); } - RECORD_EVENT(trace_recorder_, request_ids, "start sampling"); - std::vector next_tokens = sampler_->BatchSampleTokens( - logits_for_sample, models_[0], mstates_for_sample, - requests.Map([](Request request) { return request->generation_cfg; }), rngs); - RECORD_EVENT(trace_recorder_, request_ids, "finish sampling"); + logits_for_sample = logits_for_sample.CreateView({num_requests, logits_for_sample->shape[2]}, + logits_for_sample->dtype); + logit_processor_->InplaceUpdateLogits(logits_for_sample, generation_cfg, mstates_for_sample, + request_ids); + + // - Compute probability distributions. + NDArray probs_device = + logit_processor_->ComputeProbsFromLogits(logits_for_sample, generation_cfg, request_ids); + + // - Sample tokens. + std::vector next_tokens = + sampler_->BatchSampleTokens(probs_device, request_ids, generation_cfg, rngs); ICHECK_EQ(next_tokens.size(), num_requests); // - Update the committed tokens of states. @@ -199,6 +208,8 @@ class NewRequestPrefillActionObj : public EngineActionObj { /*! \brief The models to run prefill in. */ Array models_; + /*! \brief The logit processor. */ + LogitProcessor logit_processor_; /*! \brief The sampler to sample new tokens. */ Sampler sampler_; /*! \brief The KV cache config to help decide prefill is doable. */ @@ -207,12 +218,12 @@ class NewRequestPrefillActionObj : public EngineActionObj { Optional trace_recorder_; }; -EngineAction EngineAction::NewRequestPrefill(Array models, Sampler sampler, - KVCacheConfig kv_cache_config, +EngineAction EngineAction::NewRequestPrefill(Array models, LogitProcessor logit_processor, + Sampler sampler, KVCacheConfig kv_cache_config, Optional trace_recorder) { - return EngineAction(make_object(std::move(models), std::move(sampler), - std::move(kv_cache_config), - std::move(trace_recorder))); + return EngineAction(make_object( + std::move(models), std::move(logit_processor), std::move(sampler), std::move(kv_cache_config), + std::move(trace_recorder))); } } // namespace serve diff --git a/cpp/serve/function_table.cc b/cpp/serve/function_table.cc index 6dce770dc6..c4ebbe4be3 100644 --- a/cpp/serve/function_table.cc +++ b/cpp/serve/function_table.cc @@ -100,12 +100,9 @@ void FunctionTable::Init(TVMArgValue reload_lib, Device device, picojson::object this->get_global_func = [this](const std::string& name) -> PackedFunc { return SessionFuncAsPackedFunc(sess, sess->GetGlobalFunc(name), name); }; + this->model_metadata_ = + ModelMetadata::FromModule(this->disco_mod->DebugGetFromRemote(0), std::move(model_config)); this->_InitFunctions(); - { - Module mod = this->disco_mod->DebugGetFromRemote(0); - this->softmax_func_ = mod->GetFunction("softmax_with_temperature"); - this->model_metadata_ = ModelMetadata::FromModule(mod, std::move(model_config)); - } } else { Module executable{nullptr}; if (reload_lib.type_code() == kTVMModuleHandle) { @@ -193,7 +190,11 @@ void FunctionTable::_InitFunctions() { this->prefill_func_ = mod_get_func("batch_prefill"); this->decode_func_ = mod_get_func("batch_decode"); this->verify_func_ = mod_get_func("batch_verify"); - this->softmax_func_ = mod_get_func("softmax_with_temperature"); + Module mod = this->use_disco ? this->disco_mod->DebugGetFromRemote(0) : this->local_vm; + this->softmax_func_ = mod->GetFunction("softmax_with_temperature", true); + this->apply_logit_bias_func_ = mod->GetFunction("apply_logit_bias_inplace", true); + this->apply_penalty_func_ = mod->GetFunction("apply_penalty_inplace", true); + this->apply_bitmask_func_ = mod->GetFunction("apply_bitmask_inplace", true); this->create_kv_cache_func_ = mod_get_func("create_flashinfer_paged_kv_cache"); if (!this->create_kv_cache_func_.defined()) { this->create_kv_cache_func_ = mod_get_func("create_tir_paged_kv_cache"); diff --git a/cpp/serve/function_table.h b/cpp/serve/function_table.h index 24c6180707..e37b0e6f89 100644 --- a/cpp/serve/function_table.h +++ b/cpp/serve/function_table.h @@ -71,6 +71,9 @@ struct FunctionTable { PackedFunc decode_func_; PackedFunc verify_func_; PackedFunc softmax_func_; + PackedFunc apply_logit_bias_func_; + PackedFunc apply_penalty_func_; + PackedFunc apply_bitmask_func_; PackedFunc create_kv_cache_func_; PackedFunc reset_kv_cache_func_; bool support_backtracking_kv_; diff --git a/cpp/serve/logit_processor.cc b/cpp/serve/logit_processor.cc new file mode 100644 index 0000000000..a45c1f9f13 --- /dev/null +++ b/cpp/serve/logit_processor.cc @@ -0,0 +1,404 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file serve/logit_processor.cc + * \brief The implementation of logit processor. + */ +#include "logit_processor.h" + +#include +#include +#include + +namespace mlc { +namespace llm { +namespace serve { + +inline void CopyArray(NDArray src, NDArray dst) { + DLTensor dl_dst = *(dst.operator->()); + NDArray::CopyFromTo(src.operator->(), &dl_dst); +} + +/***************** LogitProcessor Implementation *****************/ + +TVM_REGISTER_OBJECT_TYPE(LogitProcessorObj); + +class LogitProcessorImpl : public LogitProcessorObj { + public: + /*! * \brief Constructor of LogitProcessorImpl. */ + explicit LogitProcessorImpl(int max_num_token, int vocab_size, FunctionTable* ft, DLDevice device, + Optional trace_recorder) + : max_num_token_(max_num_token), + vocab_size_(vocab_size), + bitmask_size_((vocab_size + 31) / 32), + softmax_func_(ft->softmax_func_), + device_(device), + apply_logit_bias_func_(ft->apply_logit_bias_func_), + apply_penalty_func_(ft->apply_penalty_func_), + apply_bitmask_func_(ft->apply_bitmask_func_), + trace_recorder_(std::move(trace_recorder)) { + DLDevice device_cpu{DLDeviceType::kDLCPU, /*device_id=*/0}; + // Initialize auxiliary arrays on CPU. + seq_ids_host_ = NDArray::Empty({max_num_token}, dtype_i32_, device_cpu); + pos2seq_id_host_ = NDArray::Empty({max_num_token * vocab_size}, dtype_i32_, device_cpu); + token_ids_host_ = NDArray::Empty({max_num_token * vocab_size}, dtype_i32_, device_cpu); + token_cnt_host_ = NDArray::Empty({max_num_token * vocab_size}, dtype_i32_, device_cpu); + token_logit_bias_host_ = NDArray::Empty({max_num_token * vocab_size}, dtype_f32_, device_cpu); + penalties_host_ = NDArray::Empty({max_num_token, 3}, dtype_f32_, device_cpu); + bitmask_host_ = NDArray::Empty({max_num_token, bitmask_size_}, dtype_i32_, device_cpu); + temperature_host_ = NDArray::Empty({max_num_token}, dtype_f32_, device_cpu); + // Initialize auxiliary arrays on GPU. + seq_ids_device_ = NDArray::Empty({max_num_token}, dtype_i32_, device); + pos2seq_id_device_ = NDArray::Empty({max_num_token * vocab_size}, dtype_i32_, device); + token_ids_device_ = NDArray::Empty({max_num_token * vocab_size}, dtype_i32_, device); + token_cnt_device_ = NDArray::Empty({max_num_token * vocab_size}, dtype_i32_, device); + token_logit_bias_device_ = NDArray::Empty({max_num_token * vocab_size}, dtype_f32_, device); + penalties_device_ = NDArray::Empty({max_num_token, 3}, dtype_f32_, device); + bitmask_device_ = NDArray::Empty({max_num_token, bitmask_size_}, dtype_i32_, device); + temperature_device_ = NDArray::Empty({max_num_token}, dtype_f32_, device); + + CHECK(apply_logit_bias_func_.defined()) + << "Function \"apply_logit_bias_inplace\" not found in model"; + CHECK(apply_penalty_func_.defined()) << "Function \"apply_penalty_inplace\" not found in model"; + CHECK(apply_bitmask_func_.defined()) << "Function \"apply_bitmask_inplace\" not found in model"; + } + + void InplaceUpdateLogits(NDArray logits, // + const Array& generation_cfg, // + const Array& mstates, // + const Array& request_ids, // + const std::vector* cum_num_token, // + const std::vector>* draft_tokens) final { + CHECK_EQ(logits->ndim, 2); + CHECK_EQ(logits->shape[1], vocab_size_); + CHECK(logits.DataType() == DataType::Float(32)); + CHECK_EQ(generation_cfg.size(), mstates.size()); + CHECK_LE(logits->shape[0], max_num_token_); + int num_total_token = logits->shape[0]; + int num_sequence = generation_cfg.size(); + + CHECK((cum_num_token == nullptr) == (draft_tokens == nullptr)); + if (cum_num_token != nullptr) { + CHECK_EQ(draft_tokens->size(), num_sequence); + CHECK_EQ(cum_num_token->size(), num_sequence + 1); + CHECK_EQ(cum_num_token->back(), num_total_token); + } else { + CHECK_EQ(num_sequence, num_total_token); + } + + RECORD_EVENT(trace_recorder_, request_ids, "start update logits"); + + // Update 1. logit bias + RECORD_EVENT(trace_recorder_, request_ids, "start apply logit bias"); + UpdateWithLogitBias(logits, generation_cfg, cum_num_token); + RECORD_EVENT(trace_recorder_, request_ids, "finish apply logit bias"); + + // Update 2. penalties + RECORD_EVENT(trace_recorder_, request_ids, "start apply penalty"); + UpdateWithPenalty(logits, generation_cfg, mstates, cum_num_token, draft_tokens); + RECORD_EVENT(trace_recorder_, request_ids, "finish apply penalty"); + + // Update 3. Vocabulary mask. + RECORD_EVENT(trace_recorder_, request_ids, "start apply logit mask"); + UpdateWithMask(logits, mstates, cum_num_token, draft_tokens); + RECORD_EVENT(trace_recorder_, request_ids, "finish apply logit mask"); + + RECORD_EVENT(trace_recorder_, request_ids, "finish update logits"); + } + + NDArray ComputeProbsFromLogits(NDArray logits, const Array& generation_cfg, + const Array& request_ids, + const std::vector* cum_num_token) final { + // logits: (n, v) + CHECK_EQ(logits->ndim, 2); + CHECK_LE(logits->shape[0], max_num_token_); + CHECK_EQ(logits->shape[1], vocab_size_); + CHECK(logits.DataType() == DataType::Float(32)); + int num_total_token = logits->shape[0]; + int num_sequence = generation_cfg.size(); + + if (cum_num_token != nullptr) { + CHECK_EQ(cum_num_token->size(), num_sequence + 1); + CHECK_EQ(cum_num_token->back(), num_total_token); + } else { + CHECK_EQ(num_sequence, num_total_token); + } + + RECORD_EVENT(trace_recorder_, request_ids, "start softmax"); + + // Construct: + // - temperature (max_num_token,) float32 + float* p_temperature = static_cast(temperature_host_->data); + + // - Set arrays. + for (int i = 0; i < num_sequence; ++i) { + int num_token_to_process = + cum_num_token == nullptr ? 1 : (cum_num_token->at(i + 1) - cum_num_token->at(i)); + int token_offset = cum_num_token == nullptr ? i : cum_num_token->at(i); + for (int j = 0; j < num_token_to_process; ++j) { + p_temperature[token_offset + j] = std::max(generation_cfg[i]->temperature, eps_); + } + } + + // - View arrays. + NDArray temperature_host = temperature_host_.CreateView({num_total_token}, dtype_f32_); + NDArray temperature_device = temperature_device_.CreateView({num_total_token}, dtype_f32_); + + // - Copy arrays to GPU. + CopyArray(/*src=*/temperature_host, /*dst=*/temperature_device); + + // - Call kernel. + NDArray probs = softmax_func_(logits.CreateView({num_total_token, 1, vocab_size_}, dtype_f32_), + temperature_device); + ICHECK_EQ(probs->ndim, 3); + ICHECK_EQ(probs->shape[0], num_total_token); + ICHECK_EQ(probs->shape[1], 1); + ICHECK_EQ(probs->shape[2], vocab_size_); + if (trace_recorder_.defined()) { + TVMSynchronize(device_.device_type, device_.device_id, /*stream=*/nullptr); + } + RECORD_EVENT(trace_recorder_, request_ids, "finish softmax"); + return probs.CreateView({num_total_token, vocab_size_}, probs->dtype); + } + + private: + void UpdateWithLogitBias(NDArray logits, const Array& generation_cfg, + const std::vector* cum_num_token) { + // Construct: + // - pos2seq_id (max_num_token * vocab_size,) int32 + // - token_ids (max_num_token * vocab_size,) int32 + // - token_logit_bias (max_num_token * vocab_size,) float32 + int* p_pos2seq_id = static_cast(pos2seq_id_host_->data); + int* p_token_ids = static_cast(token_ids_host_->data); + float* p_token_logit_bias = static_cast(token_logit_bias_host_->data); + + // - Set arrays. + int num_token_for_bias = 0; + int num_bias_token = 0; + for (int i = 0; i < static_cast(generation_cfg.size()); ++i) { + int num_token_to_process = + cum_num_token == nullptr ? 1 : (cum_num_token->at(i + 1) - cum_num_token->at(i)); + int token_offset = cum_num_token == nullptr ? i : cum_num_token->at(i); + for (int j = 0; j < num_token_to_process; ++j) { + if (!generation_cfg[i]->logit_bias.empty()) { + for (auto [token_id, bias] : generation_cfg[i]->logit_bias) { + p_pos2seq_id[num_bias_token] = token_offset + j; + p_token_ids[num_bias_token] = token_id; + p_token_logit_bias[num_bias_token] = bias; + ++num_bias_token; + } + ++num_token_for_bias; + } + } + } + + if (num_token_for_bias == 0) { + return; + } + + // - View arrays. + int num_token = num_bias_token; + NDArray pos2seq_id_host = pos2seq_id_host_.CreateView({num_token}, dtype_i32_); + NDArray pos2seq_id_device = pos2seq_id_device_.CreateView({num_token}, dtype_i32_); + NDArray token_ids_host = token_ids_host_.CreateView({num_token}, dtype_i32_); + NDArray token_ids_device = token_ids_device_.CreateView({num_token}, dtype_i32_); + NDArray token_logit_bias_host = token_logit_bias_host_.CreateView({num_token}, dtype_f32_); + NDArray token_logit_bias_device = token_logit_bias_device_.CreateView({num_token}, dtype_f32_); + + // - Copy arrays to GPU. + CopyArray(/*src=*/pos2seq_id_host, /*dst=*/pos2seq_id_device); + CopyArray(/*src=*/token_ids_host, /*dst=*/token_ids_device); + CopyArray(/*src=*/token_logit_bias_host, /*dst=*/token_logit_bias_device); + + // - Call kernel. + apply_logit_bias_func_(logits, pos2seq_id_device, token_ids_device, token_logit_bias_device); + if (trace_recorder_.defined()) { + TVMSynchronize(device_.device_type, device_.device_id, /*stream=*/nullptr); + } + } + + void UpdateWithPenalty(NDArray logits, const Array& generation_cfg, + const Array& mstates, + const std::vector* cum_num_token, + const std::vector>* draft_tokens) { + // Construct: + // - seq_ids (max_num_token,) int32 + // - pos2seq_id (max_num_token * vocab_size,) int32 + // - token_ids (max_num_token * vocab_size,) int32 + // - token_cnt (max_num_token * vocab_size,) int32 + // - penalties (max_num_token, 3) float32 + int* p_seq_ids = static_cast(seq_ids_host_->data); + int* p_pos2seq_id = static_cast(pos2seq_id_host_->data); + int* p_token_ids = static_cast(token_ids_host_->data); + int* p_token_cnt = static_cast(token_cnt_host_->data); + float* p_penalties = static_cast(penalties_host_->data); + + // - Set arrays. + int num_token_for_penalty = 0; + int num_penalty_appeared_token = 0; + for (int i = 0; i < static_cast(generation_cfg.size()); ++i) { + if (generation_cfg[i]->frequency_penalty != 0.0 || + generation_cfg[i]->presence_penalty != 0.0 || + generation_cfg[i]->repetition_penalty != 1.0) { + int num_token_to_process = + cum_num_token == nullptr ? 1 : (cum_num_token->at(i + 1) - cum_num_token->at(i)); + int token_offset = cum_num_token == nullptr ? i : cum_num_token->at(i); + CHECK(num_token_to_process == 1 || mstates[i]->draft_output_tokens.empty()); + for (int j = 0; j < num_token_to_process; ++j) { + p_seq_ids[num_token_for_penalty] = token_offset + j; + for (auto [token_id, cnt] : mstates[i]->appeared_token_ids) { + p_pos2seq_id[num_penalty_appeared_token] = num_token_for_penalty; + p_token_ids[num_penalty_appeared_token] = token_id; + p_token_cnt[num_penalty_appeared_token] = cnt; + ++num_penalty_appeared_token; + } + p_penalties[num_token_for_penalty * 3] = generation_cfg[i]->presence_penalty; + p_penalties[num_token_for_penalty * 3 + 1] = generation_cfg[i]->frequency_penalty; + p_penalties[num_token_for_penalty * 3 + 2] = generation_cfg[i]->repetition_penalty; + ++num_token_for_penalty; + if (j > 0) { + mstates[i]->AddDraftToken(draft_tokens->at(i)[j - 1]); + } + } + if (num_token_to_process != 1) { + // Roll back. + mstates[i]->RemoveAllDraftTokens(); + } + } + } + + if (num_token_for_penalty == 0) { + return; + } + + // - View arrays. + int num_seq = num_token_for_penalty; + int num_token = num_penalty_appeared_token; + NDArray seq_ids_host = seq_ids_host_.CreateView({num_seq}, dtype_i32_); + NDArray seq_ids_device = seq_ids_device_.CreateView({num_seq}, dtype_i32_); + NDArray pos2seq_id_host = pos2seq_id_host_.CreateView({num_token}, dtype_i32_); + NDArray pos2seq_id_device = pos2seq_id_device_.CreateView({num_token}, dtype_i32_); + NDArray token_ids_host = token_ids_host_.CreateView({num_token}, dtype_i32_); + NDArray token_ids_device = token_ids_device_.CreateView({num_token}, dtype_i32_); + NDArray token_cnt_host = token_cnt_host_.CreateView({num_token}, dtype_i32_); + NDArray token_cnt_device = token_cnt_device_.CreateView({num_token}, dtype_i32_); + NDArray penalties_host = penalties_host_.CreateView({num_seq, 3}, dtype_f32_); + NDArray penalties_device = penalties_device_.CreateView({num_seq, 3}, dtype_f32_); + + // - Copy arrays to GPU. + CopyArray(/*src=*/seq_ids_host, /*dst=*/seq_ids_device); + CopyArray(/*src=*/pos2seq_id_host, /*dst=*/pos2seq_id_device); + CopyArray(/*src=*/token_ids_host, /*dst=*/token_ids_device); + CopyArray(/*src=*/token_cnt_host, /*dst=*/token_cnt_device); + CopyArray(/*src=*/penalties_host, /*dst=*/penalties_device); + + // - Call kernel. + apply_penalty_func_(logits, seq_ids_device, pos2seq_id_device, token_ids_device, + token_cnt_device, penalties_device); + if (trace_recorder_.defined()) { + TVMSynchronize(device_.device_type, device_.device_id, /*stream=*/nullptr); + } + } + + void UpdateWithMask(NDArray logits, const Array& mstates, + const std::vector* cum_num_token, + const std::vector>* draft_tokens) { + // Construct: + // - seq_ids (max_num_token,) int32 + // - bitmask (max_num_token, ceildiv(vocab_size, 32)), int32 + int* p_seq_ids = static_cast(seq_ids_host_->data); + int* p_bitmask = static_cast(bitmask_host_->data); + + // - Set arrays. + int num_token_for_mask = 0; + for (int i = 0; i < static_cast(mstates.size()); ++i) { + int num_token_to_process = + cum_num_token == nullptr ? 1 : (cum_num_token->at(i + 1) - cum_num_token->at(i)); + int token_offset = cum_num_token == nullptr ? i : cum_num_token->at(i); + CHECK(num_token_to_process == 1 || mstates[i]->draft_output_tokens.empty()); + for (int j = 0; j < num_token_to_process; ++j) { + std::vector bitmask = mstates[i]->GetTokenBitmask(vocab_size_); + if (!bitmask.empty()) { + p_seq_ids[num_token_for_mask] = token_offset + j; + ICHECK_EQ(bitmask.size(), bitmask_size_); + for (int p = 0; p < bitmask_size_; ++p) { + p_bitmask[num_token_for_mask * bitmask_size_ + p] = bitmask[p]; + } + ++num_token_for_mask; + } + if (j > 0) { + mstates[i]->AddDraftToken(draft_tokens->at(i)[j - 1]); + } + } + if (num_token_to_process != 1) { + // Roll back. + mstates[i]->RemoveAllDraftTokens(); + } + } + + if (num_token_for_mask == 0) { + return; + } + + // - View arrays. + int num_seq = num_token_for_mask; + NDArray seq_ids_host = seq_ids_host_.CreateView({num_seq}, dtype_i32_); + NDArray seq_ids_device = seq_ids_device_.CreateView({num_seq}, dtype_i32_); + NDArray bitmask_host = bitmask_host_.CreateView({num_seq, bitmask_size_}, dtype_i32_); + NDArray bitmask_device = bitmask_device_.CreateView({num_seq, bitmask_size_}, dtype_i32_); + + // - Copy arrays to GPU. + CopyArray(/*src=*/seq_ids_host, /*dst=*/seq_ids_device); + CopyArray(/*src=*/bitmask_host, /*dst=*/bitmask_device); + + // - Call kernel. + apply_bitmask_func_(logits, seq_ids_device, bitmask_device); + if (trace_recorder_.defined()) { + TVMSynchronize(device_.device_type, device_.device_id, /*stream=*/nullptr); + } + } + + // Model configurations + const int max_num_token_; + const int vocab_size_; + const int bitmask_size_; + const DLDataType dtype_i32_ = DataType::Int(32); + const DLDataType dtype_f32_ = DataType::Float(32); + // Packed functions. + Device device_; + PackedFunc softmax_func_; + PackedFunc apply_logit_bias_func_; + PackedFunc apply_penalty_func_; + PackedFunc apply_bitmask_func_; + // Auxiliary NDArrays on CPU + NDArray seq_ids_host_; + NDArray pos2seq_id_host_; + NDArray token_ids_host_; + NDArray token_cnt_host_; + NDArray token_logit_bias_host_; + NDArray penalties_host_; + NDArray bitmask_host_; + NDArray temperature_host_; + // Auxiliary NDArrays on GPU + NDArray seq_ids_device_; + NDArray pos2seq_id_device_; + NDArray token_ids_device_; + NDArray token_cnt_device_; + NDArray token_logit_bias_device_; + NDArray penalties_device_; + NDArray bitmask_device_; + NDArray temperature_device_; + // Event trace recorder. + Optional trace_recorder_; + // A small epsilon. + const double eps_ = 1e-5; +}; + +LogitProcessor::LogitProcessor(int max_num_token, int vocab_size, FunctionTable* ft, + DLDevice device, Optional trace_recorder) { + data_ = make_object(max_num_token, vocab_size, ft, device, + std::move(trace_recorder)); +} + +} // namespace serve +} // namespace llm +} // namespace mlc diff --git a/cpp/serve/logit_processor.h b/cpp/serve/logit_processor.h new file mode 100644 index 0000000000..2425542731 --- /dev/null +++ b/cpp/serve/logit_processor.h @@ -0,0 +1,94 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file serve/logit_processor.h + * \brief The header for logit processor. + */ + +#ifndef MLC_LLM_SERVE_LOGIT_PROCESSOR_H_ +#define MLC_LLM_SERVE_LOGIT_PROCESSOR_H_ + +#include +#include + +#include "../base.h" +#include "config.h" +#include "event_trace_recorder.h" +#include "function_table.h" +#include "request_state.h" + +namespace mlc { +namespace llm { +namespace serve { + +using tvm::Device; +using namespace tvm::runtime; + +/*! + * \brief The logit processor class that updates logits with regard + * presence/frequency penalties, logit bias, etc.. + */ +class LogitProcessorObj : public Object { + public: + /*! + * \brief In-place update a batch of logits with regard to the given + * generation config and request states. + * \param logits The batch of raw logits, in shape (num_total_token, vocab_size), + * where `num_total_token` may be larger than the number of sequences + * indicated by `generation_cfg`, in which case some sequences may have + * more than one token. + * \param generation_cfg The generation config of each sequence in the batch. + * \param mstates The request states of each sequence in the batch. + * \param request_ids The ids of each request. + * \param cum_num_token The pointer to the cumulative token length of the sequences. + * If the pointer is nullptr, it means each sequence has only one token. + * \param draft_tokens The pointer to the draft tokens of each sequence + * when speculation is enabled, in which case some sequences may have + * more than one token. + */ + virtual void InplaceUpdateLogits(NDArray logits, const Array& generation_cfg, + const Array& mstates, + const Array& request_ids, + const std::vector* cum_num_token = nullptr, + const std::vector>* draft_tokens = nullptr) = 0; + + /*! + * \brief Compute probability distributions for the input batch of logits. + * \param logits The batch of updated logits. + * \param generation_cfg The generation config of each sequence in the batch. + * \param request_ids The ids of each request. + * \param cum_num_token The pointer to the cumulative token length of the sequences. + * If the pointer is nullptr, it means each sequence has only one token. + * \return The batch of computed probability distributions on GPU. + */ + virtual NDArray ComputeProbsFromLogits(NDArray logits, + const Array& generation_cfg, + const Array& request_ids, + const std::vector* cum_num_token = nullptr) = 0; + + static constexpr const char* _type_key = "mlc.serve.LogitProcessor"; + static constexpr const bool _type_has_method_sequal_reduce = false; + static constexpr const bool _type_has_method_shash_reduce = false; + TVM_DECLARE_BASE_OBJECT_INFO(LogitProcessorObj, Object); +}; + +class LogitProcessor : public ObjectRef { + public: + /*! + * \brief Constructor. + * \param max_num_token The max number of tokens in the token processor. + * \param vocab_size The model's vocabulary size. + * \param ft The packed function table. + * \param device The device that the model runs on. + * \param trace_recorder The event trace recorder. + */ + explicit LogitProcessor(int max_num_token, int vocab_size, FunctionTable* ft, DLDevice device, + Optional trace_recorder); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(LogitProcessor, ObjectRef, LogitProcessorObj); +}; + +} // namespace serve +} // namespace llm +} // namespace mlc + +#endif // MLC_LLM_SERVE_LOGIT_PROCESSOR_H_ diff --git a/cpp/serve/model.cc b/cpp/serve/model.cc index 48ff463667..ecaa5276d8 100644 --- a/cpp/serve/model.cc +++ b/cpp/serve/model.cc @@ -11,6 +11,8 @@ #include +#include "logit_processor.h" + namespace mlc { namespace llm { namespace serve { @@ -350,34 +352,14 @@ class ModelImpl : public ModelObj { return logits; } - NDArray SoftmaxWithTemperature(NDArray logits, Array generation_cfg) final { - // logits: (b, n, v) - CHECK_EQ(logits->ndim, 3); - CHECK_EQ(logits->shape[0], generation_cfg.size()); - CHECK_EQ(logits->device.device_type, device_.device_type); - CHECK_EQ(logits->device.device_id, device_.device_id); - - int batch_size = logits->shape[0]; - std::vector temperatures; - temperatures.reserve(batch_size); - for (GenerationConfig cfg : generation_cfg) { - temperatures.push_back(cfg->temperature); - } - NDArray temperatures_nd = - CopyArrayToDevice(temperatures, &temperature_arr_, logits->dtype, 32, device_); - ICHECK_EQ(temperatures_nd->ndim, 1); - ICHECK_EQ(temperatures_nd->shape[0], batch_size); - - NDArray probs = ft_.softmax_func_(logits, temperatures_nd); - ICHECK_EQ(probs->ndim, 3); - ICHECK_EQ(probs->shape[0], logits->shape[0]); - ICHECK_EQ(probs->shape[1], logits->shape[1]); - ICHECK_EQ(probs->shape[2], logits->shape[2]); - return probs; - } - /*********************** KV Cache Management ***********************/ + LogitProcessor CreateLogitProcessor(int max_num_token, + Optional trace_recorder) { + return LogitProcessor(max_num_token, vocab_size_, &this->ft_, device_, + std::move(trace_recorder)); + } + void CreateKVCache(KVCacheConfig kv_cache_config) final { IntTuple max_num_sequence{kv_cache_config->max_num_sequence}; IntTuple max_total_sequence_length{kv_cache_config->max_total_sequence_length}; @@ -451,6 +433,12 @@ class ModelImpl : public ModelObj { } else { LOG(FATAL) << "Key \"tensor_parallel_shards\" not found."; } + if (config.count("vocab_size")) { + CHECK(config["vocab_size"].is()); + this->vocab_size_ = config["vocab_size"].get(); + } else { + LOG(FATAL) << "Key \"vocab_size\" not found."; + } return config; } @@ -460,6 +448,7 @@ class ModelImpl : public ModelObj { int max_window_size_ = -1; int num_shards_ = -1; int max_num_sequence_ = -1; + int vocab_size_ = -1; //---------------------------- // TVM related states //---------------------------- diff --git a/cpp/serve/model.h b/cpp/serve/model.h index 72a869198e..b561b7895e 100644 --- a/cpp/serve/model.h +++ b/cpp/serve/model.h @@ -12,7 +12,9 @@ #include "../base.h" #include "config.h" +#include "event_trace_recorder.h" #include "function_table.h" +#include "logit_processor.h" namespace mlc { namespace llm { @@ -92,15 +94,6 @@ class ModelObj : public Object { virtual NDArray BatchVerify(const NDArray& embeddings, const std::vector& seq_ids, const std::vector& lengths) = 0; - /*! - * \brief Computing probabilities from logits with softmax and temperatures. - * \param logits The logits to compute from. - * \param generation_cfg The generation config which contains the temperatures. - * \return The computed probabilities distribution. - */ - virtual NDArray SoftmaxWithTemperature(NDArray logits, - Array generation_cfg) = 0; - /*********************** KV Cache Management ***********************/ /*! @@ -123,6 +116,10 @@ class ModelObj : public Object { /*********************** Utilities ***********************/ + /*! \brief Create a logit processor from this model. */ + virtual LogitProcessor CreateLogitProcessor(int max_num_token, + Optional trace_recorder) = 0; + /*! * \brief Estimate number of CPU units required to drive the model * executing during TP. diff --git a/cpp/serve/request_state.cc b/cpp/serve/request_state.cc index a4b5297337..b721d32ac6 100644 --- a/cpp/serve/request_state.cc +++ b/cpp/serve/request_state.cc @@ -31,6 +31,11 @@ int RequestModelStateNode::GetInputLength() const { return total_length; } +std::vector RequestModelStateNode::GetTokenBitmask(int vocab_size) const { + // TODO(mlc-team): implement this function. + return std::vector(); +} + void RequestModelStateNode::CommitToken(int32_t token_id) { committed_tokens.push_back(token_id); appeared_token_ids[token_id] += 1; @@ -43,14 +48,17 @@ void RequestModelStateNode::AddDraftToken(int32_t token_id) { void RequestModelStateNode::RemoveLastDraftToken() { ICHECK(!draft_output_tokens.empty()); - appeared_token_ids[draft_output_tokens.back()] -= 1; + auto it = appeared_token_ids.find(draft_output_tokens.back()); draft_output_tokens.pop_back(); + CHECK(it != appeared_token_ids.end()); + if (--it->second == 0) { + appeared_token_ids.erase(it); + } } void RequestModelStateNode::RemoveAllDraftTokens() { while (!draft_output_tokens.empty()) { - appeared_token_ids[draft_output_tokens.back()] -= 1; - draft_output_tokens.pop_back(); + RemoveLastDraftToken(); } } diff --git a/cpp/serve/request_state.h b/cpp/serve/request_state.h index 82835d01df..ea0b688810 100644 --- a/cpp/serve/request_state.h +++ b/cpp/serve/request_state.h @@ -81,6 +81,11 @@ class RequestModelStateNode : public Object { /*! \brief Return the total length of the input data. */ int GetInputLength() const; + /*! + * \brief Return the token bitmask induced by the current state. + * The returned vector should have size "ceildiv(vocab_size, 32)". + */ + std::vector GetTokenBitmask(int vocab_size) const; /*! \brief Commit a new token into committed_tokens. Update appeared_token_ids. */ void CommitToken(int32_t token_id); /*! \brief Add a draft token into draft_output_tokens. Update appeared_token_ids. */ diff --git a/cpp/serve/sampler.cc b/cpp/serve/sampler.cc index 8ddfca527a..502bde72e6 100644 --- a/cpp/serve/sampler.cc +++ b/cpp/serve/sampler.cc @@ -18,128 +18,6 @@ namespace mlc { namespace llm { namespace serve { -/***** Utility function for in-place logits/prob update on CPU *****/ - -/*! - * \brief In-place apply repetition penalty to logits based on history tokens. - * \param logits The logits (a batch) to be in-place mutated. - * \param token_offset The offset of the token in the batch - * whose logits will be updated. - * \param state The request state that contains history tokens. - * \param repetition_penalty The value of repetition penalty. - */ -void ApplyRepetitionPenaltyOnCPU(NDArray logits, int token_offset, RequestModelState state, - double repetition_penalty) { - // logits: (n, v) - CHECK(logits.DataType() == DataType::Float(32)) << "Logits data type is not float32!"; - CHECK_EQ(logits->ndim, 2); - CHECK_EQ(logits->device.device_type, DLDeviceType::kDLCPU); - int vocab_size = logits->shape[1]; - - float* logits_raw_data = static_cast(logits->data) + (token_offset * vocab_size); - for (const auto& it : state->appeared_token_ids) { - int token_id = it.first; - ICHECK_GE(token_id, 0); - ICHECK_LT(token_id, vocab_size); - if (logits_raw_data[token_id] <= 0) { - logits_raw_data[token_id] *= repetition_penalty; - } else { - logits_raw_data[token_id] /= repetition_penalty; - } - } -} - -/*! - * \brief In-place apply frequency and presence penalty to logits based on history tokens. - * \param logits The logits (a batch) to be in-place mutated. - * \param token_offset The offset of the token in the batch - * whose logits will be updated. - * \param state The request state that contains history tokens. - * \param frequency_penalty The value of frequency penalty. - * \param presence_penalty The value of presence penalty. - */ -void ApplyFrequencyAndPresencePenaltyOnCPU(NDArray logits, int token_offset, - RequestModelState state, double frequency_penalty, - double presence_penalty) { - // logits: (n, v) - CHECK(logits.DataType() == DataType::Float(32)) << "Logits data type is not float32!"; - CHECK_EQ(logits->ndim, 2); - CHECK_EQ(logits->device.device_type, DLDeviceType::kDLCPU); - int vocab_size = logits->shape[1]; - - float* logits_raw_data = static_cast(logits->data) + (token_offset * vocab_size); - for (const auto& it : state->appeared_token_ids) { - int token_id = it.first; - int occurrences = it.second; - ICHECK_GE(token_id, 0); - ICHECK_LT(token_id, vocab_size); - logits_raw_data[token_id] -= occurrences * frequency_penalty + presence_penalty; - } -} - -/*! - * \brief In-place compute softmax with temperature on CPU. - * \param logits The logits (a batch) to compute softmax from. - * \param token_offset The offset of the token in the batch - * to compute softmax for. Only the logits of the specified - * token will be updated to probability after softmax. - * \param temperature The temperature to apply before softmax. - */ -void ApplySoftmaxWithTemperatureOnCPU(NDArray logits, int token_offset, double temperature) { - // logits: (n, v) - CHECK(logits.DataType() == DataType::Float(32)) << "Logits data type is not float32!"; - CHECK_EQ(logits->ndim, 2); - CHECK_EQ(logits->device.device_type, DLDeviceType::kDLCPU); - int vocab_size = logits->shape[1]; - - float* __restrict logits_raw_data = - static_cast(__builtin_assume_aligned(logits->data, 4)) + (token_offset * vocab_size); - float m = std::numeric_limits::min(); - float inv_temp = 1.0f / temperature; - double d = 0.0f; - for (int i = 0; i < vocab_size; ++i) { - float x = logits_raw_data[i] * inv_temp; - float m_prev = m; - m = std::max(m, x); - d = d * std::exp(m_prev - m) + std::exp(x - m); - } - for (int i = 0; i < vocab_size; ++i) { - float x = logits_raw_data[i] * inv_temp; - logits_raw_data[i] = std::exp(x - m) / d; - } -} - -/*! - * \brief In-place set probability via argmax. - * This is used for zero-temperature sampling cases. - * \param logits The logits (a batch) to set probability. - * \param token_offset The offset of the token in the batch - * to set probability for. Only the logits of the specified - * token will be updated to probability. - */ -void SetProbWithArgmaxOnCPU(NDArray logits, int token_offset) { - // logits: (n, v) - CHECK(logits.DataType() == DataType::Float(32)) << "Logits data type is not float32!"; - CHECK_EQ(logits->ndim, 2); - CHECK_EQ(logits->device.device_type, kDLCPU); - int vocab_size = logits->shape[1]; - - float* logits_raw_data = static_cast(logits->data) + (token_offset * vocab_size); - int argmax_pos = -1; - float max_logits = std::numeric_limits::lowest(); - for (int i = 0; i < vocab_size; ++i) { - if (logits_raw_data[i] > max_logits) { - max_logits = logits_raw_data[i]; - argmax_pos = i; - } - } - - ICHECK_NE(argmax_pos, -1); - for (int i = 0; i < vocab_size; ++i) { - logits_raw_data[i] = i == argmax_pos ? 1.0f : 0.0f; - } -} - /*! * \brief Sample a value from the input probability distribution with top-p. * The input is a batch of distributions, and we use `unit_offset` to specify @@ -181,6 +59,30 @@ std::pair SampleTopPFromProb(NDArray prob, int unit_offset, doub if (!(*output_prob_dist)[unit_offset].defined()) { (*output_prob_dist)[unit_offset] = NDArray::Empty({ndata}, prob->dtype, DLDevice{kDLCPU, 0}); } + } + + if (top_p == 0) { + // Specially handle case where top_p == 0. + // This case is equivalent to doing argmax. + int argmax_pos = -1; + float max_prob = 0.0; + for (int i = 0; i < ndata; ++i) { + if (p_prob[i] > max_prob) { + max_prob = p_prob[i]; + argmax_pos = i; + } + } + if (output_prob_dist) { + float* __restrict p_output_prob = + static_cast(__builtin_assume_aligned((*output_prob_dist)[unit_offset]->data, 4)); + for (int i = 0; i < ndata; ++i) { + p_output_prob[i] = i == argmax_pos ? 1.0 : 0.0; + } + } + return std::make_pair(1.0, argmax_pos); + } + + if (output_prob_dist) { (*output_prob_dist)[unit_offset].CopyFromBytes(p_prob, ndata * sizeof(float)); } @@ -193,7 +95,6 @@ std::pair SampleTopPFromProb(NDArray prob, int unit_offset, doub return std::make_pair(p_prob[i], i); } } - LOG(INFO) << "prob sum = " << prob_sum << ", sample = " << uniform_sample; ICHECK(false) << "Possibly prob distribution contains NAN."; } @@ -278,37 +179,6 @@ std::pair SampleTopPFromProb(NDArray prob, int unit_offset, doub return sampled_index; } -/*! - * \brief Copy logits or prob distributions from device to CPU. - * The input array is in layout (b, n, v). - * This function flattens the first dimension, returns an NDArray - * in shape (b * n, v). - */ -NDArray CopyLogitsOrProbsToCPU(NDArray arr_on_device, NDArray* arr_on_cpu) { - // arr_on_device: (b, n, v) - ICHECK_EQ(arr_on_device->ndim, 3); - ICHECK(!arr_on_cpu->defined() || (*arr_on_cpu)->ndim == 2); - ICHECK(arr_on_device->device.device_type != kDLCPU); - if (arr_on_cpu->defined()) { - ICHECK_EQ((*arr_on_cpu)->shape[1], arr_on_device->shape[2]); - } - - int64_t init_size = arr_on_cpu->defined() ? (*arr_on_cpu)->shape[0] : 32; - int64_t num_tokens = arr_on_device->shape[0] * arr_on_device->shape[1]; - int64_t vocab_size = arr_on_device->shape[2]; - while (init_size < num_tokens) { - init_size *= 2; - } - if (!arr_on_cpu->defined() || init_size != (*arr_on_cpu)->shape[0]) { - (*arr_on_cpu) = - NDArray::Empty({init_size, vocab_size}, arr_on_device->dtype, DLDevice{kDLCPU, 0}); - } - ICHECK_LE(num_tokens, (*arr_on_cpu)->shape[0]); - NDArray view = arr_on_cpu->CreateView({num_tokens, vocab_size}, arr_on_device->dtype); - view.CopyFrom(arr_on_device); - return view; -} - /********************* CPU Sampler *********************/ class CPUSampler : public SamplerObj { @@ -323,44 +193,68 @@ class CPUSampler : public SamplerObj { } } - std::vector BatchSampleTokens(NDArray logits_on_device, Model model, - Array request_mstates, - Array generation_cfg, + std::vector BatchSampleTokens(NDArray probs_device, // + const Array& request_ids, + const Array& generation_cfg, const std::vector& rngs, std::vector* output_prob_dist, std::vector* output_token_probs) final { - NDArray probs_on_cpu = BatchComputeProb(logits_on_device, /*cum_sequence_length=*/nullptr, - model, request_mstates, generation_cfg); + // probs_device: (n, v) + RECORD_EVENT(trace_recorder_, request_ids, "start sampling"); + CHECK_EQ(probs_device->ndim, 2); + // - Copy probs to CPU + RECORD_EVENT(trace_recorder_, request_ids, "start copy probs to CPU"); + NDArray probs_host = CopyProbsToCPU(probs_device); + RECORD_EVENT(trace_recorder_, request_ids, "finish copy probs to CPU"); + // - Sample tokens from probabilities. - // NOTE: Though we have the probability field in RequestModelState, - // we do not save the probabilities right now. - // We will handle this in the future when we work on speculation. - std::vector output_tokens = SampleTokensFromProbs( - probs_on_cpu, request_mstates, generation_cfg, rngs, output_prob_dist, output_token_probs); - return output_tokens; + ICHECK_EQ(probs_host->shape[0], request_ids.size()); + ICHECK_EQ(probs_host->shape[0], generation_cfg.size()); + ICHECK_EQ(probs_host->shape[0], rngs.size()); + int n = probs_host->shape[0]; + + std::vector sampled_tokens; + sampled_tokens.resize(n); + if (output_prob_dist) { + output_prob_dist->resize(n); + } + if (output_token_probs) { + output_token_probs->resize(n); + } + + tvm::runtime::parallel_for_with_threading_backend( + [this, &sampled_tokens, &probs_host, &generation_cfg, &rngs, &request_ids, output_prob_dist, + output_token_probs](int i) { + RECORD_EVENT(this->trace_recorder_, request_ids[i], "start sample token"); + // Sample top p from probability. + std::pair sample_result = SampleTopPFromProb( + probs_host, i, generation_cfg[i]->temperature < eps_ ? 0.0 : generation_cfg[i]->top_p, + rngs[i]->GetRandomNumber(), output_prob_dist); + sampled_tokens[i] = sample_result.second; + if (output_token_probs) { + (*output_token_probs)[i] = sample_result.first; + } + RECORD_EVENT(this->trace_recorder_, request_ids[i], "finish sample token"); + }, + 0, n); + RECORD_EVENT(trace_recorder_, request_ids, "finish sampling"); + return sampled_tokens; } std::vector> BatchVerifyDraftTokens( - NDArray logits_on_device, const std::vector& cum_verify_lengths, Model model, - const Array& request_mstates, + NDArray probs_device, const Array& request_ids, + const std::vector& cum_verify_lengths, const Array& request_mstates, const Array& generation_cfg, const std::vector& rngs, const std::vector>& draft_output_tokens, const std::vector>& draft_output_token_prob, const std::vector>& draft_output_prob_dist) final { - bool can_compute_prob_in_parallel = CanComputeProbInParallel(generation_cfg); - NDArray logits_or_probs_on_cpu{nullptr}; - Array request_ids = - request_mstates.Map([](const RequestModelState& mstate) { return mstate->request->id; }); - if (can_compute_prob_in_parallel) { - logits_or_probs_on_cpu = BatchComputeProb(logits_on_device, &cum_verify_lengths, model, - request_mstates, generation_cfg); - } else { - RECORD_EVENT(trace_recorder_, request_ids, "start copy logits to CPU"); - logits_or_probs_on_cpu = CopyLogitsOrProbsToCPU(logits_on_device, &logits_or_probs_on_cpu_); - RECORD_EVENT(trace_recorder_, request_ids, "finish copy logits to CPU"); - } - ICHECK(logits_or_probs_on_cpu->device.device_type == kDLCPU); - ICHECK_EQ(logits_or_probs_on_cpu->ndim, 2); + // probs_device: (n, v) + RECORD_EVENT(trace_recorder_, request_ids, "start draft verification"); + CHECK_EQ(probs_device->ndim, 2); + // - Copy probs to CPU + RECORD_EVENT(trace_recorder_, request_ids, "start copy probs to CPU"); + NDArray probs_host = CopyProbsToCPU(probs_device); + RECORD_EVENT(trace_recorder_, request_ids, "finish copy probs to CPU"); int num_sequence = static_cast(cum_verify_lengths.size()) - 1; CHECK_EQ(rngs.size(), num_sequence); @@ -372,20 +266,14 @@ class CPUSampler : public SamplerObj { accepted_tokens.resize(num_sequence); float* __restrict global_p_probs = - static_cast(__builtin_assume_aligned(logits_or_probs_on_cpu->data, 4)); - int vocab_size = logits_or_probs_on_cpu->shape[1]; + static_cast(__builtin_assume_aligned(probs_host->data, 4)); + int vocab_size = probs_host->shape[1]; tvm::runtime::parallel_for_with_threading_backend( [&](int i) { int verify_start = cum_verify_lengths[i]; int verify_end = cum_verify_lengths[i + 1]; for (int cur_token_idx = 0; cur_token_idx < verify_end - verify_start; ++cur_token_idx) { - if (!can_compute_prob_in_parallel) { - SinglePosComputeProbsFromLogitsInplace(logits_or_probs_on_cpu, - verify_start + cur_token_idx, - request_mstates[i], generation_cfg[i]); - } - float* p_probs = global_p_probs + (verify_start + cur_token_idx) * vocab_size; int cur_token = draft_output_tokens[i][cur_token_idx]; float q_value = draft_output_token_prob[i][cur_token_idx]; @@ -422,8 +310,10 @@ class CPUSampler : public SamplerObj { // sample a new token from the new distribution int32_t new_token = - SampleTopPFromProb(logits_or_probs_on_cpu, verify_start + cur_token_idx, - generation_cfg[i]->top_p, rngs[i]->GetRandomNumber()) + SampleTopPFromProb( + probs_host, verify_start + cur_token_idx, + generation_cfg[i]->temperature < eps_ ? 0.0 : generation_cfg[i]->top_p, + rngs[i]->GetRandomNumber()) .second; request_mstates[i]->CommitToken(new_token); accepted_tokens[i].push_back(cur_token); @@ -431,238 +321,42 @@ class CPUSampler : public SamplerObj { } }, 0, num_sequence); + RECORD_EVENT(trace_recorder_, request_ids, "finish draft verification"); return accepted_tokens; } private: - /*! - * \brief Given the generation config of a batch, check if the - * probability distributions needs to be computed on device via softmax. - * \param generation_cfg The input generation config. - * \return A boolean flag indicating if the check result. - */ - bool RequireGPUSoftmax(Array generation_cfg) { - // - Return false if there is customized probability compute function. - if (flogits_to_probs_inplace_.defined()) { - return false; - } - // - Return false if any sampling param has frequency/presence penalty other than 0.0. - // - Return false if any sampling param has repetition penalty other than 1.0. - // - Return false if any sampling param has zero temperature. - for (GenerationConfig cfg : generation_cfg) { - if (cfg->frequency_penalty != 0.0 || cfg->presence_penalty != 0.0 || - cfg->repetition_penalty != 1.0 || cfg->temperature < 1e-6) { - return false; - } - } - return true; - } - - /*! - * \brief Given the generation config of a batch, check if the - * probability distributions need to be computed serially. - */ - bool CanComputeProbInParallel(const Array& generation_cfg) { - for (const GenerationConfig& cfg : generation_cfg) { - if (cfg->frequency_penalty != 0.0 || cfg->presence_penalty != 0.0 || - cfg->repetition_penalty != 1.0) { - return false; - } + /*! \brief Copy prob distributions from device to CPU. */ + NDArray CopyProbsToCPU(NDArray probs_device) { + // probs_device: (n, v) + ICHECK(probs_device->device.device_type != kDLCPU); + if (probs_host_.defined()) { + ICHECK_EQ(probs_host_->shape[1], probs_device->shape[1]); } - return true; - } - /*! - * \brief Compute the probability distribution of the input logits. - * \param logits_on_device The logits to compute probability distribution for. - * \param model The LLM model which contains the softmax - * function on device that might be used to compute probability distribution. - * \param request_mstates The request states of each sequence in - * the batch with regard to the given model. - * \param generation_cfg The generation config of each request - * in the input batch. - * \return The probability distribution of the input logits. - */ - NDArray BatchComputeProb(NDArray logits_on_device, const std::vector* cum_sequence_length, - Model model, const Array& request_mstates, - const Array& generation_cfg) { - ICHECK(logits_on_device.defined()); - ICHECK_EQ(logits_on_device->ndim, 3); - int num_sequence; - if (cum_sequence_length == nullptr) { - ICHECK_EQ(logits_on_device->shape[1], 1) - << "Multi-token sampling for one sequence requiring `cum_sequence_length`."; - num_sequence = logits_on_device->shape[0]; - } else { - ICHECK(!cum_sequence_length->empty()); - num_sequence = static_cast(cum_sequence_length->size()) - 1; - ICHECK_EQ(logits_on_device->shape[0], 1); - ICHECK_EQ(logits_on_device->shape[1], cum_sequence_length->back()); + int64_t init_size = probs_host_.defined() ? probs_host_->shape[0] : 32; + int64_t num_tokens = probs_device->shape[0]; + int64_t vocab_size = probs_device->shape[1]; + while (init_size < num_tokens) { + init_size *= 2; } - ICHECK_EQ(generation_cfg.size(), num_sequence); - ICHECK_EQ(request_mstates.size(), num_sequence); - - Array request_ids = - request_mstates.Map([](const RequestModelState& mstate) { return mstate->request->id; }); - - RECORD_EVENT(trace_recorder_, request_ids, "start query need GPU softmax"); - bool require_gpu_softmax = RequireGPUSoftmax(generation_cfg); - RECORD_EVENT(trace_recorder_, request_ids, "finish query need GPU softmax"); - - // - Compute probabilities from logits. - NDArray logits_or_probs_on_cpu{nullptr}; - if (require_gpu_softmax) { - RECORD_EVENT(trace_recorder_, request_ids, "start GPU softmax"); - Array generation_cfg_for_softmax; - if (cum_sequence_length == nullptr) { - generation_cfg_for_softmax = generation_cfg; - } else { - logits_on_device = logits_on_device.CreateView( - {logits_on_device->shape[1], 1, logits_on_device->shape[2]}, logits_on_device->dtype); - generation_cfg_for_softmax.reserve(logits_on_device->shape[1]); - for (int i = 0; i < num_sequence; ++i) { - for (int pos = cum_sequence_length->at(i); pos < cum_sequence_length->at(i + 1); ++pos) { - generation_cfg_for_softmax.push_back(generation_cfg[i]); - } - } - } - NDArray probs_on_device = - model->SoftmaxWithTemperature(logits_on_device, generation_cfg_for_softmax); - RECORD_EVENT(trace_recorder_, request_ids, "finish GPU softmax"); - RECORD_EVENT(trace_recorder_, request_ids, "start copy probs to CPU"); - logits_or_probs_on_cpu = CopyLogitsOrProbsToCPU(probs_on_device, &logits_or_probs_on_cpu_); - RECORD_EVENT(trace_recorder_, request_ids, "finish copy probs to CPU"); - } else { - RECORD_EVENT(trace_recorder_, request_ids, "start copy logits to CPU"); - logits_or_probs_on_cpu = CopyLogitsOrProbsToCPU(logits_on_device, &logits_or_probs_on_cpu_); - RECORD_EVENT(trace_recorder_, request_ids, "finish copy logits to CPU"); - // The "BatchComputeProbsFromLogitsInplace" function updates - // `logits_or_probs_on_cpu` in place. - BatchComputeProbsFromLogitsInplace(logits_or_probs_on_cpu, cum_sequence_length, - std::move(request_mstates), generation_cfg); + if (!probs_host_.defined() || init_size != probs_host_->shape[0]) { + probs_host_ = + NDArray::Empty({init_size, vocab_size}, probs_device->dtype, DLDevice{kDLCPU, 0}); } - // `CopyLogitsOrProbsToCPU` flattens the first two dimensions. - ICHECK_EQ(logits_or_probs_on_cpu->ndim, 2); - return logits_or_probs_on_cpu; - } - - /*! - * \brief Compute the probability distribution from on-cpu logits for - * a batch of tokens **in place**. - * \param logits The input logits on CPU. - * \param states The request states, which contains the history generated tokens. - * \param generation_cfg The generation config. - * \note The function returns nothing. It in-place updates the input logits array. - */ - void BatchComputeProbsFromLogitsInplace(NDArray logits, - const std::vector* cum_sequence_length, - Array states, - Array generation_cfg) { - // logits: (n, v) - CHECK_EQ(logits->ndim, 2); - CHECK_EQ(logits->device.device_type, kDLCPU); - - // - Invoke environment compute function if exists. - if (flogits_to_probs_inplace_.defined()) { - IntTuple cum_sequence_length_obj; - if (cum_sequence_length != nullptr) { - cum_sequence_length_obj = - IntTuple{cum_sequence_length->begin(), cum_sequence_length->end()}; - } - flogits_to_probs_inplace_(logits, cum_sequence_length_obj, states, generation_cfg); - return; - } - - tvm::runtime::parallel_for_with_threading_backend( - [this, &logits, cum_sequence_length, &states, &generation_cfg](int i) { - int offset_start = cum_sequence_length == nullptr ? i : cum_sequence_length->at(i); - int offset_end = cum_sequence_length == nullptr ? i + 1 : cum_sequence_length->at(i + 1); - for (int offset = offset_start; offset < offset_end; ++offset) { - SinglePosComputeProbsFromLogitsInplace(logits, offset, states[i], generation_cfg[i]); - } - }, - 0, logits->shape[0]); - } - - void SinglePosComputeProbsFromLogitsInplace(NDArray logits, int offset, - const RequestModelState& state, - const GenerationConfig& generation_cfg) { - // - Apply frequency/presence penalty or repetition penalty (inplace). - if (generation_cfg->frequency_penalty != 0.0 || generation_cfg->presence_penalty != 0.0) { - RECORD_EVENT(trace_recorder_, state->request->id, "start frequency/presence penalty"); - ApplyFrequencyAndPresencePenaltyOnCPU(logits, offset, state, - generation_cfg->frequency_penalty, - generation_cfg->presence_penalty); - RECORD_EVENT(trace_recorder_, state->request->id, "finish frequency/presence penalty"); - } else if (generation_cfg->repetition_penalty != 1.0) { - RECORD_EVENT(trace_recorder_, state->request->id, "start repetition penalty"); - ApplyRepetitionPenaltyOnCPU(logits, offset, state, generation_cfg->repetition_penalty); - RECORD_EVENT(trace_recorder_, state->request->id, "finish repetition penalty"); - } - // - Compute probability (inplace) from logits. - // Using softmax if temperature is non-zero. - // Or set probability of the max-logit position to 1. - if (generation_cfg->temperature >= 1e-6) { - RECORD_EVENT(trace_recorder_, state->request->id, "start CPU softmax"); - ApplySoftmaxWithTemperatureOnCPU(logits, offset, generation_cfg->temperature); - RECORD_EVENT(trace_recorder_, state->request->id, "finish CPU softmax"); - } else { - RECORD_EVENT(trace_recorder_, state->request->id, "start argmax"); - SetProbWithArgmaxOnCPU(logits, offset); - RECORD_EVENT(trace_recorder_, state->request->id, "finish argmax"); - } - } - - std::vector SampleTokensFromProbs(NDArray probs, - Array request_mstates, - Array generation_cfg, - const std::vector& rngs, - std::vector* output_prob_dist, - std::vector* output_token_probs) { - // probs: (n, v) - CHECK_EQ(probs->ndim, 2); - CHECK_EQ(probs->device.device_type, kDLCPU); - ICHECK_EQ(probs->shape[0], request_mstates.size()); - ICHECK_EQ(probs->shape[0], generation_cfg.size()); - ICHECK_EQ(probs->shape[0], rngs.size()); - - Array request_ids = - request_mstates.Map([](const RequestModelState& mstate) { return mstate->request->id; }); - - int n = probs->shape[0]; - std::vector sampled_tokens; - sampled_tokens.resize(n); - if (output_prob_dist) { - output_prob_dist->resize(n); - } - if (output_token_probs) { - output_token_probs->resize(n); - } - - tvm::runtime::parallel_for_with_threading_backend( - [this, &sampled_tokens, &probs, &generation_cfg, &rngs, &request_ids, output_prob_dist, - output_token_probs](int i) { - RECORD_EVENT(this->trace_recorder_, request_ids[i], "start sample token"); - // Sample top p from probability. - std::pair sample_result = SampleTopPFromProb( - probs, i, generation_cfg[i]->top_p, rngs[i]->GetRandomNumber(), output_prob_dist); - sampled_tokens[i] = sample_result.second; - if (output_token_probs) { - (*output_token_probs)[i] = sample_result.first; - } - RECORD_EVENT(this->trace_recorder_, request_ids[i], "finish sample token"); - }, - 0, n); - return sampled_tokens; + ICHECK_LE(num_tokens, probs_host_->shape[0]); + NDArray view = probs_host_.CreateView({num_tokens, vocab_size}, probs_device->dtype); + view.CopyFrom(probs_device); + return view; } /*! \brief The event trace recorder for requests. */ Optional trace_recorder_; /*! \brief Customized function which computes prob distribution from logits */ PackedFunc flogits_to_probs_inplace_; - /*! \brief Shared array for logits and probability distributions on cpu. */ - NDArray logits_or_probs_on_cpu_{nullptr}; - const float eps_ = 1e-9; + /*! \brief Probability distribution array on CPU. */ + NDArray probs_host_{nullptr}; + const float eps_ = 1e-5; }; /*********************** Sampler ***********************/ diff --git a/cpp/serve/sampler.h b/cpp/serve/sampler.h index d74a7ef400..ac4820db64 100644 --- a/cpp/serve/sampler.h +++ b/cpp/serve/sampler.h @@ -32,12 +32,9 @@ using namespace tvm::runtime; class SamplerObj : public Object { public: /*! - * \brief Sample tokens from the input batch of logits. - * \param logits_on_device The logits to sample tokens from. - * \param model The LLM model which contains the softmax - * function on device that might be used to compute probability distribution. - * \param request_mstates The request states of each sequence in - * the batch with regard to the given model. + * \brief Sample tokens from the input batch of prob distribution on device. + * \param probs_device The prob distributions on GPU to sample tokens from. + * \param request_ids The id of each request. * \param generation_cfg The generation config of each request * in the input batch. * \param rngs The random number generator of each sequence. @@ -46,18 +43,17 @@ class SamplerObj : public Object { * \return The sampled tokens, one for each request in the batch. */ virtual std::vector BatchSampleTokens( - NDArray logits_on_device, Model model, Array request_mstates, - Array generation_cfg, const std::vector& rngs, + NDArray probs_device, const Array& request_ids, + const Array& generation_cfg, const std::vector& rngs, std::vector* output_prob_dist = nullptr, std::vector* output_token_probs = nullptr) = 0; /*! * \brief Verify draft tokens generated by small models in the large model * in speculative decoding. The input corresponds to a batch of sequences. - * \param logits_on_device The logits of the large model. + * \param probs_device The prob distributions on GPU to sample tokens from. + * \param request_ids The id of each request. * \param cum_verify_lengths The cumulative draft lengths to verify of all sequences. - * \param model The LLM model which contains the softmax - * function on device that might be used to compute probability distribution. * \param request_mstates The request states of each sequence in * the batch with regard to the large model. * \param generation_cfg The generation config of each request @@ -72,8 +68,8 @@ class SamplerObj : public Object { * \return The list of accepted tokens for each request. */ virtual std::vector> BatchVerifyDraftTokens( - NDArray logits_on_device, const std::vector& cum_verify_lengths, Model model, - const Array& request_mstates, + NDArray probs_device, const Array& request_ids, + const std::vector& cum_verify_lengths, const Array& request_mstates, const Array& generation_cfg, const std::vector& rngs, const std::vector>& draft_output_tokens, const std::vector>& draft_output_token_prob, diff --git a/python/mlc_chat/compiler_pass/attach_to_ir_module.py b/python/mlc_chat/compiler_pass/attach_to_ir_module.py index 84a6c76243..625c4ebfee 100644 --- a/python/mlc_chat/compiler_pass/attach_to_ir_module.py +++ b/python/mlc_chat/compiler_pass/attach_to_ir_module.py @@ -1,8 +1,10 @@ """A couple of passes that simply attach additional information onto the IRModule.""" + from typing import Dict import tvm from tvm import IRModule, relax, tir +from tvm.script import tir as T @tvm.transform.module_pass(opt_level=0, name="AttachVariableBounds") @@ -44,3 +46,114 @@ def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IR if isinstance(func, relax.Function): mod[g_var] = func.with_attr("relax.memory_plan_dynamic_func_output", True) return mod + + +@tvm.transform.module_pass(opt_level=0, name="AttachLogitProcessFunc") +class AttachLogitProcessFunc: # pylint: disable=too-few-public-methods + """Attach logit processing TIR functions to IRModule.""" + + def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: + """Entrypoint""" + mod = mod.clone() + mod["apply_logit_bias_inplace"] = _apply_logit_bias_inplace + mod["apply_penalty_inplace"] = _apply_penalty_inplace + mod["apply_bitmask_inplace"] = _apply_bitmask_inplace + return mod + + +@T.prim_func +def _apply_logit_bias_inplace( + var_logits: T.handle, + var_pos2seq_id: T.handle, + var_token_ids: T.handle, + var_logit_bias: T.handle, +) -> None: + """Function that applies logit bias in place.""" + T.func_attr( + {"global_symbol": "apply_logit_bias_inplace", "tir.noalias": True, "tir.is_scheduled": True} + ) + batch_size = T.int32(is_size_var=True) + vocab_size = T.int32(is_size_var=True) + num_token = T.int32(is_size_var=True) + logits = T.match_buffer(var_logits, (batch_size, vocab_size), "float32") + # seq_ids + pos2seq_id = T.match_buffer(var_pos2seq_id, (num_token,), "int32") + token_ids = T.match_buffer(var_token_ids, (num_token,), "int32") + logit_bias = T.match_buffer(var_logit_bias, (num_token,), "float32") + + for p0 in T.thread_binding(0, (num_token + 1023) // 1024, "blockIdx.x"): + for p1 in T.thread_binding(0, 1024, "threadIdx.x"): + with T.block("block"): + vp = T.axis.spatial(num_token, p0 * 1024 + p1) + T.where(p0 * 1024 + p1 < num_token) + logits[pos2seq_id[vp], token_ids[vp]] += logit_bias[vp] + + +@T.prim_func +def _apply_penalty_inplace( + var_logits: T.handle, + var_seq_ids: T.handle, + var_pos2seq_id: T.handle, + var_token_ids: T.handle, + var_token_cnt: T.handle, + var_penalties: T.handle, +) -> None: + """Function that applies penalties in place.""" + T.func_attr( + {"global_symbol": "apply_penalty_inplace", "tir.noalias": True, "tir.is_scheduled": True} + ) + batch_size = T.int32(is_size_var=True) + vocab_size = T.int32(is_size_var=True) + num_token = T.int32(is_size_var=True) + num_seq = T.int32(is_size_var=True) + logits = T.match_buffer(var_logits, (batch_size, vocab_size), "float32") + seq_ids = T.match_buffer(var_seq_ids, (num_seq,), "int32") + pos2seq_id = T.match_buffer(var_pos2seq_id, (num_token,), "int32") + token_ids = T.match_buffer(var_token_ids, (num_token,), "int32") + token_cnt = T.match_buffer(var_token_cnt, (num_token,), "int32") + penalties = T.match_buffer(var_penalties, (num_seq, 3), "float32") + + for p0 in T.thread_binding(0, (num_token + 1023) // 1024, "blockIdx.x"): + for p1 in T.thread_binding(0, 1024, "threadIdx.x"): + with T.block("block"): + vp = T.axis.spatial(num_token, p0 * 1024 + p1) + T.where(p0 * 1024 + p1 < num_token) + # Penalties: (presence_penalty, frequency_penalty, repetition_penalty) + logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] -= ( + penalties[pos2seq_id[vp], 0] + token_cnt[vp] * penalties[pos2seq_id[vp], 1] + ) + logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] = T.if_then_else( + logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] > 0, + logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] * penalties[pos2seq_id[vp], 2], + logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] / penalties[pos2seq_id[vp], 2], + ) + + +@T.prim_func +def _apply_bitmask_inplace( + var_logits: T.handle, + var_seq_ids: T.handle, + var_bitmask: T.handle, +) -> None: + """Function that applies vocabulary masking in place.""" + T.func_attr( + {"global_symbol": "apply_bitmask_inplace", "tir.noalias": True, "tir.is_scheduled": True} + ) + batch_size = T.int32(is_size_var=True) + vocab_size = T.int32(is_size_var=True) + num_seq = T.int32(is_size_var=True) + logits = T.match_buffer(var_logits, (batch_size, vocab_size), "float32") + seq_ids = T.match_buffer(var_seq_ids, (num_seq,), "int32") + bitmask = T.match_buffer(var_bitmask, (num_seq, (vocab_size + 31 // 32)), "int32") + + for fused_s_v_0 in T.thread_binding(0, (num_seq * vocab_size + 1023) // 1024, "blockIdx.x"): + for fused_s_v_1 in T.thread_binding(0, 1024, "threadIdx.x"): + with T.block("block"): + vs = T.axis.spatial(num_seq, (fused_s_v_0 * 1024 + fused_s_v_1) // vocab_size) + vv = T.axis.spatial(vocab_size, (fused_s_v_0 * 1024 + fused_s_v_1) % vocab_size) + T.where(fused_s_v_0 * 1024 + fused_s_v_1 < num_seq * vocab_size) + logits[seq_ids[vs], vv] = T.if_then_else( + (bitmask[vs, vv // 32] >> (vv % 32)) & 1 == 1, + logits[seq_ids[vs], vv], + T.float32(-1e10), + ) diff --git a/python/mlc_chat/compiler_pass/pipeline.py b/python/mlc_chat/compiler_pass/pipeline.py index 20676187bd..98922c6139 100644 --- a/python/mlc_chat/compiler_pass/pipeline.py +++ b/python/mlc_chat/compiler_pass/pipeline.py @@ -13,6 +13,7 @@ from .attach_to_ir_module import ( AttachAdditionalPrimFuncs, + AttachLogitProcessFunc, AttachMemoryPlanAttr, AttachVariableBounds, ) @@ -89,6 +90,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I # Phase 0. Add additional information for compilation and remove unused Relax func RewriteKVCacheCreation(target, flashinfer, metadata), AttachVariableBounds(variable_bounds), + AttachLogitProcessFunc(), AttachAdditionalPrimFuncs(additional_tirs), AttachMemoryPlanAttr(), tvm.tir.transform.BindTarget(tvm.target.Target.current(allow_none=False)), diff --git a/python/mlc_chat/protocol/openai_api_protocol.py b/python/mlc_chat/protocol/openai_api_protocol.py index 128d7e99d7..36b75f81a5 100644 --- a/python/mlc_chat/protocol/openai_api_protocol.py +++ b/python/mlc_chat/protocol/openai_api_protocol.py @@ -63,7 +63,7 @@ class CompletionRequest(BaseModel): echo: bool = False frequency_penalty: float = 0.0 presence_penalty: float = 0.0 - logit_bias: Optional[Dict[str, float]] = None + logit_bias: Optional[Dict[int, float]] = None logprobs: Optional[int] = None max_tokens: int = 16 n: int = 1 @@ -84,6 +84,22 @@ def check_penalty_range(cls, penalty_value: float) -> float: raise ValueError("Penalty value should be in range [-2, 2].") return penalty_value + @field_validator("logit_bias") + @classmethod + def check_logit_bias( + cls, logit_bias_value: Optional[Dict[int, float]] + ) -> Optional[Dict[int, float]]: + """Check if the logit bias key is given as an integer.""" + if logit_bias_value is None: + return None + for token_id, bias in logit_bias_value.items(): + if abs(bias) > 100: + raise ValueError( + "Logit bias value should be in range [-100, 100], while value " + f"{bias} is given for token id {token_id}" + ) + return logit_bias_value + class CompletionResponseChoice(BaseModel): finish_reason: Optional[Literal["stop", "length"]] = None @@ -149,7 +165,7 @@ class ChatCompletionRequest(BaseModel): model: str frequency_penalty: float = 0.0 presence_penalty: float = 0.0 - logit_bias: Optional[Dict[str, float]] = None + logit_bias: Optional[Dict[int, float]] = None max_tokens: Optional[int] = None n: int = 1 response_format: Literal["text", "json_object"] = "text" @@ -163,6 +179,30 @@ class ChatCompletionRequest(BaseModel): user: Optional[str] = None ignore_eos: bool = False + @field_validator("frequency_penalty", "presence_penalty") + @classmethod + def check_penalty_range(cls, penalty_value: float) -> float: + """Check if the penalty value is in range [-2, 2].""" + if penalty_value < -2 or penalty_value > 2: + raise ValueError("Penalty value should be in range [-2, 2].") + return penalty_value + + @field_validator("logit_bias") + @classmethod + def check_logit_bias( + cls, logit_bias_value: Optional[Dict[int, float]] + ) -> Optional[Dict[int, float]]: + """Check if the logit bias key is given as an integer.""" + if logit_bias_value is None: + return None + for token_id, bias in logit_bias_value.items(): + if abs(bias) > 100: + raise ValueError( + "Logit bias value should be in range [-100, 100], while value " + f"{bias} is given for token id {token_id}" + ) + return logit_bias_value + class ChatCompletionResponseChoice(BaseModel): finish_reason: Optional[Literal["stop", "length", "tool_calls", "error"]] = None @@ -214,7 +254,6 @@ def openai_api_get_unsupported_fields( """Get the unsupported fields in the request.""" unsupported_field_default_values: List[Tuple[str, Any]] = [ ("best_of", 1), - ("logit_bias", None), ("logprobs", None), ("n", 1), ("response_format", "text"), @@ -238,6 +277,7 @@ def openai_api_get_generation_config( "max_tokens", "frequency_penalty", "presence_penalty", + "logit_bias", "seed", "ignore_eos", ] diff --git a/python/mlc_chat/serve/config.py b/python/mlc_chat/serve/config.py index 4223148e8e..1962b61215 100644 --- a/python/mlc_chat/serve/config.py +++ b/python/mlc_chat/serve/config.py @@ -1,7 +1,7 @@ """Configuration dataclasses used in MLC LLM serving""" import json from dataclasses import asdict, dataclass, field -from typing import List, Optional +from typing import Dict, List, Optional @dataclass @@ -31,6 +31,9 @@ class GenerationConfig: # pylint: disable=too-many-instance-attributes It will be suppressed when any of frequency_penalty and presence_penalty is non-zero. + logit_bias : Optional[Dict[int, float]] + The bias logit value added to selected tokens prior to sampling. + max_tokens : Optional[int] The maximum number of generated tokens, or None, in which case the generation will not stop @@ -56,6 +59,7 @@ class GenerationConfig: # pylint: disable=too-many-instance-attributes frequency_penalty: float = 0.0 presence_penalty: float = 0.0 repetition_penalty: float = 1.0 + logit_bias: Optional[Dict[int, float]] = field(default_factory=dict) max_tokens: Optional[int] = 128 seed: Optional[int] = None diff --git a/tests/python/serve/server/test_server.py b/tests/python/serve/server/test_server.py index 65c63c2166..0721e97190 100644 --- a/tests/python/serve/server/test_server.py +++ b/tests/python/serve/server/test_server.py @@ -484,6 +484,54 @@ def test_openai_v1_completions_temperature( ) +@pytest.mark.parametrize("stream", [False, True]) +def test_openai_v1_completions_logit_bias( + served_model: Tuple[str, str], + launch_server, # pylint: disable=unused-argument + stream: bool, +): + # `served_model` and `launch_server` are pytest fixtures + # defined in conftest.py. + + # NOTE: This test only tests that the system does not break on logit bias. + # The test does not promise the correctness of logit bias handling. + + prompt = "What's the meaning of life?" + max_tokens = 128 + payload = { + "model": served_model[0], + "prompt": prompt, + "max_tokens": max_tokens, + "stream": stream, + "logit_bias": {338: -100}, # 338 is " is" in Llama tokenizer. + } + + response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=60) + if not stream: + check_openai_nonstream_response( + response.json(), + is_chat_completion=False, + model=served_model[0], + object_str="text_completion", + num_choices=1, + finish_reason="length", + ) + else: + responses = [] + for chunk in response.iter_lines(chunk_size=512): + if not chunk or chunk == b"data: [DONE]": + continue + responses.append(json.loads(chunk.decode("utf-8")[6:])) + check_openai_stream_response( + responses, + is_chat_completion=False, + model=served_model[0], + object_str="text_completion", + num_choices=1, + finish_reason="length", + ) + + @pytest.mark.parametrize("stream", [False, True]) def test_openai_v1_completions_presence_frequency_penalty( served_model: Tuple[str, str], @@ -889,26 +937,6 @@ def test_openai_v1_chat_completions_system_prompt_wrong_pos( assert num_chunks == 1 -def test_openai_v1_chat_completions_unsupported_args( - served_model: Tuple[str, str], - launch_server, # pylint: disable=unused-argument -): - # `served_model` and `launch_server` are pytest fixtures - # defined in conftest.py. - - # Right now "tool_choice" is unsupported. - tool_choice = "auto" - payload = { - "model": served_model[0], - "messages": CHAT_COMPLETION_MESSAGES[0], - "tool_choice": tool_choice, - } - - response = requests.post(OPENAI_V1_CHAT_COMPLETION_URL, json=payload, timeout=60) - error_msg_prefix = 'Request fields "tool_choice" are not supported right now.' - expect_error(response.json(), msg_prefix=error_msg_prefix) - - def test_debug_dump_event_trace( served_model: Tuple[str, str], launch_server, # pylint: disable=unused-argument @@ -946,6 +974,8 @@ def test_debug_dump_event_trace( test_openai_v1_completions_stop_str(MODEL, None, stream=True) test_openai_v1_completions_temperature(MODEL, None, stream=False) test_openai_v1_completions_temperature(MODEL, None, stream=True) + test_openai_v1_completions_logit_bias(MODEL, None, stream=False) + test_openai_v1_completions_logit_bias(MODEL, None, stream=True) test_openai_v1_completions_presence_frequency_penalty(MODEL, None, stream=False) test_openai_v1_completions_presence_frequency_penalty(MODEL, None, stream=True) test_openai_v1_completions_seed(MODEL, None) @@ -965,6 +995,5 @@ def test_debug_dump_event_trace( test_openai_v1_chat_completions_ignore_eos(MODEL, None, stream=True) test_openai_v1_chat_completions_system_prompt_wrong_pos(MODEL, None, stream=False) test_openai_v1_chat_completions_system_prompt_wrong_pos(MODEL, None, stream=True) - test_openai_v1_chat_completions_unsupported_args(MODEL, None) test_debug_dump_event_trace(MODEL, None)