Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Serving] Creating EngineConfig from JSON #2237

Merged
merged 1 commit into from
Apr 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cpp/json_ffi/json_ffi_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode {
TVM_MODULE_VTABLE_END();

void InitBackgroundEngine(JSONFFIEngineConfig json_ffi_engine_config, EngineConfig engine_config,
Optional<PackedFunc> request_stream_callback,
Device device, Optional<PackedFunc> request_stream_callback,
Optional<EventTraceRecorder> trace_recorder) {
std::optional<Conversation> conv_template =
Conversation::FromJSON(json_ffi_engine_config->conv_template, &err_);
Expand All @@ -150,7 +150,7 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode {
};

request_stream_callback = PackedFunc(frequest_stream_callback_wrapper);
this->engine_->InitBackgroundEngine(std::move(request_stream_callback),
this->engine_->InitBackgroundEngine(device, std::move(request_stream_callback),
std::move(trace_recorder));
this->engine_->Reload(std::move(engine_config));
}
Expand Down
61 changes: 53 additions & 8 deletions cpp/serve/config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -244,17 +244,16 @@ String GenerationConfigNode::AsJSONString() const {
TVM_REGISTER_OBJECT_TYPE(EngineConfigNode);

EngineConfig::EngineConfig(String model, String model_lib_path, Array<String> additional_models,
Array<String> additional_model_lib_paths, DLDevice device,
int kv_cache_page_size, int max_num_sequence,
int max_total_sequence_length, int max_single_sequence_length,
int prefill_chunk_size, int max_history_size, KVStateKind kv_state_kind,
Array<String> additional_model_lib_paths, int kv_cache_page_size,
int max_num_sequence, int max_total_sequence_length,
int max_single_sequence_length, int prefill_chunk_size,
int max_history_size, KVStateKind kv_state_kind,
SpeculativeMode speculative_mode, int spec_draft_length) {
ObjectPtr<EngineConfigNode> n = make_object<EngineConfigNode>();
n->model = std::move(model);
n->model_lib_path = std::move(model_lib_path);
n->additional_models = std::move(additional_models);
n->additional_model_lib_paths = std::move(additional_model_lib_paths);
n->device = device;
n->kv_cache_page_size = kv_cache_page_size;
n->max_num_sequence = max_num_sequence;
n->max_total_sequence_length = max_total_sequence_length;
Expand All @@ -267,14 +266,60 @@ EngineConfig::EngineConfig(String model, String model_lib_path, Array<String> ad
data_ = std::move(n);
}

EngineConfig EngineConfig::FromJSONString(const std::string& json_str) {
picojson::value config_json;
std::string err = picojson::parse(config_json, json_str);
if (!err.empty()) {
LOG(FATAL) << err;
}

// Get json fields.
picojson::object config = config_json.get<picojson::object>();
String model = json::Lookup<std::string>(config, "model");
String model_lib_path = json::Lookup<std::string>(config, "model_lib_path");
std::vector<String> additional_models;
std::vector<String> additional_model_lib_paths;
int kv_cache_page_size = json::Lookup<int64_t>(config, "kv_cache_page_size");
int max_num_sequence = json::Lookup<int64_t>(config, "max_num_sequence");
int max_total_sequence_length = json::Lookup<int64_t>(config, "max_total_sequence_length");
int max_single_sequence_length = json::Lookup<int64_t>(config, "max_single_sequence_length");
int prefill_chunk_size = json::Lookup<int64_t>(config, "prefill_chunk_size");
int max_history_size = json::Lookup<int64_t>(config, "max_history_size");
KVStateKind kv_state_kind =
static_cast<KVStateKind>(json::Lookup<int64_t>(config, "kv_state_kind"));
SpeculativeMode speculative_mode =
static_cast<SpeculativeMode>(json::Lookup<int64_t>(config, "speculative_mode"));
int spec_draft_length = json::Lookup<int64_t>(config, "spec_draft_length");

picojson::array additional_models_arr =
json::Lookup<picojson::array>(config, "additional_models");
picojson::array additional_model_lib_paths_arr =
json::Lookup<picojson::array>(config, "additional_model_lib_paths");
CHECK_EQ(additional_models_arr.size(), additional_model_lib_paths_arr.size())
<< "The number of additional model lib paths does not match the number of additional models";
int num_additional_models = additional_models_arr.size();
additional_models.reserve(num_additional_models);
additional_model_lib_paths.reserve(num_additional_models);
for (int i = 0; i < num_additional_models; ++i) {
additional_models.push_back(json::Lookup<std::string>(additional_models_arr, i));
additional_model_lib_paths.push_back(
json::Lookup<std::string>(additional_model_lib_paths_arr, i));
}

return EngineConfig(std::move(model), std::move(model_lib_path), additional_models,
additional_model_lib_paths, kv_cache_page_size, max_num_sequence,
max_total_sequence_length, max_single_sequence_length, prefill_chunk_size,
max_history_size, kv_state_kind, speculative_mode, spec_draft_length);
}

TVM_REGISTER_GLOBAL("mlc.serve.EngineConfig")
.set_body_typed([](String model, String model_lib_path, Array<String> additional_models,
Array<String> additional_model_lib_paths, DLDevice device,
int kv_cache_page_size, int max_num_sequence, int max_total_sequence_length,
Array<String> additional_model_lib_paths, int kv_cache_page_size,
int max_num_sequence, int max_total_sequence_length,
int max_single_sequence_length, int prefill_chunk_size, int max_history_size,
int kv_state_kind, int speculative_mode, int spec_draft_length) {
return EngineConfig(std::move(model), std::move(model_lib_path), std::move(additional_models),
std::move(additional_model_lib_paths), device, kv_cache_page_size,
std::move(additional_model_lib_paths), kv_cache_page_size,
max_num_sequence, max_total_sequence_length, max_single_sequence_length,
prefill_chunk_size, max_history_size, KVStateKind(kv_state_kind),
SpeculativeMode(speculative_mode), spec_draft_length);
Expand Down
12 changes: 5 additions & 7 deletions cpp/serve/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,6 @@ class EngineConfigNode : public Object {
/*! \brief The path to the additional models' libraries. */
Array<String> additional_model_lib_paths;

/*************** Device ***************/

/*! \brief The device where the models run. */
DLDevice device;

/*************** KV cache config and engine capacities ***************/

/*! \brief The number of consecutive tokens handled in each page in paged KV cache. */
Expand Down Expand Up @@ -152,12 +147,15 @@ class EngineConfigNode : public Object {
class EngineConfig : public ObjectRef {
public:
explicit EngineConfig(String model, String model_lib_path, Array<String> additional_models,
Array<String> additional_model_lib_paths, DLDevice device,
int kv_cache_page_size, int max_num_sequence, int max_total_sequence_length,
Array<String> additional_model_lib_paths, int kv_cache_page_size,
int max_num_sequence, int max_total_sequence_length,
int max_single_sequence_length, int prefill_chunk_size,
int max_history_size, KVStateKind kv_state_kind,
SpeculativeMode speculative_mode, int spec_draft_length);

/*! \brief Create EngineConfig from JSON string. */
static EngineConfig FromJSONString(const std::string& json_str);

TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(EngineConfig, ObjectRef, EngineConfigNode);
};

Expand Down
20 changes: 11 additions & 9 deletions cpp/serve/engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ class EngineImpl : public Engine {
public:
/********************** Engine Management **********************/

explicit EngineImpl(EngineConfig engine_config, Optional<PackedFunc> request_stream_callback,
explicit EngineImpl(EngineConfig engine_config, DLDevice device,
Optional<PackedFunc> request_stream_callback,
Optional<EventTraceRecorder> trace_recorder) {
// Step 1. Initialize metadata and singleton states inside the engine
this->estate_->Reset();
Expand All @@ -62,9 +63,9 @@ class EngineImpl : public Engine {
this->models_.clear();
this->model_workspaces_.clear();

auto f_create_model = [this, &engine_config, &trace_recorder](const String& model_path,
const String& model_lib_path) {
Model model = Model::Create(model_lib_path, std::move(model_path), engine_config->device,
auto f_create_model = [this, &engine_config, &device, &trace_recorder](
const String& model_path, const String& model_lib_path) {
Model model = Model::Create(model_lib_path, std::move(model_path), device,
engine_config->max_num_sequence,
/*trace_enabled=*/trace_recorder.defined());
model->CreateKVCache(engine_config->kv_cache_page_size, engine_config->max_num_sequence,
Expand Down Expand Up @@ -339,10 +340,11 @@ class EngineImpl : public Engine {
Optional<EventTraceRecorder> trace_recorder_;
};

std::unique_ptr<Engine> Engine::Create(EngineConfig engine_config,
std::unique_ptr<Engine> Engine::Create(EngineConfig engine_config, Device device,
Optional<PackedFunc> request_stream_callback,
Optional<EventTraceRecorder> trace_recorder) {
return std::make_unique<EngineImpl>(std::move(engine_config), std::move(request_stream_callback),
return std::make_unique<EngineImpl>(std::move(engine_config), device,
std::move(request_stream_callback),
std::move(trace_recorder));
}

Expand All @@ -368,10 +370,10 @@ class EngineModule : public ModuleNode {
TVM_MODULE_VTABLE_END();

/*! \brief Initialize the engine with config and other fields. */
void Init(EngineConfig engine_config, Optional<PackedFunc> request_stream_callback,
void Init(EngineConfig engine_config, Device device, Optional<PackedFunc> request_stream_callback,
Optional<EventTraceRecorder> trace_recorder) {
this->engine_ = Engine::Create(std::move(engine_config), std::move(request_stream_callback),
std::move(trace_recorder));
this->engine_ = Engine::Create(std::move(engine_config), device,
std::move(request_stream_callback), std::move(trace_recorder));
}
/*! \brief Construct an EngineModule. */
static tvm::runtime::Module Create() { return Module(make_object<EngineModule>()); }
Expand Down
3 changes: 2 additions & 1 deletion cpp/serve/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,12 @@ class Engine {
/*!
* \brief Create an engine in unique pointer.
* \param engine_config The engine config.
* \param device The device where the run models.
* \param request_stream_callback The request stream callback function to.
* \param trace_recorder Event trace recorder for requests.
* \return The created Engine in pointer.
*/
static std::unique_ptr<Engine> Create(EngineConfig engine_config,
static std::unique_ptr<Engine> Create(EngineConfig engine_config, Device device,
Optional<PackedFunc> request_stream_callback,
Optional<EventTraceRecorder> trace_recorder);

Expand Down
7 changes: 5 additions & 2 deletions cpp/serve/threaded_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@ enum class InstructionKind : int {
/*! \brief The implementation of ThreadedEngine. */
class ThreadedEngineImpl : public ThreadedEngine {
public:
void InitBackgroundEngine(Optional<PackedFunc> request_stream_callback,
void InitBackgroundEngine(Device device, Optional<PackedFunc> request_stream_callback,
Optional<EventTraceRecorder> trace_recorder) final {
device_ = device;
CHECK(request_stream_callback.defined())
<< "ThreadedEngine requires request stream callback function, but it is not given.";
request_stream_callback_ = request_stream_callback.value();
Expand Down Expand Up @@ -231,7 +232,7 @@ class ThreadedEngineImpl : public ThreadedEngine {
};

Optional<PackedFunc> request_stream_callback = PackedFunc(frequest_stream_callback_wrapper);
background_engine_ = Engine::Create(std::move(engine_config),
background_engine_ = Engine::Create(std::move(engine_config), device_,
std::move(request_stream_callback), trace_recorder_);
}

Expand All @@ -247,6 +248,8 @@ class ThreadedEngineImpl : public ThreadedEngine {
}
}

/*! \brief The device to run models on. */
Device device_;
/*! \brief The background normal engine for request processing. */
std::unique_ptr<Engine> background_engine_;
/*! \brief The request stream callback. */
Expand Down
3 changes: 2 additions & 1 deletion cpp/serve/threaded_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,11 @@ class ThreadedEngine {

/*!
* \brief Initialize the threaded engine from packed arguments in TVMArgs.
* \param device The device where to run models.
* \param request_stream_callback The request stream callback function to.
* \param trace_recorder Event trace recorder for requests.
*/
virtual void InitBackgroundEngine(Optional<PackedFunc> request_stream_callback,
virtual void InitBackgroundEngine(Device device, Optional<PackedFunc> request_stream_callback,
Optional<EventTraceRecorder> trace_recorder) = 0;

/*!
Expand Down
5 changes: 0 additions & 5 deletions python/mlc_llm/serve/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,6 @@ class EngineConfig(tvm.runtime.Object):
additional_model_lib_paths : List[str]
The path to the additional models' libraries.

device : tvm.runtime.Device
The device where the models run.

kv_cache_page_size : int
The number of consecutive tokens handled in each page in paged KV cache.

Expand Down Expand Up @@ -203,7 +200,6 @@ def __init__( # pylint: disable=too-many-arguments
model_lib_path: str,
additional_models: List[str],
additional_model_lib_paths: List[str],
device: tvm.runtime.Device,
kv_cache_page_size: int,
max_num_sequence: int,
max_total_sequence_length: int,
Expand All @@ -220,7 +216,6 @@ def __init__( # pylint: disable=too-many-arguments
model_lib_path,
additional_models,
additional_model_lib_paths,
device,
kv_cache_page_size,
max_num_sequence,
max_total_sequence_length,
Expand Down
2 changes: 1 addition & 1 deletion python/mlc_llm/serve/engine_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1070,6 +1070,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals
}
self.tokenizer = Tokenizer(model_args[0][0])
self._ffi["init_background_engine"](
device,
self.state.get_request_stream_callback(kind),
self.state.trace_recorder,
)
Expand All @@ -1079,7 +1080,6 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals
model_lib_path=model_args[0][1],
additional_models=[model_arg[0] for model_arg in model_args[1:]],
additional_model_lib_paths=[model_arg[1] for model_arg in model_args[1:]],
device=device,
kv_cache_page_size=16,
max_num_sequence=max_batch_size,
max_total_sequence_length=max_total_sequence_length,
Expand Down
2 changes: 1 addition & 1 deletion python/mlc_llm/serve/sync_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,6 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals
model_lib_path=model_args[0][1],
additional_models=[model_arg[0] for model_arg in model_args[1:]],
additional_model_lib_paths=[model_arg[1] for model_arg in model_args[1:]],
device=device,
kv_cache_page_size=16,
max_num_sequence=max_batch_size,
max_total_sequence_length=max_total_sequence_length,
Expand All @@ -177,6 +176,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals
speculative_mode=speculative_mode,
spec_draft_length=spec_draft_length,
),
device,
request_stream_callback,
self.trace_recorder,
)
Expand Down
6 changes: 0 additions & 6 deletions tests/python/json_ffi/_ffi_api.py

This file was deleted.

44 changes: 18 additions & 26 deletions tests/python/json_ffi/test_json_ffi_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Union

import tvm
from tests.python.json_ffi import _ffi_api

from mlc_llm.protocol import openai_api_protocol
from mlc_llm.serve import engine_utils
Expand Down Expand Up @@ -61,30 +60,23 @@
]


@tvm._ffi.register_object(
"mlc.json_ffi.ModelDefinedGenerationConfig"
) # pylint: disable=protected-access
class ModelDefinedGenerationConfig(tvm.runtime.Object):
def __init__( # pylint: disable=too-many-arguments
self, temperature: float, top_p: float, frequency_penalty: float, presence_penalty: float
) -> None:
self.__init_handle_by_constructor__(
_ffi_api.ModelDefinedGenerationConfig,
temperature,
top_p,
frequency_penalty,
presence_penalty,
)
def create_model_defined_generation_config(
temperature: float, top_p: float, frequency_penalty: float, presence_penalty: float
) -> tvm.runtime.Object:
return tvm.get_global_func("mlc.json_ffi.ModelDefinedGenerationConfig")(
temperature,
top_p,
frequency_penalty,
presence_penalty,
)


@tvm._ffi.register_object("mlc.json_ffi.JSONFFIEngineConfig") # pylint: disable=protected-access
class JSONFFIEngineConfig(tvm.runtime.Object):
def __init__( # pylint: disable=too-many-arguments
self, conv_template: str, model_generation_cfgs: Dict[str, ModelDefinedGenerationConfig]
) -> None:
self.__init_handle_by_constructor__(
_ffi_api.JSONFFIEngineConfig, conv_template, model_generation_cfgs
)
def create_json_ffi_engine_config(
conv_template: str, model_generation_cfgs: Dict[str, tvm.runtime.Object]
) -> tvm.runtime.Object:
return tvm.get_global_func("mlc.json_ffi.JSONFFIEngineConfig")(
conv_template, model_generation_cfgs
)


class EngineState:
Expand Down Expand Up @@ -187,7 +179,6 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals
model_lib_path=model_args[0][1],
additional_models=[model_arg[0] for model_arg in model_args[1:]],
additional_model_lib_paths=[model_arg[1] for model_arg in model_args[1:]],
device=device,
kv_cache_page_size=16,
max_num_sequence=max_batch_size,
max_total_sequence_length=max_total_sequence_length,
Expand All @@ -199,10 +190,10 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals
spec_draft_length=spec_draft_length,
)

self.json_ffi_engine_config = JSONFFIEngineConfig(
self.json_ffi_engine_config = create_json_ffi_engine_config(
conv_template=self.conv_template.model_dump_json(),
model_generation_cfgs={
model.model: ModelDefinedGenerationConfig(
model.model: create_model_defined_generation_config(
temperature=model_config["temperature"],
top_p=model_config["top_p"],
frequency_penalty=model_config["frequency_penalty"],
Expand All @@ -215,6 +206,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals
self._ffi["init_background_engine"](
self.json_ffi_engine_config,
self.engine_config,
device,
self.state.get_request_stream_callback(),
None,
)
Expand Down
Loading