Skip to content

[SpecDecode] Support Eagle in speculative decoding #2080

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions cpp/serve/config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -299,16 +299,16 @@ String KVCacheConfigNode::AsJSONString() const {

TVM_REGISTER_OBJECT_TYPE(EngineModeNode);

EngineMode::EngineMode(bool enable_speculative, int spec_draft_length) {
EngineMode::EngineMode(int spec_draft_length, int speculative_mode) {
ObjectPtr<EngineModeNode> n = make_object<EngineModeNode>();
n->enable_speculative = enable_speculative;
n->spec_draft_length = spec_draft_length;
n->speculative_mode = SpeculativeMode(speculative_mode);
data_ = std::move(n);
}

EngineMode::EngineMode(const std::string& config_str) {
bool enable_speculative = false;
int spec_draft_length = 4;
int speculative_mode = 0;

picojson::value config_json;
std::string err = picojson::parse(config_json, config_str);
Expand All @@ -318,25 +318,25 @@ EngineMode::EngineMode(const std::string& config_str) {

// Get json fields.
picojson::object config = config_json.get<picojson::object>();
if (config.count("enable_speculative")) {
CHECK(config["enable_speculative"].is<bool>());
enable_speculative = config["enable_speculative"].get<bool>();
}
if (config.count("spec_draft_length")) {
CHECK(config["spec_draft_length"].is<int64_t>());
spec_draft_length = config["spec_draft_length"].get<int64_t>();
}
if (config.count("speculative_mode")) {
CHECK(config["speculative_mode"].is<int64_t>());
speculative_mode = config["speculative_mode"].get<int64_t>();
}

ObjectPtr<EngineModeNode> n = make_object<EngineModeNode>();
n->enable_speculative = enable_speculative;
n->spec_draft_length = spec_draft_length;
n->speculative_mode = SpeculativeMode(speculative_mode);
data_ = std::move(n);
}

String EngineModeNode::AsJSONString() const {
picojson::object config;
config["enable_speculative"] = picojson::value(static_cast<bool>(this->enable_speculative));
config["spec_draft_length"] = picojson::value(static_cast<int64_t>(this->spec_draft_length));
config["speculative_mode"] = picojson::value(static_cast<int64_t>(this->speculative_mode));
return picojson::value(config).serialize(true);
}

Expand Down
13 changes: 10 additions & 3 deletions cpp/serve/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,20 @@ class KVCacheConfig : public ObjectRef {

/****************** Engine Mode ******************/

/*! \brief The speculative mode. */
enum class SpeculativeMode : int {
kDisable = 0,
kSmallDraft = 1,
kEagle = 2,
};

/*! \brief The configuration of engine execution mode. */
class EngineModeNode : public Object {
public:
/* Whether the speculative decoding mode is enabled */
bool enable_speculative;
/* The number of tokens to generate in speculative proposal (draft) */
int spec_draft_length;
/* The speculative mode. */
SpeculativeMode speculative_mode;

String AsJSONString() const;

Expand All @@ -116,7 +123,7 @@ class EngineModeNode : public Object {

class EngineMode : public ObjectRef {
public:
explicit EngineMode(bool enable_speculative, int spec_draft_length);
explicit EngineMode(int spec_draft_length, int speculative_mode);

explicit EngineMode(const std::string& config_str);

Expand Down
50 changes: 35 additions & 15 deletions cpp/serve/engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,32 +80,52 @@ class EngineImpl : public Engine {
<< ", is smaller than the pre-defined max single sequence length, "
<< this->max_single_sequence_length_;
this->models_.push_back(model);
this->model_workspaces_.push_back(ModelWorkspace{model->AllocEmbeddingTensor()});
this->model_workspaces_.push_back(
ModelWorkspace{model->AllocEmbeddingTensor(), model->AllocHiddenStatesTensor()});
}
int max_num_tokens = kv_cache_config_->max_num_sequence;
if (engine_mode_->enable_speculative) {
if (engine_mode_->speculative_mode != SpeculativeMode::kDisable) {
max_num_tokens *= engine_mode_->spec_draft_length;
}
LogitProcessor logit_processor =
this->models_[0]->CreateLogitProcessor(max_num_tokens, trace_recorder);
Sampler sampler = this->models_[0]->CreateSampler(
max_num_tokens, static_cast<int>(this->models_.size()), trace_recorder);
// Step 3. Initialize engine actions that represent state transitions.
if (this->engine_mode_->enable_speculative) {
if (this->engine_mode_->speculative_mode != SpeculativeMode::kDisable) {
// Speculative decoding is only possible for more than one model.
ICHECK_GT(this->models_.size(), 1U);
this->actions_ = {
EngineAction::NewRequestPrefill(this->models_, //
logit_processor, //
sampler, //
this->model_workspaces_, //
this->kv_cache_config_, //
this->engine_mode_, //
this->trace_recorder_),
EngineAction::BatchDraft(this->models_, logit_processor, sampler, this->trace_recorder_,
this->engine_mode_->spec_draft_length),
EngineAction::BatchVerify(this->models_, logit_processor, sampler, this->kv_cache_config_,
this->trace_recorder_)};
switch (this->engine_mode_->speculative_mode) {
case SpeculativeMode::kEagle:
this->actions_ = {EngineAction::EagleNewRequestPrefill(this->models_, //
logit_processor, //
sampler, //
this->model_workspaces_, //
this->kv_cache_config_, //
this->engine_mode_, //
this->trace_recorder_),
EngineAction::EagleBatchDraft(
this->models_, logit_processor, sampler, this->model_workspaces_,
this->trace_recorder_, this->engine_mode_->spec_draft_length),
EngineAction::EagleBatchVerify(
this->models_, logit_processor, sampler, this->model_workspaces_,
this->kv_cache_config_, this->trace_recorder_)};
break;
default:
this->actions_ = {
EngineAction::NewRequestPrefill(this->models_, //
logit_processor, //
sampler, //
this->model_workspaces_, //
this->kv_cache_config_, //
this->engine_mode_, //
this->trace_recorder_),
EngineAction::BatchDraft(this->models_, logit_processor, sampler,
this->trace_recorder_,
this->engine_mode_->spec_draft_length),
EngineAction::BatchVerify(this->models_, logit_processor, sampler,
this->kv_cache_config_, this->trace_recorder_)};
}
} else {
this->actions_ = {EngineAction::NewRequestPrefill(this->models_, //
logit_processor, //
Expand Down
52 changes: 52 additions & 0 deletions cpp/serve/engine_actions/action.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,23 @@ class EngineAction : public ObjectRef {
std::vector<ModelWorkspace> model_workspaces,
KVCacheConfig kv_cache_config, EngineMode engine_mode,
Optional<EventTraceRecorder> trace_recorder);
/*!
* \brief Create the action that prefills requests in the `waiting_queue`
* of the engine state.
* \param models The models to run prefill in.
* \param logit_processor The logit processor.
* \param sampler The sampler to sample new tokens.
* \param model_workspaces The workspace of each model.
* \param kv_cache_config The KV cache config to help decide prefill is doable.
* \param engine_mode The engine operation mode.
* \param trace_recorder The event trace recorder for requests.
* \return The created action object.
*/
static EngineAction EagleNewRequestPrefill(Array<Model> models, LogitProcessor logit_processor,
Sampler sampler,
std::vector<ModelWorkspace> model_workspaces,
KVCacheConfig kv_cache_config, EngineMode engine_mode,
Optional<EventTraceRecorder> trace_recorder);
/*!
* \brief Create the action that runs one-step decode for requests in the
* `running_queue` of engine state. Preempt low-priority requests
Expand Down Expand Up @@ -97,6 +114,23 @@ class EngineAction : public ObjectRef {
Sampler sampler, Optional<EventTraceRecorder> trace_recorder,
int draft_length = 4);

/*!
* \brief Create the action that runs one-step speculative draft proposal for
* requests in the `running_queue` of engine state. Preempt low-priority requests
* accordingly when it is impossible to decode all the running requests.
* \param models The model to run decode in. When there are multiple
* models, the `Step` function of the created action will not take effect.
* \param sampler The sampler to sample new tokens.
* \param model_workspaces The workspace of each model.
* \param trace_recorder The event trace recorder for requests.
* \param draft_length The number of draft proposal rounds.
* \return The created action object.
*/
static EngineAction EagleBatchDraft(Array<Model> models, LogitProcessor logit_processor,
Sampler sampler, std::vector<ModelWorkspace> model_workspaces,
Optional<EventTraceRecorder> trace_recorder,
int draft_length = 4);

/*!
* \brief Create the action that runs one-step speculative verification for requests in the
* `running_queue` of engine state. Preempt low-priority requests
Expand All @@ -112,6 +146,24 @@ class EngineAction : public ObjectRef {
Sampler sampler, KVCacheConfig kv_cache_config,
Optional<EventTraceRecorder> trace_recorder);

/*!
* \brief Create the action that runs one-step speculative verification for requests in the
* `running_queue` of engine state. Preempt low-priority requests
* accordingly when it is impossible to decode all the running requests.
* \param models The model to run decode in. When there are multiple
* models, the `Step` function of the created action will not take effect.
* \param sampler The sampler to sample new tokens.
* \param model_workspaces The workspace of each model.
* \param kv_cache_config The KV cache config to help decide verify is doable.
* \param trace_recorder The event trace recorder for requests.
* \return The created action object.
*/
static EngineAction EagleBatchVerify(Array<Model> models, LogitProcessor logit_processor,
Sampler sampler,
std::vector<ModelWorkspace> model_workspaces,
KVCacheConfig kv_cache_config,
Optional<EventTraceRecorder> trace_recorder);

TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(EngineAction, ObjectRef, EngineActionObj);
};

Expand Down
Loading