Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Sampler] Enable GPU sampler for draft verification #2198

Merged
merged 7 commits into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion cpp/serve/engine_actions/eagle_batch_verify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ class EagleBatchVerifyActionObj : public EngineActionObj {
generation_cfg.push_back(rsentries[i]->request->generation_cfg);
rngs.push_back(&rsentries[i]->rng);
draft_output_tokens.push_back(draft_mstate->draft_output_tokens);
CHECK(draft_mstate->draft_output_prob_dist[0]->device.device_type == kDLCPU);
draft_output_prob_dist.push_back(draft_mstate->draft_output_prob_dist);
}

Expand Down
1 change: 1 addition & 0 deletions cpp/serve/function_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ void FunctionTable::_InitFunctions() {
gpu_argsort_probs_func_ = mod->GetFunction("argsort_probs", true);
gpu_sample_with_top_p_func_ = mod->GetFunction("sample_with_top_p", true);
gpu_sampler_take_probs_func_ = mod->GetFunction("sampler_take_probs", true);
gpu_verify_draft_tokens_func_ = mod->GetFunction("sampler_verify_draft_tokens", true);
}
this->nd_view_func_ = get_global_func("vm.builtin.reshape");
this->nd_get_shape_func_ = get_global_func("vm.builtin.shape_of");
Expand Down
1 change: 1 addition & 0 deletions cpp/serve/function_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ struct FunctionTable {
PackedFunc gpu_argsort_probs_func_;
PackedFunc gpu_sample_with_top_p_func_;
PackedFunc gpu_sampler_take_probs_func_;
PackedFunc gpu_verify_draft_tokens_func_;
PackedFunc nd_view_func_;
PackedFunc nd_get_shape_func_;
PackedFunc nd_copy_embedding_to_offset_func_;
Expand Down
4 changes: 1 addition & 3 deletions cpp/serve/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -768,9 +768,7 @@ class ModelImpl : public ModelObj {

Sampler CreateSampler(int max_num_sample, int num_models,
Optional<EventTraceRecorder> trace_recorder) {
if (num_models > 1) { // speculative decoding uses cpu sampler
return Sampler::CreateCPUSampler(std::move(trace_recorder));
} else if (Sampler::SupportGPUSampler(device_)) {
if (Sampler::SupportGPUSampler(device_)) {
return Sampler::CreateGPUSampler(max_num_sample, vocab_size_, &this->ft_, device_,
std::move(trace_recorder));
} else {
Expand Down
140 changes: 138 additions & 2 deletions cpp/serve/sampler/gpu_sampler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class GPUSampler : public SamplerObj {
gpu_argsort_probs_func_(ft->gpu_argsort_probs_func_),
gpu_sample_with_top_p_func_(ft->gpu_sample_with_top_p_func_),
gpu_sampler_take_probs_func_(ft->gpu_sampler_take_probs_func_),
gpu_verify_draft_tokens_func_(ft->gpu_verify_draft_tokens_func_),
trace_recorder_(std::move(trace_recorder)) {
ICHECK(gpu_multinomial_from_uniform_func_.defined());
ICHECK(gpu_argsort_probs_func_.defined());
Expand Down Expand Up @@ -92,11 +93,20 @@ class GPUSampler : public SamplerObj {
NVTXScopedRange nvtx_scope("BatchSampleTokens");
// probs_on_device: (n, v)
RECORD_EVENT(trace_recorder_, request_ids, "start sampling");
CHECK(output_prob_dist == nullptr) << "GPU sampler does not support collecting output probs.";
CHECK_EQ(probs_on_device->ndim, 2);
int num_samples = sample_indices.size();
int num_probs = probs_on_device->shape[0];
int vocab_size = probs_on_device->shape[1];
if (output_prob_dist != nullptr) {
ICHECK(output_prob_dist->empty());
output_prob_dist->reserve(num_probs);
for (int i = 0; i < num_probs; ++i) {
NDArray prob_dist = NDArray::Empty({vocab_size}, dtype_f32_, device_);
float* p_prob = static_cast<float*>(probs_on_device->data) + i * vocab_size;
prob_dist.CopyFromBytes(p_prob, vocab_size * sizeof(float));
output_prob_dist->push_back(std::move(prob_dist));
}
}
ICHECK_EQ(request_ids.size(), num_samples);
ICHECK_EQ(generation_cfg.size(), num_samples);
ICHECK_EQ(rngs.size(), num_samples);
Expand Down Expand Up @@ -132,7 +142,132 @@ class GPUSampler : public SamplerObj {
const std::vector<RandomGenerator*>& rngs,
const std::vector<std::vector<SampleResult>>& draft_output_tokens,
const std::vector<std::vector<NDArray>>& draft_output_prob_dist) final {
LOG(FATAL) << "GPU sampler does not support batch verification for now.";
std::vector<std::vector<SampleResult>> sample_results;
// probs_on_device: (n, v)
RECORD_EVENT(trace_recorder_, request_ids, "start draft verification");
CHECK_EQ(probs_on_device->ndim, 2);

int num_sequence = static_cast<int>(cum_verify_lengths.size()) - 1;
CHECK_EQ(rngs.size(), num_sequence);
CHECK_EQ(draft_output_tokens.size(), num_sequence);
CHECK_EQ(draft_output_prob_dist.size(), num_sequence);
sample_results.resize(num_sequence);

int num_nodes = cum_verify_lengths.back();
NDArray uniform_samples_host = uniform_samples_host_.CreateView({num_nodes}, dtype_f32_);
NDArray uniform_samples_device = uniform_samples_device_.CreateView({num_nodes}, dtype_f32_);
NDArray draft_probs_device = NDArray::Empty({num_nodes, vocab_size_}, dtype_f32_, device_);
NDArray draft_tokens_device = NDArray::Empty({num_nodes}, dtype_i32_, device_);
NDArray draft_tokens_host =
NDArray::Empty({num_nodes}, dtype_i32_, DLDevice{DLDeviceType::kDLCPU, 0});

// Concat draft prob distributions to a ragged tensor (num_nodes, vocab_size)
for (int i = 0; i < num_sequence; i++) {
const std::vector<SampleResult>& draft_output_tokens_i = draft_output_tokens[i];
const std::vector<NDArray>& draft_output_prob_dist_i = draft_output_prob_dist[i];
int start = cum_verify_lengths[i];
int end = cum_verify_lengths[i + 1];
// start/end is the range of the sequence i in probs_on_device, which includes the prob dist
// of the draft tokens and the last committed token
ICHECK_EQ(draft_output_tokens_i.size() + 1, end - start);
ICHECK_EQ(draft_output_prob_dist_i.size() + 1, end - start);
for (int j = 0; j < end - start - 1; j++) {
// Copy prob dist
ICHECK_EQ(draft_probs_device->dtype.bits, 32);
float* p_draft_probs =
static_cast<float*>(draft_probs_device->data) +
(j + start + 1) *
vocab_size_; // shift by one, q of the last committed token is undefined
// Copy sampled token id
draft_output_prob_dist_i[j].CopyToBytes(p_draft_probs, vocab_size_ * sizeof(float));
*(static_cast<int*>(draft_tokens_host->data) + j + start + 1) =
draft_output_tokens_i[j].sampled_token_id.first;
}
}
CopyArray(draft_tokens_host, draft_tokens_device, copy_stream_);

float* p_uniform_samples = static_cast<float*>(uniform_samples_host->data);
for (int i = 0; i < num_sequence; ++i) {
int start = cum_verify_lengths[i];
int end = cum_verify_lengths[i + 1];
for (int j = start; j < end; j++) {
p_uniform_samples[j] = rngs[i]->GetRandomNumber();
}
}
CopyArray(uniform_samples_host, uniform_samples_device, copy_stream_);

// This should be refactored to use the cached tensors
NDArray token_tree_first_child_device = NDArray::Empty({num_nodes}, dtype_i32_, device_);
NDArray token_tree_next_sibling_device = NDArray::Empty({num_nodes}, dtype_i32_, device_);
NDArray token_tree_parent_ptr_device = NDArray::Empty({num_sequence}, dtype_i32_, device_);
NDArray token_tree_first_child_host =
NDArray::Empty({num_nodes}, dtype_i32_, DLDevice{DLDeviceType::kDLCPU, 0});
NDArray token_tree_next_sibling_host =
NDArray::Empty({num_nodes}, dtype_i32_, DLDevice{DLDeviceType::kDLCPU, 0});
NDArray token_tree_parent_ptr_host =
NDArray::Empty({num_sequence}, dtype_i32_, DLDevice{DLDeviceType::kDLCPU, 0});
NDArray token_tree_child_to_parent_host =
NDArray::Empty({num_nodes}, dtype_i32_, DLDevice{DLDeviceType::kDLCPU, 0});

// Build the tree structure on CPU
for (int i = 0; i < num_sequence; i++) {
// Assuming no tree structure for now
int start = cum_verify_lengths[i];
int end = cum_verify_lengths[i + 1];
ICHECK_EQ(end - start, 2); // one committed token and assuming only one draft token
static_cast<int*>(token_tree_child_to_parent_host->data)[start] = -1; // root has no parent
for (int j = 0; j < end - start; j++) {
int cur_node = j + start;
int child_node = j + 1 >= end - start ? -1 : cur_node + 1;
static_cast<int*>(token_tree_first_child_host->data)[cur_node] = child_node;
if (child_node != -1) {
static_cast<int*>(token_tree_child_to_parent_host->data)[child_node] = cur_node;
}
static_cast<int*>(token_tree_next_sibling_host->data)[cur_node] = -1;
}
static_cast<int*>(token_tree_parent_ptr_host->data)[i] = start; // point to the root
}
// Copy token tree structure to GPU
CopyArray(token_tree_first_child_host, token_tree_first_child_device, copy_stream_);
CopyArray(token_tree_next_sibling_host, token_tree_next_sibling_device, copy_stream_);
CopyArray(token_tree_parent_ptr_host, token_tree_parent_ptr_device, copy_stream_);

SyncCopyStream(device_, compute_stream_, copy_stream_);

gpu_verify_draft_tokens_func_(draft_probs_device, draft_tokens_device, probs_on_device,
token_tree_first_child_device, token_tree_next_sibling_device,
uniform_samples_device, token_tree_parent_ptr_device);

CopyArray(token_tree_parent_ptr_device, token_tree_parent_ptr_host, compute_stream_);
TVMSynchronize(device_.device_type, device_.device_id, compute_stream_);

std::vector<int> sample_indices;

for (int i = 0; i < num_sequence; i++) {
int start = cum_verify_lengths[i];
int end = cum_verify_lengths[i + 1];
int last_accepted = static_cast<int*>(token_tree_parent_ptr_host->data)[i];
int num_accepted = 0;
for (int cur_node = last_accepted; cur_node != start;
cur_node = static_cast<int*>(token_tree_child_to_parent_host->data)[cur_node]) {
sample_results[i].push_back(draft_output_tokens[i][cur_node - start - 1]);
num_accepted++;
}
std::reverse(sample_results[i].rbegin(), sample_results[i].rbegin() + num_accepted);
sample_indices.push_back(last_accepted);
}
std::vector<SampleResult> additional_sample_result;
// This only works for top-p = 1. To enable top-p, we need to normalize the probs before
// verifying.
additional_sample_result = this->BatchSampleTokens(probs_on_device, sample_indices, request_ids,
generation_cfg, rngs, nullptr);
ICHECK_EQ(additional_sample_result.size(), num_sequence);
for (int i = 0; i < num_sequence; i++) {
sample_results[i].push_back(additional_sample_result[i]);
}

RECORD_EVENT(trace_recorder_, request_ids, "finish draft verification");
return sample_results;
}

private:
Expand Down Expand Up @@ -370,6 +505,7 @@ class GPUSampler : public SamplerObj {
PackedFunc gpu_argsort_probs_func_;
PackedFunc gpu_sample_with_top_p_func_;
PackedFunc gpu_sampler_take_probs_func_;
PackedFunc gpu_verify_draft_tokens_func_;
// Auxiliary NDArrays on CPU
NDArray uniform_samples_host_;
NDArray sample_indices_host_;
Expand Down
50 changes: 50 additions & 0 deletions python/mlc_llm/compiler_pass/attach_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from tvm.relax.frontend import nn
from tvm.script import tir as T

from ..op.batch_spec_verify import batch_spec_verify


@tvm.transform.module_pass(opt_level=0, name="AttachGPUSamplingFunc")
class AttachGPUSamplingFunc: # pylint: disable=too-few-public-methods
Expand Down Expand Up @@ -46,6 +48,7 @@ def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IR
_attach_argsort_func(bb, vocab_size),
_attach_sample_with_top_p(bb, vocab_size),
_attach_take_probs_func(bb, vocab_size),
_attach_batch_verifier(bb, vocab_size),
]
]

Expand Down Expand Up @@ -289,3 +292,50 @@ def sampler_take_probs_tir( # pylint: disable=too-many-locals,too-many-argument
bb.emit_output(taken_probs_indices)
gv = bb.emit_func_output(taken_probs_indices)
return gv


def _attach_batch_verifier(bb: relax.BlockBuilder, vocab_size: tir.PrimExpr):
num_nodes = tir.Var("num_nodes", "int64")
nbatch = tir.Var("nbatch", "int64")
draft_probs = relax.Var(
"draft_probs", relax.TensorStructInfo((num_nodes, vocab_size), "float32")
)
draft_tokens = relax.Var("draft_tokens", relax.TensorStructInfo((num_nodes,), "int32"))
model_probs = relax.Var(
"model_probs", relax.TensorStructInfo((num_nodes, vocab_size), "float32")
)
token_tree_first_child = relax.Var(
"token_tree_first_child", relax.TensorStructInfo((num_nodes,), "int32")
)
token_tree_next_sibling = relax.Var(
"token_tree_next_sibling", relax.TensorStructInfo((num_nodes,), "int32")
)
uniform_samples = relax.Var("uniform_samples", relax.TensorStructInfo((num_nodes,), "float32"))
token_tree_parent_ptr = relax.Var(
"token_tree_parent_ptr", relax.TensorStructInfo((nbatch,), "int32")
)
args = [
draft_probs,
draft_tokens,
model_probs,
token_tree_first_child,
token_tree_next_sibling,
uniform_samples,
token_tree_parent_ptr,
]
with bb.function("sampler_verify_draft_tokens", args):
with bb.dataflow():
res = bb.emit(
relax.call_tir_inplace(
bb.add_func(batch_spec_verify(vocab_size), "batch_verify_on_gpu_single_kernel"),
args,
inplace_indices=[args.index(model_probs), args.index(token_tree_parent_ptr)],
out_sinfo=[
model_probs.struct_info, # pylint: disable=no-member
token_tree_parent_ptr.struct_info, # pylint: disable=no-member
],
)
)
bb.emit_output(res)
gv = bb.emit_func_output(res)
return gv
Loading