Skip to content

[Eagle] Make eagle disco compatible #2197

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions cpp/serve/engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,18 +102,18 @@ class EngineImpl : public Engine {
ICHECK_GT(this->models_.size(), 1U);
switch (engine_config->speculative_mode) {
case SpeculativeMode::kEagle:
this->actions_ = {
EngineAction::EagleNewRequestPrefill(this->models_, //
logit_processor, //
sampler, //
this->model_workspaces_, //
engine_config, //
this->trace_recorder_),
EngineAction::EagleBatchDraft(this->models_, logit_processor, sampler,
this->model_workspaces_, this->trace_recorder_),
EngineAction::EagleBatchVerify(this->models_, logit_processor, sampler,
this->model_workspaces_, engine_config,
this->trace_recorder_)};
this->actions_ = {EngineAction::EagleNewRequestPrefill(this->models_, //
logit_processor, //
sampler, //
this->model_workspaces_, //
engine_config, //
this->trace_recorder_),
EngineAction::EagleBatchDraft(
this->models_, logit_processor, sampler, this->model_workspaces_,
this->trace_recorder_, engine_config->spec_draft_length),
EngineAction::EagleBatchVerify(this->models_, logit_processor, sampler,
this->model_workspaces_, engine_config,
this->trace_recorder_)};
break;
default:
this->actions_ = {EngineAction::NewRequestPrefill(this->models_, //
Expand Down
4 changes: 2 additions & 2 deletions cpp/serve/function_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,8 @@ void FunctionTable::_InitFunctions() {
this->verify_to_last_hidden_func_ = mod_get_func("batch_verify_to_last_hidden_states");
this->fuse_embed_hidden_func_ = mod_get_func("fuse_embed_hidden_states");
Module mod = this->use_disco ? this->disco_mod->DebugGetFromRemote(0) : this->local_vm;
this->get_logits_func_ = mod->GetFunction("get_logits", true);
this->batch_get_logits_func_ = mod->GetFunction("batch_get_logits", true);
this->get_logits_func_ = mod_get_func("get_logits");
this->batch_get_logits_func_ = mod_get_func("batch_get_logits");
this->batch_select_last_hidden_func_ = mod->GetFunction("batch_select_last_hidden_states", true);
this->softmax_func_ = mod->GetFunction("softmax_with_temperature", true);
this->apply_logit_bias_func_ = mod->GetFunction("apply_logit_bias_inplace", true);
Expand Down
42 changes: 33 additions & 9 deletions cpp/serve/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,16 +136,23 @@ class ModelImpl : public ModelObj {
ICHECK_EQ(hidden_states->device.device_type, device_.device_type);
ICHECK_EQ(hidden_states->device.device_id, device_.device_id);

hidden_states_dref_or_nd =
hidden_states =
hidden_states.CreateView({batch_size * seq_len, hidden_size_}, hidden_states->dtype);

// This copy can be avoided by not copying the hidden states to engine.
hidden_states_dref_or_nd = ft_.CopyToWorker0(
hidden_states, "hidden_states", {max_num_sequence_ * prefill_chunk_size_, hidden_size_});
ObjectRef ret = ft_.get_logits_func_(hidden_states_dref_or_nd, params_);
if (trace_enabled_) {
TVMSynchronize(device_.device_type, device_.device_id, nullptr);
}

NDArray logits;
logits = Downcast<NDArray>(ret);
NDArray logits{nullptr};
if (ret->IsInstance<DRefObj>()) {
logits = Downcast<DRef>(ret)->DebugGetFromRemote(0);
} else {
logits = Downcast<NDArray>(ret);
}
CHECK(logits.defined());
// logits: (b * s, v)
ICHECK_EQ(logits->ndim, 2);
Expand Down Expand Up @@ -185,8 +192,11 @@ class ModelImpl : public ModelObj {
ICHECK_EQ(hidden_states->device.device_type, device_.device_type);
ICHECK_EQ(hidden_states->device.device_id, device_.device_id);

hidden_states_dref_or_nd =
hidden_states.CreateView({total_length, hidden_size_}, hidden_states->dtype);
hidden_states = hidden_states.CreateView({total_length, hidden_size_}, hidden_states->dtype);

// This copy can be avoided by not copying the hidden states to engine.
hidden_states_dref_or_nd = ft_.CopyToWorker0(
hidden_states, "hidden_states", {max_num_sequence_ * prefill_chunk_size_, hidden_size_});

ObjectRef ret =
ft_.batch_get_logits_func_(hidden_states_dref_or_nd, logit_pos_dref_or_nd, params_);
Expand Down Expand Up @@ -218,8 +228,15 @@ class ModelImpl : public ModelObj {
p_logit_pos[i] = total_length - 1;
}
NDArray logit_pos_nd = logit_pos_arr_.CreateView({num_sequences}, DataType::Int(32));

// This step runs on the engine thread.
// By temporarily turning off the disco flag, this copies the logit_pos_nd to the cached device
// tensor without actually copying to the worker.
bool use_disco = ft_.use_disco;
ft_.use_disco = false;
ObjectRef logit_pos_dref_or_nd =
ft_.CopyToWorker0(logit_pos_nd, "logit_pos", {max_num_sequence_});
ft_.use_disco = use_disco;

CHECK(ft_.batch_select_last_hidden_func_.defined())
<< "`batch_select_last_hidden_states` function is not found in the model.";
Expand All @@ -240,7 +257,7 @@ class ModelImpl : public ModelObj {
hidden_states.CreateView({total_length, hidden_size_}, hidden_states->dtype);

ObjectRef ret =
ft_.batch_select_last_hidden_func_(hidden_states_dref_or_nd, logit_pos_dref_or_nd, params_);
ft_.batch_select_last_hidden_func_(hidden_states_dref_or_nd, logit_pos_dref_or_nd);
if (trace_enabled_) {
TVMSynchronize(device_.device_type, device_.device_id, nullptr);
}
Expand All @@ -265,10 +282,17 @@ class ModelImpl : public ModelObj {
// No ICHECK_EQ(hidden->shape[0], hidden_size_) here to allow different hidden_sizes.
hidden = hidden.CreateView({1, hidden_size_}, hidden->dtype);
// Reuse the copy embedding function
ft_.nd_copy_embedding_to_offset_func_(hidden, *dst, cum_length);
ObjectRef hidden_dref_or_nd =
ft_.CopyToWorker0(hidden, "hidden_for_concat", {1, hidden_size_});
ft_.nd_copy_embedding_to_offset_func_(hidden_dref_or_nd, *dst, cum_length);
cum_length += 1;
}
NDArray ret = Downcast<NDArray>(*dst);
NDArray ret{nullptr};
if ((*dst)->IsInstance<DRefObj>()) {
ret = Downcast<DRef>(*dst)->DebugGetFromRemote(0);
} else {
ret = Downcast<NDArray>(*dst);
}
ret = ret.CreateView({cum_length, hidden_size_}, hidden_states[0]->dtype);
return ret;
}
Expand All @@ -295,7 +319,7 @@ class ModelImpl : public ModelObj {
return embeddings_nd.CreateView({batch_size, seq_len, hidden_size_}, embeddings_nd->dtype);
}
} else {
ShapeTuple embedding_shape{batch_size, seq_len, hidden_size_};
ShapeTuple embedding_shape{batch_size * seq_len, hidden_size_};
embeddings_dref_or_nd = ft_.nd_view_func_(embeddings, embedding_shape);

if (!ft_.fuse_embed_hidden_func_.defined() || !previous_hidden_states.defined()) {
Expand Down
4 changes: 1 addition & 3 deletions python/mlc_llm/model/llama/llama_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,6 @@ def batch_get_logits(self, hidden_states: Tensor, logit_positions: Tensor):

def batch_select_last_hidden_states(self, hidden_states: Tensor, logit_positions: Tensor):
op_ext.configure()
if self.tensor_parallel_shards > 1:
logit_positions = op.ccl_broadcast_from_worker0(logit_positions)
hidden_states = op.take(hidden_states, logit_positions, axis=0)
return hidden_states

Expand Down Expand Up @@ -382,7 +380,7 @@ def get_default_spec(self):
"hidden_states": nn.spec.Tensor(["seq_len", self.hidden_size], self.dtype),
"logit_positions": nn.spec.Tensor(["batch_size"], "int32"),
"$": {
"param_mode": "packed",
"param_mode": "none",
"effect_mode": "none",
},
},
Expand Down