Skip to content

Commit

Permalink
[Serving][Refactor] Logit processor and logit bias support
Browse files Browse the repository at this point in the history
This PR refactors the existing logit processing pipeline
with a unfiied logit processor class. The logit processor class
exposes two functions:
- `InplaceUpdateLogits`, which takes in the raw logits produced
by the model, and apply logit bias (which is introduced in this PR),
presence/frequency/repetition penalties, and token id mask in
order when needed.
- `ComputeProbsFromLogits`, which takes in the updated logits,
and invoke softmax with temperature to compute the probability
distribution.

The logit processor completely runs on GPU. This being said,
all the logit bias / penalty / mask application and the softmax
is backed by GPU kernels. This is a highlight difference compared
with the logit processing prior to this PR, where the processing
happens on CPU, and softmax also happens on CPU when any logit
process is needed.

With the unified logit processor, we simplified the interface
of handling model's output logits in engine actions to make it
cleaner. We also simplified the interface of Sampler.

Preliminary results show that LogitProcessor brings a bit perf
improvement when any processing is needed.
  • Loading branch information
MasterJH5574 committed Feb 23, 2024
1 parent 43d38ee commit d861a7d
Show file tree
Hide file tree
Showing 23 changed files with 1,008 additions and 550 deletions.
22 changes: 22 additions & 0 deletions cpp/serve/config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,22 @@ GenerationConfig::GenerationConfig(String config_json_str) {
n->repetition_penalty = config["repetition_penalty"].get<double>();
CHECK(n->repetition_penalty > 0) << "Repetition penalty must be a positive number!";
}
if (config.count("logit_bias")) {
CHECK(config["logit_bias"].is<picojson::null>() || config["logit_bias"].is<picojson::object>());
if (config["logit_bias"].is<picojson::object>()) {
picojson::object logit_bias_json = config["logit_bias"].get<picojson::object>();
std::vector<std::pair<int, float>> logit_bias;
logit_bias.reserve(logit_bias_json.size());
for (auto [token_id_str, bias] : logit_bias_json) {
CHECK(bias.is<double>());
double bias_value = bias.get<double>();
CHECK_LE(std::fabs(bias_value), 100.0)
<< "Logit bias value should be in range [-100, 100].";
logit_bias.emplace_back(std::stoi(token_id_str), bias_value);
}
n->logit_bias = std::move(logit_bias);
}
}
if (config.count("max_tokens")) {
if (config["max_tokens"].is<int64_t>()) {
n->max_tokens = config["max_tokens"].get<int64_t>();
Expand Down Expand Up @@ -115,6 +131,12 @@ String GenerationConfigNode::AsJSONString() const {
config["max_tokens"] = picojson::value(static_cast<int64_t>(this->max_tokens));
config["seed"] = picojson::value(static_cast<int64_t>(this->seed));

picojson::object logit_bias_obj;
for (auto [token_id, bias] : logit_bias) {
logit_bias_obj[std::to_string(token_id)] = picojson::value(static_cast<double>(bias));
}
config["logit_bias"] = picojson::value(logit_bias_obj);

picojson::array stop_strs_arr;
for (String stop_str : this->stop_strs) {
stop_strs_arr.push_back(picojson::value(stop_str));
Expand Down
1 change: 1 addition & 0 deletions cpp/serve/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class GenerationConfigNode : public Object {
double frequency_penalty = 0.0;
double presence_penalty = 0.0;
double repetition_penalty = 1.0;
std::vector<std::pair<int, float>> logit_bias;
int seed;
bool ignore_eos = false;

Expand Down
31 changes: 20 additions & 11 deletions cpp/serve/engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "engine_actions/action_commons.h"
#include "engine_state.h"
#include "event_trace_recorder.h"
#include "logit_processor.h"
#include "model.h"
#include "request.h"
#include "request_state.h"
Expand Down Expand Up @@ -53,10 +54,10 @@ class EngineImpl : public Engine {
this->engine_mode_ = EngineMode(engine_mode_json_str);
this->request_stream_callback_ = std::move(request_stream_callback);
this->trace_recorder_ = trace_recorder;
this->sampler_ = Sampler::Create(/*sampler_kind=*/"cpu", trace_recorder_);
this->tokenizer_ = Tokenizer::FromPath(tokenizer_path);
this->token_table_ = tokenizer_->TokenTable();
// Step 2. Initialize each model independently.
// Create the logit processor and sampler.
this->models_.clear();
for (const auto& model_info : model_infos) {
TVMArgValue model_lib = std::get<0>(model_info);
Expand All @@ -71,26 +72,35 @@ class EngineImpl : public Engine {
<< this->max_single_sequence_length_;
this->models_.push_back(model);
}
int max_logit_processor_num_token = kv_cache_config_->max_num_sequence;
if (engine_mode_->enable_speculative) {
max_logit_processor_num_token *= engine_mode_->spec_draft_length;
}
LogitProcessor logit_processor =
this->models_[0]->CreateLogitProcessor(max_logit_processor_num_token, trace_recorder);
Sampler sampler = Sampler::Create(/*sampler_kind=*/"cpu", trace_recorder_);
// Step 3. Initialize engine actions that represent state transitions.
if (this->engine_mode_->enable_speculative) {
// Speculative decoding is only possible for more than one model.
ICHECK_GT(this->models_.size(), 1U);
this->actions_ = {
EngineAction::NewRequestPrefill(this->models_, //
this->sampler_, //
logit_processor, //
sampler, //
this->kv_cache_config_, //
this->trace_recorder_),
EngineAction::BatchDraft(this->models_, this->sampler_, this->trace_recorder_,
EngineAction::BatchDraft(this->models_, logit_processor, sampler, this->trace_recorder_,
this->engine_mode_->spec_draft_length),
EngineAction::BatchVerify(this->models_, this->sampler_, this->kv_cache_config_,
EngineAction::BatchVerify(this->models_, logit_processor, sampler, this->kv_cache_config_,
this->trace_recorder_)};
} else {
this->actions_ = {
EngineAction::NewRequestPrefill(this->models_, //
this->sampler_, //
this->kv_cache_config_, //
this->trace_recorder_),
EngineAction::BatchDecode(this->models_, this->sampler_, this->trace_recorder_)};
this->actions_ = {EngineAction::NewRequestPrefill(this->models_, //
logit_processor, //
sampler, //
this->kv_cache_config_, //
this->trace_recorder_),
EngineAction::BatchDecode(this->models_, logit_processor, sampler,
this->trace_recorder_)};
}
// Step 4. Automatically set the threading backend max concurrency.
SetThreadMaxConcurrency();
Expand Down Expand Up @@ -196,7 +206,6 @@ class EngineImpl : public Engine {
KVCacheConfig kv_cache_config_;
EngineMode engine_mode_;
int max_single_sequence_length_;
Sampler sampler_;
Tokenizer tokenizer_;
std::vector<std::string> token_table_;
// Models
Expand Down
18 changes: 10 additions & 8 deletions cpp/serve/engine_actions/action.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,14 @@ class EngineAction : public ObjectRef {
* \brief Create the action that prefills requests in the `waiting_queue`
* of the engine state.
* \param models The models to run prefill in.
* \param logit_processor The logit processor.
* \param sampler The sampler to sample new tokens.
* \param kv_cache_config The KV cache config to help decide prefill is doable.
* \param trace_recorder The event trace recorder for requests.
* \return The created action object.
*/
static EngineAction NewRequestPrefill(Array<Model> models, Sampler sampler,
KVCacheConfig kv_cache_config,
static EngineAction NewRequestPrefill(Array<Model> models, LogitProcessor logit_processor,
Sampler sampler, KVCacheConfig kv_cache_config,
Optional<EventTraceRecorder> trace_recorder);
/*!
* \brief Create the action that runs one-step decode for requests in the
Expand All @@ -74,8 +75,8 @@ class EngineAction : public ObjectRef {
* \param trace_recorder The event trace recorder for requests.
* \return The created action object.
*/
static EngineAction BatchDecode(Array<Model> models, Sampler sampler,
Optional<EventTraceRecorder> trace_recorder);
static EngineAction BatchDecode(Array<Model> models, LogitProcessor logit_processor,
Sampler sampler, Optional<EventTraceRecorder> trace_recorder);

/*!
* \brief Create the action that runs one-step speculative draft proposal for
Expand All @@ -88,8 +89,9 @@ class EngineAction : public ObjectRef {
* \param draft_length The number of draft proposal rounds.
* \return The created action object.
*/
static EngineAction BatchDraft(Array<Model> models, Sampler sampler,
Optional<EventTraceRecorder> trace_recorder, int draft_length = 4);
static EngineAction BatchDraft(Array<Model> models, LogitProcessor logit_processor,
Sampler sampler, Optional<EventTraceRecorder> trace_recorder,
int draft_length = 4);

/*!
* \brief Create the action that runs one-step speculative verification for requests in the
Expand All @@ -102,8 +104,8 @@ class EngineAction : public ObjectRef {
* \param trace_recorder The event trace recorder for requests.
* \return The created action object.
*/
static EngineAction BatchVerify(Array<Model> models, Sampler sampler,
KVCacheConfig kv_cache_config,
static EngineAction BatchVerify(Array<Model> models, LogitProcessor logit_processor,
Sampler sampler, KVCacheConfig kv_cache_config,
Optional<EventTraceRecorder> trace_recorder);

TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(EngineAction, ObjectRef, EngineActionObj);
Expand Down
27 changes: 19 additions & 8 deletions cpp/serve/engine_actions/batch_decode.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ namespace serve {
*/
class BatchDecodeActionObj : public EngineActionObj {
public:
explicit BatchDecodeActionObj(Array<Model> models, Sampler sampler,
Optional<EventTraceRecorder> trace_recorder)
explicit BatchDecodeActionObj(Array<Model> models, LogitProcessor logit_processor,
Sampler sampler, Optional<EventTraceRecorder> trace_recorder)
: models_(std::move(models)),
logit_processor_(std::move(logit_processor)),
sampler_(std::move(sampler)),
trace_recorder_(std::move(trace_recorder)) {}

Expand Down Expand Up @@ -92,11 +93,17 @@ class BatchDecodeActionObj : public EngineActionObj {
ICHECK_EQ(logits->shape[0], embeddings->shape[0]);
ICHECK_EQ(logits->shape[1], 1);

// - Update logits.
logits = logits.CreateView({num_requests, logits->shape[2]}, logits->dtype);
logit_processor_->InplaceUpdateLogits(logits, generation_cfg, mstates, request_ids);

// - Compute probability distributions.
NDArray probs_device =
logit_processor_->ComputeProbsFromLogits(logits, generation_cfg, request_ids);

// - Sample tokens.
RECORD_EVENT(trace_recorder_, request_ids, "start sampling");
std::vector<int32_t> next_tokens =
sampler_->BatchSampleTokens(logits, models_[0], mstates, generation_cfg, rngs);
RECORD_EVENT(trace_recorder_, request_ids, "finish sampling");
sampler_->BatchSampleTokens(probs_device, request_ids, generation_cfg, rngs);
ICHECK_EQ(next_tokens.size(), num_requests);

// - Update the committed tokens of states.
Expand All @@ -122,16 +129,20 @@ class BatchDecodeActionObj : public EngineActionObj {
* models, the `Step` function of the created action will not take effect.
*/
Array<Model> models_;
/*! \brief The logit processor. */
LogitProcessor logit_processor_;
/*! \brief The sampler to sample new tokens. */
Sampler sampler_;
/*! \brief Event trace recorder. */
Optional<EventTraceRecorder> trace_recorder_;
};

EngineAction EngineAction::BatchDecode(Array<Model> models, Sampler sampler,
EngineAction EngineAction::BatchDecode(Array<Model> models, LogitProcessor logit_processor,
Sampler sampler,
Optional<EventTraceRecorder> trace_recorder) {
return EngineAction(make_object<BatchDecodeActionObj>(std::move(models), std::move(sampler),
std::move(trace_recorder)));
return EngineAction(
make_object<BatchDecodeActionObj>(std::move(models), std::move(logit_processor),
std::move(sampler), std::move(trace_recorder)));
}

} // namespace serve
Expand Down
26 changes: 18 additions & 8 deletions cpp/serve/engine_actions/batch_draft.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ namespace serve {
*/
class BatchDraftActionObj : public EngineActionObj {
public:
explicit BatchDraftActionObj(Array<Model> models, Sampler sampler,
explicit BatchDraftActionObj(Array<Model> models, LogitProcessor logit_processor, Sampler sampler,
Optional<EventTraceRecorder> trace_recorder, int draft_length)
: models_(std::move(models)),
logit_processor_(std::move(logit_processor)),
sampler_(std::move(sampler)),
trace_recorder_(std::move(trace_recorder)),
draft_length_(draft_length) {
Expand Down Expand Up @@ -102,13 +103,19 @@ class BatchDraftActionObj : public EngineActionObj {
ICHECK_EQ(logits->shape[0], embeddings->shape[0]);
ICHECK_EQ(logits->shape[1], 1);

// - Update logits.
logits = logits.CreateView({num_requests, logits->shape[2]}, logits->dtype);
logit_processor_->InplaceUpdateLogits(logits, generation_cfg, mstates, request_ids);

// - Compute probability distributions.
NDArray probs_device =
logit_processor_->ComputeProbsFromLogits(logits, generation_cfg, request_ids);

// - Sample tokens.
RECORD_EVENT(trace_recorder_, request_ids, "start proposal sampling");
std::vector<NDArray> prob_dist;
std::vector<float> token_probs;
std::vector<int32_t> next_tokens = sampler_->BatchSampleTokens(
logits, models_[model_id], mstates, generation_cfg, rngs, &prob_dist, &token_probs);
RECORD_EVENT(trace_recorder_, request_ids, "finish proposal sampling");
probs_device, request_ids, generation_cfg, rngs, &prob_dist, &token_probs);
ICHECK_EQ(next_tokens.size(), num_requests);

// - Update the draft tokens, prob dist, token probs of states.
Expand Down Expand Up @@ -143,6 +150,8 @@ class BatchDraftActionObj : public EngineActionObj {

/*! \brief The model to run draft generation in speculative decoding. */
Array<Model> models_;
/*! \brief The logit processor. */
LogitProcessor logit_processor_;
/*! \brief The sampler to sample new tokens. */
Sampler sampler_;
/*! \brief Event trace recorder. */
Expand All @@ -151,11 +160,12 @@ class BatchDraftActionObj : public EngineActionObj {
int draft_length_;
};

EngineAction EngineAction::BatchDraft(Array<Model> models, Sampler sampler,
Optional<EventTraceRecorder> trace_recorder,
EngineAction EngineAction::BatchDraft(Array<Model> models, LogitProcessor logit_processor,
Sampler sampler, Optional<EventTraceRecorder> trace_recorder,
int draft_length) {
return EngineAction(make_object<BatchDraftActionObj>(std::move(models), std::move(sampler),
std::move(trace_recorder), draft_length));
return EngineAction(make_object<BatchDraftActionObj>(
std::move(models), std::move(logit_processor), std::move(sampler), std::move(trace_recorder),
draft_length));
}

} // namespace serve
Expand Down
31 changes: 22 additions & 9 deletions cpp/serve/engine_actions/batch_verify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@ namespace serve {
*/
class BatchVerifyActionObj : public EngineActionObj {
public:
explicit BatchVerifyActionObj(Array<Model> models, Sampler sampler, KVCacheConfig kv_cache_config,
explicit BatchVerifyActionObj(Array<Model> models, LogitProcessor logit_processor,
Sampler sampler, KVCacheConfig kv_cache_config,
Optional<EventTraceRecorder> trace_recorder)
: models_(std::move(models)),
logit_processor_(std::move(logit_processor)),
sampler_(std::move(sampler)),
kv_cache_config_(std::move(kv_cache_config)),
trace_recorder_(std::move(trace_recorder)),
Expand Down Expand Up @@ -103,13 +105,22 @@ class BatchVerifyActionObj : public EngineActionObj {
ICHECK_EQ(logits->shape[0], 1);
ICHECK_EQ(logits->shape[1], total_draft_length);

// - Update logits.
std::vector<int> cum_verify_lengths = {0};
for (int i = 0; i < num_requests; ++i) {
cum_verify_lengths.push_back(cum_verify_lengths.back() + draft_lengths[i]);
}
logits = logits.CreateView({total_draft_length, logits->shape[2]}, logits->dtype);
logit_processor_->InplaceUpdateLogits(logits, generation_cfg, verify_request_mstates,
request_ids, &cum_verify_lengths, &draft_output_tokens);

// - Compute probability distributions.
NDArray probs_device = logit_processor_->ComputeProbsFromLogits(
logits, generation_cfg, request_ids, &cum_verify_lengths);

std::vector<std::vector<int32_t>> accepted_tokens_arr = sampler_->BatchVerifyDraftTokens(
logits, cum_verify_lengths, models_[verify_model_id_], verify_request_mstates,
generation_cfg, rngs, draft_output_tokens, draft_output_token_prob, draft_output_prob_dist);
probs_device, request_ids, cum_verify_lengths, verify_request_mstates, generation_cfg, rngs,
draft_output_tokens, draft_output_token_prob, draft_output_prob_dist);
ICHECK_EQ(accepted_tokens_arr.size(), num_requests);

for (int i = 0; i < num_requests; ++i) {
Expand Down Expand Up @@ -222,6 +233,8 @@ class BatchVerifyActionObj : public EngineActionObj {
* models, the `Step` function of the created action will not take effect.
*/
Array<Model> models_;
/*! \brief The logit processor. */
LogitProcessor logit_processor_;
/*! \brief The sampler to sample new tokens. */
Sampler sampler_;
/*! \brief The kv cache config. */
Expand All @@ -233,15 +246,15 @@ class BatchVerifyActionObj : public EngineActionObj {
/*! \brief The ids of verify/draft models. */
const int verify_model_id_ = 0;
const int draft_model_id_ = 1;
const float eps_ = 1e-9;
const float eps_ = 1e-5;
};

EngineAction EngineAction::BatchVerify(Array<Model> models, Sampler sampler,
KVCacheConfig kv_cache_config,
EngineAction EngineAction::BatchVerify(Array<Model> models, LogitProcessor logit_processor,
Sampler sampler, KVCacheConfig kv_cache_config,
Optional<EventTraceRecorder> trace_recorder) {
return EngineAction(make_object<BatchVerifyActionObj>(std::move(models), std::move(sampler),
std::move(kv_cache_config),
std::move(trace_recorder)));
return EngineAction(make_object<BatchVerifyActionObj>(
std::move(models), std::move(logit_processor), std::move(sampler), std::move(kv_cache_config),
std::move(trace_recorder)));
}

} // namespace serve
Expand Down
Loading

0 comments on commit d861a7d

Please sign in to comment.