From d448fdb782cca9e2dfe6b4f416d156c1579ecc2d Mon Sep 17 00:00:00 2001 From: Sunghyun Park Date: Mon, 22 Apr 2024 08:19:47 -0600 Subject: [PATCH] Merge with `mlc-ai/main` (`835223541d4135e511a50cba1deca06731b03abd`, April 18th 2024) (#260) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [Attn] Making decode attn kernel be aware of webgpu target (#1817) This PR enables the decode attn kernel to have awareness of the webgpu backend, so that it helps make sure the total number of threads does not exceed the 256 limit of WebGPU. Co-authored-by: Bohan Hou * [Serving][Refactor] Logit processor and logit bias support (#1828) This PR refactors the existing logit processing pipeline with a unfiied logit processor class. The logit processor class exposes two functions: - `InplaceUpdateLogits`, which takes in the raw logits produced by the model, and apply logit bias (which is introduced in this PR), presence/frequency/repetition penalties, and token id mask in order when needed. - `ComputeProbsFromLogits`, which takes in the updated logits, and invoke softmax with temperature to compute the probability distribution. The logit processor completely runs on GPU. This being said, all the logit bias / penalty / mask application and the softmax is backed by GPU kernels. This is a highlight difference compared with the logit processing prior to this PR, where the processing happens on CPU, and softmax also happens on CPU when any logit process is needed. With the unified logit processor, we simplified the interface of handling model's output logits in engine actions to make it cleaner. We also simplified the interface of Sampler. Preliminary results show that LogitProcessor brings a bit perf improvement when any processing is needed. * [Serving][Grammar] BNF grammar simplifier and matcher (#1801) * [Serving] LogProbs support (#1832) This PR introduces the logprobs support with OpenAI API compatibility. It enhances the sampler with a function to get the top-probability tokens (supporting 5 tokens at most as of now). To make it easy to pass logprob results back from serving engine to frontend, we choose to pass logprob results in JSON string with OpenAI API spec. Unit tests are added to ensure the correctness of logprobs. And the logprobs support also work with speculative decoding. * [Serving] Support Mixtral in MLC Serve (#1840) This PR supports Mixtral in MLC serve. The main thing is only introducing the Mistral conversation template to Python registry so that MLC Serve can use. Besides that, this PR updates the KV cache capacity analysis to make it more accurate in terms of usage calculation, while being conservative since there is a known issue regarding batch-prefill embedding taking which may lead to OOM. We will reset the follow up on the issue with a fix in the future and then enable the estimation to use more GPU vRAM. * [Fix] Fix `u_char` for Windows build (#1848) Prior to this PR, `u_char` was used while it is not a standard type in C++, which causes Windows build failure. This PR fixes it by using `unsigned char`. * Auto updated submodule references * [Fix] Add phi lm head name to is_final_fc, add q4f16_ft to CI (#1849) [Fix] Add phi lm head name to is_final_fc * [Build] Replace mod_transform_before_build with IRModule pass (#1852) Instead of a python function that returns an updated `IRModule`, the new `optimize_mod_pipeline` function returns a `tvm.ir.transform.Pass` which can be applied to an `IRModule`. * [SLM] Add support for InternLM architecture (#1835) * Create __init__.py * Add files via upload * Update model.py * Update model_preset.py * Update conv_templates.cc * Update internlm_loader.py * Update internlm_quantization.py * fix name of notes * Update model.py * Migration * fix pylint issue * fix pylint issue * fix pylint error * Update internlm_loader.py * Update __init__.py * Update __init__.py * Delete python/mlc_chat/model/internlm/__init__.py * Add files via upload * [Bugfix] Handle model names with multiple path components (#1851) Prior to this commit, a model name with multiple path components (e.g. `dist/models/group_name/model_name`) would have duplicated path components (e.g. `dist/group_name/artifact_path/group_name/libname.so`). This commit resolves the duplication. * [KVCache] Add max num threads awareness to KVCache kernels (#1822) * [KVCache] Add max num threads to KVCache kernels, fix WebGPU * Read max_num_threads_per_block when available * Change merge state in place kernel * Make attention decode aware of max num threads, not just webgpu Co-authored-by: Egor Churaev * Change util function name --------- Co-authored-by: Egor Churaev * [KVCache] Migrate Baichuan model to PagedKVCache (#1854) * [Python] Lazy import of transformers for tiktoken conversion (#1860) This PR moves the import of transformers into the function body of tiktoken tokenizer conversion, so we do not have a force dependency on transformers. * [SLM] RWKV5 World Support (#1787) This PR adds RWKV5 support with RNNState, a similar interface as PagedAttention. Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> * [Serving] Register the ChatML conversation template (#1862) Following #1854 , this pr registers the ChatML conversation template. * [Utils][Transform] Added SetEntryFuncs transform (#1855) Sets the entry functions for a module. This utility is intended for cases where only module contains several externally-exposed functions, and only one is desired for use. (e.g. Separating out a `transform_params` function from an `IRModule` that also contains inference functions.) This commit only updates the external visibility, after which `relax.transform.DeadCodeElimination()` can be applied. * [Build] Update transform_params_for_each_rank to IRModule pass (#1856) This allows it to be used as part of a optimization pipeline specified as a `tvm.ir.transform.Sequential`. * [Serving][Grammar] Integrate JSON grammar into the generation pipeline (#1867) This PR is the 3rd part of the grammar-guided generation. This intregrates the grammar framework into the generation process, and supports JSON output for now. The API this PR provides is compatible with the OpenAI api. ### APIs #### Python API ``` @dataclass class ResponseFormat: type: Literal["text", "json_object"] = "text" json_schema: Optional[str] = None @dataclass class GenerationConfig: response_format: ResponseFormat = ResponseFormat(type="text") ``` #### Rest API ``` response_format: { "type": "text" } # text generation, by default response_format: { "type": "json_object" } # json generation response_format: { "type": "json_object", json_schema="..."} # json generation with schema ``` JSON generation with schema is not supported yet, but has been planned to be realized in the future. ### Performance #### Without JSON ``` Single token prefill latency: 891.2234 ms/tok Single token decode latency: 31.3399 ms/tok Prefill token throughput: 4693.3077 tok/s Decode token throughput: 226.4406 tok/s Overall token throughput: 470.3180 tok/s ``` #### With JSON ``` Single token prefill latency: 219.2287 ms/tok Single token decode latency: 29.1399 ms/tok Prefill token throughput: 7392.1555 tok/s Decode token throughput: 179.2296 tok/s Overall token throughput: 1052.1996 tok/s ``` We observed a slight decrease in performance under JSON mode. This will be further optimized in the future. * [Serving] Support "n" for parallel generation (#1868) This PR brings field `n` to generation config and thereby supports parallel generation. This parallel generation effectively leverages the "fork" functionality of paged KV cache. This PR supports specifying the number of parallel generation `n` in stardard OpenAI ChatCompletion API. This is the last feature towards the OpenAI API feature completeness. * [CI] Add retry to scm checkout (#1869) Sometimes scm checkout can timeout, this PR add retry to that * [Attn] Use float32 accumulation in attention kernel (#1870) Prior to this PR, the TIR attention kernels does not cast matmul operands to fp32 before multiplying. For models like Phi-2 which may have large Q/K/V data (at the level of a few hundreds), the fp16 multiplication exceeds the range of fp16, and lead to attention result being NAN sometimes. This PR fixes this issue. * [Utils] Allow ReorderTransformFunc to be used without param manager (#1857) Prior to this commit, the `ReorderTransformFunc` required several components of the `ParamManager` to use. The functionality it provides, reordering dataflow blocks to minimize the liveset, is useful outside of the context of the `ParamManager`. This commit makes the following changes, allowing it to be used independently of the `ParamManager`. - Generate the `pidx2binname` dictionary outside of `ReorderTransformFunc` - Allow parameters to be separate `func.params`, rather than a single bundled tuple parameter. * [SLM] Migrate Phi-2 to paged KV Cache #1871 (#1872) This PR migrates Phi-2 for Paged KV cache Attention as a part of Model definition migration according to #1749 . Co-authored-by: Shrey Gupta * [Fix] Fix the use of "call_inplace_packed" and "call_pure_packed" (#1874) The use of `call_inplace_packed` and `call_pure_packed` in the old flow is outdated due to signature changes. This PR fixes the issue. * [Fix] Add the missing BundleModelParams pass (#1875) PR #1852 missed to apply the BundleModelParams pass and thus made the compiled models not runnable through ChatModule (#1864). This PR fixes the issue. * [Docs] Update Android APK download link (#1876) As pointed out by #1830, this PR fixes the Android app download link in docs. * Fix MLC-LLM website link weight convert not accessible (#1877) Fix website link not accessible * [Serving][Grammar] Support termination state in GrammarStateMatcher (#1884) * [Serving] Make RequestState as a standalone object class (#1878) This PR adopts suggestions from the support of OpenAI API parallel generation `n` in #1868. The main update in this PR is to make the RequestState as a standalone object class, which was a typedef from `std::vector` before. This PR also fixes a bug in prefill that will cause engine failure when `n` is large. * [SLM] Update StableLM model and migrate it to paged KV Cache (#1882) * [KVCache] Qwen 1.0 Model PagedKV Support (#1887) Support Qwen1.0 Paged KV Cache * [Serving] Estimate KV cache memory usage with metadata (#1888) Prior to this PR, the serving engine memory usage estimation reads model config for fields such as `num_key_value_heads`, `num_hidden_layers`, etc.. However, since not every model share the same set of config names (#1854), the estimation fails for models that do not have this set of config field names. This PR makes the following changes. First, it attaches these field values into the model's metadata, in which way we unify the field names for different models effectively. Then, when estimating the memory usage, we read these fields from the metadata, rather than model config, so we are safe for the name inconsistency. * [KVCache] Migrate bigcode arch to PagedKVCache (#1891) Compilation and runtime smooth. I will open follow-up PRs to enable starcoder2 support in the same model definition file * [Serving] Add Phi-2 conv template to mlc serve (#1890) This PR adds the phi-2 model template to MLC serve. For testing 1. Start server ```python -m mlc_chat.serve.server --model ./dist/phi-2-q4f16_1-MLC/ --model-lib-path ./dist/phi-2-q4f16_1-MLC/phi-2-q4f16_1-cuda.so --device auto --max-batch-size 2 --enable-tracing --host 127.0.0.1 --port 8000 --max-total-seq-length 8000``` 2. Send request ```python test_server_rest_api.py``` ```python # test_server_rest_api.py import requests import json model = "./dist/phi-2-q4f16_1-MLC/" port = 8000 payload = { "model": f"{model}", "messages": [{"role": "user", "content": "Tell me about Machine Learning in 200 words."}], "stream": False, } r = requests.post(f"http://127.0.0.1:{port}/v1/chat/completions", json=payload) if r.status_code != 200: print(r.json()) else: print(r.json()["choices"][0]["message"]["content"]) ``` * [Attn] Fix attention kernel for head dim not divisble by 32 (#1889) Prior to this PR, our TIR prefill attention kernel assumes the head dim to be a multiple of 32. As reported by #1826, this assumption does not always hold. This PR fixes this issue so that models with different head dim can also compile. * [Python] Enable "thrust" for CUDA by default (#1866) This PR enables thrust for CUDA targets so that we can dispatch some operators (e.g., cumsum) to thrust. * [Serving] Fix loading presharded weights (#1894) * [Serving] Address embedding lookup OOM issue (#1899) This PR addresses the OOM issue that may be caused by embedding lookup when the batch size of a prefill action is large. Prior to this PR, a large embedding tensor will be created for each sequence in the prefilled batch, thus may take unexpectedly large memory when the batch size is large. * [Model] Remove redundant `batch_forward` and move broadcast (#1900) This PR contains four changes: 1. It removes the duplicate `batch_forward` defined in model definitions. This function was widely used prior to our migration to PagedKVCache, since before migration the attention codepath of single sequence forward and batch forward differ. But since our migration, the codepaths are unified into one, and therefore we can safely remove most `batch_forward` functions. 2. It moves `op.ccl_broadcast_from_worker0` from model main forward (which will be called at the beginning of prefill/decode) to embedding. This change has two benefits. Firstly, the token ids taken by `embed` was not broadcasted across workers, and it is possible for workers other than 0 to have illegal token ids which is not in the range of vocab size, and moving the broadcasting to `embed` perfectly address this issue. Secondly, broadcasting token ids in `embed` is more lightweight than broadcasting embeddings in `prefill`/`decode`, since the tensor size of token ids is much smaller. 3. It adds `max_batch_size` to the config class of models, so that they are potentially compatible with batching and MLC serve. 4. It removes the `k_cache` and `v_cache` effects from the models that have switched to PagedKVCache support. Randomly picked a few models (as below) to run the engine test, and all of them are passed: * phi-2 with tp=2, * RedPajama with tp=2, * stablelm with tp=2 (since stablelm does not support TP right now). * [KVCache]Migrate Qwen2 model to PagedKVCache (#1903) * [CI] Skip not supported quantization in model compilation test (#1904) This PR updates the model compilation test so that it will now skip a quantization when the model does not support. * [Serving] Add missing header for `std::iota` (#1905) The header `` was missed, which may have caused build failure on Windows. This PR adds the header. * [Serving] Fix Model TokenEmbed function with TP (#1906) This PR fixes a severe bug introduced by #1899. Since #1899, we no longer copy the embedding back from worker 0 when using tensor parallelism. However, we did not synchronize with the worker 0. This will cause the following issue: in batch prefill, we will continuously call TokenEmbed for multiple times. Each time, we will copy the token ids to the `token_ids` NDArray on worker 0. If we do not synchronize with worker 0, then it is possible that the local token ids have been updated for multiple times, before the first `CopyToWorker0` really starts to execute on the worker 0 side. As a result, at the time of executing the token ids copy to worker 0, the local token ids might be wrong (by "wrong", say we are executing the copying of seq 0's token ids, then the actual local token ids array might have already been seq 3's token ids). As a result, the issue will cause the batch prefill behave completely wrong. This PR adds a synchronization with worker 0 explicitly. * [SLM] Add support for Orion architecture. (#1883) This is a PR for supporting [OrionStarAI/Orion-14B-Chat](https://huggingface.co/OrionStarAI/Orion-14B-Chat). * [Model] Eliminate the reshape in embedding func (#1908) Prior to this PR, there is a trailing reshape kernel at the end of the embedding func. The reshape is not necessarily needed to be as a kernel, which consumes extra time during execution. This PR eliminates the reshape in the embedding function by updating the signature of the embedding func, so that now it only takes the plain 1D token ids as input. * [Pass] Low batch GEMM using GEMV-like schedule (#1769) When batch size is small, GEMM in MLP of decode stage can be dispatched into a specialized GEMV-like schedule to improve efficiency. GEMM with a dynamic var in spatial axis will now be lowered into ```python if dyn_var <= 8: low_batch_gemv() else: normal_gemm() ``` * Auto updated submodule references * [Serving] Avoid unnecessary worker sync in Model (#1909) Following up #1906, this PR removes the synchronization given it is avoidable. We use another approach to avoid the write-after-write issue. The key to address the issue is to make sure the addresses to be copied to worker 0 is not rewritten before the copy actually happens. So we pre-allocate a large host array to hold all the token ids, and for each sequence, we copy its token ids to the offset given when calling TokenEmbed, so that we can make sure an address will not be written twice before copy happens. * [Serving][Grammar] Enhance GrammarStateMatcher to support general grammar (#1917) * [Android] Improve perf of TIR PagedAttn kernel on Android (#1915) * android perf * Update kv_cache.py * Deprecate old flow (#1928) * Deprecate old flow This PR deprecates the old flow. As of today most of the efforts are centralized around the new flow with SLM compilation. Additionally, we are bringing model definitions through unified kv interface so we can have a single model across all backends, server and local setting. We kept the old flow around for a while, but it is a good time to do the transition. All the documents are updated to point to the new flow. We also created a backup branch https://github.com/mlc-ai/mlc-llm/tree/backup-before-old-flow-deprecation for people who would like to checkout some of the old flow references. * Remove deprecated prebuilts * [Serving] Register the StableLM3B conversation template (#1920) Update conversation_template.py * Remove deprecated build.py * [Fix] KVCache creation with call_pure_packed (#1930) With https://github.com/apache/tvm/pull/16684 merged in, the KV cache creation will fail when compiling models. This PR fixes the problem by using `call_pure_packed`. * [KVCache] Update FlashInfer PackedFunc names (#1931) This PR updates the FlashInfer names given https://github.com/apache/tvm/pull/16692 has been merged. * [REFACTOR] remove tests/legacy-python (#1933) This PR removes the folder tests/legacy-python as a followup cleanup step of the old flow Some of the files like compare lib are useful and we should recover them later at mlc_llm.testing.DebugChat flow * [REFACTOR] rename mlc_chat => mlc_llm (#1932) This PR renames the mlc_chat pckage to the mlc_llm package now that this is the new official flow. We also update the necessary locations that might touch the package. * Auto updated submodule references * [Docs] Deprecating CUDA 11.7/11.8 support (#1939) We have deprecated the wheel support for CUDA 11.7/11.8 due to TVM thrust compatibility with old CUDA versions. * [Fix] Fix KV cache call in mistral (#1938) The latest TVM introduces the wellformedness check of the IR. The mistral model definition breaks the wellformedness due to the purity. This PR fixes this issue. * [ChatModule] Remove eos_token_ids (#1940) This PR removes the eos_token_ids from the ChatModule given it is nowhere used actually. * [SLM] Weight conversion with generator (#1916) This PR enhances weight conversion so that it passes a generator to `tvmjs.dump_ndarray_cache`. This effectively reduces the CPU memory pressure when converting weights, especially when the total converted weight size is close to or larger to the CPU memory size. * [Serve] Introducing GPU sampler for CUDA (#1934) This PR introduces the GPU sampler for CUDA only. The GPU sampler makes use of the GPU sampling ops introduced in apache/tvm#16575. We will follow up to benchmark the performance of the GPU sampler over CPU sampler. * [Serve] Constrain KV cache capacity on Metal (#1943) This PR constrains the KV cache capacity for Metal devices to 32768, in order to avoid large tensors in KV cache. This is because right now Metal runtime has performance issue when running a kernel where when some input buffer is very large, even if little of the large buffer is accesed in the kernel. * [CI] Add windows ci (#1942) This PR adds windows CI. * Auto updated submodule references * [Fix] Fix embedding shape check in ChatModule (#1953) This PR is a fix to address #1952. * [Fix] Fetching the Git-LFS tokenizer files (#1954) Prior to this PR, when running commands like ```shell python3 -m mlc_chat chat HF://mlc-ai/gemma-7b-it-q4f16_2-MLC ``` only the binary weight files are downloaded, among all the Git LFS files. For models like Gemma whose tokenizer is large and also in Git LFS file, the tokenizer files are not effectively downloaded automatically. For example, the cloned Gemma `tokenizer.json` file has content ``` version https://git-lfs.github.com/spec/v1 oid sha256:05e97791a5e007260de1db7e1692e53150e08cea481e2bf25435553380c147ee size 17477929 ``` and this content is never realized to the actual tokenizer. This will lead to the issue of #1913. This PR fixes the issue by pulling all the Git LFS files that are not binary files. * [LogitProcessor] Add max thread awareness to logit processing kernels (#1955) Make the kernels in `AttachLogitProcessFunc` to be aware of maximum threads, fixing https://github.com/mlc-ai/mlc-llm/issues/1951. Most code change is due to indentation, the main change is changing `1024` to `tx`, where `tx` is ``` tx = 1024 # default max_num_threads_per_block = get_max_num_threads_per_block(target) if max_num_threads_per_block < tx: tx = max_num_threads_per_block check_thread_limits(target, bdx=tx, bdy=1, bdz=1, gdz=1) ``` * [Model] Use static hidden size in mixtral scatter_output (#1959) * Auto updated submodule references * [CompilerFlag] Detect if FlashInfer is enabled from libinfo (#1941) This PR supports the detection of if FlashInfer is enabled when building TVM, so that FlashInfer won't be enabled when TVM is not built with FlashInfer enabled. * [Serving][Grammar] Add grammar termination as a stop condition (#1964) * Unify schema for conversation template and embed into mlc-chat-config.json (#1965) * [SLM] Small correction on Stablelm and Qwen2. (#1958) * small fix * small fix * Update stablelm_model.py * [Serving][Fix] Fix JSON output check in test_server.py (#1966) `test_server::is_json_or_json_prefix` is used to check the output is JSON or a prefix of JSON. It uses json.loads internally. However, json.loads (i.e. json.decode) is token-based instead of char based. If half a token is left at the end of the string, it cannot be matched. This PR adds another check for the rest "half a token" if it exists. * [Model] Migrate Mistral to use PagedKVCache (#1967) This PR migrates the mistral model to the PagedKVCache interface which supports sliding window attention with paged attention kernel written in TensorIR. We thereby introduce a `support_sliding_window` mode for KV cache, which leaves space for supporting sliding window for any model at runtime. This PR tests the mistral on with both chat and serve. The chat performance of Mistral 7B gets improvement than before, benefitted from the paged attention implementation. * Auto updated submodule references * [REST] Update Rest API docs for the latest serve flow (#1972) * [Docs][Upd] Server launch, examples for endpoints for MLC Serve * remove v1/completions * add api docs to rest --------- Co-authored-by: Shrey Gupta * [Conv] Add bos_token to llama and mistral in ConvTemplateRegistry (#1970) Since we don't have the `add_bos` field in the new Conversation template, we should add the bos token into the system_prefix_token_ids, so that it will be added to the tokenized prompt. * [Model][Serve] Add support for LLaVa model in serving engine (#1974) This PR adds support for LLaVa-v1.5 model on the serving engine. Use the HF weights and config from https://huggingface.co/llava-hf/llava-1.5-7b-hf. Passing image input is supported as url (reference: https://platform.openai.com/docs/guides/vision) Example: ```python data = { "model": "dist/llava-1.5-7b-hf-q4f16_1-MLC/params/", "messages": [ { "role": "user", "content": [ { "type": "image_url", "image_url": "https://llava-vl.github.io/static/images/view.jpg", }, {"type": "text", "text": "What does this image represent?"}, ], } ] } response = requests.post("http://127.0.0.1:8000/v1/chat/completions", json=data) print("Response body:", response.text) ``` * [Serve] Hot fix for the mixtral serving (#1975) [Fix] hotfix for the mixtral serving Co-authored-by: Yong Wu * [REST] REST API Deprecated (#1973) Deleted old Rest API - Removed rest.py - Removed old interface/openai_api.py - Update ChatModule to use new OpenAI Api protocol Co-authored-by: Kartik Khandelwal * [Fix] Fix handling of non-numerical cuda arch (#1976) In the latest gpu, cuda arch may not be integer, e.g `sm_90a`. This fixes a few places that rely on integer parsing. * [Serving][Grammar] Support specifying the main rule in grammar (#1982) finish * [Fix] Fix `MLC_MULTI_ARCH` with arch `sm_90a` (#1984) This PR fixes the missing patch for target with `sm_90a` arch, as follow up pr of #1976. * Fix Llama-2 and Mistral conversation template. Update ConvTemplateRegistry (#1981) The current prompt format for Llama-2 and Mistral is not completely correct. This PR updates the code to strictly follow the official prompt format for the two models. Also adds in missing conv templates to ConvTemplateRegistry. * [SpecDecode] Fix sampler selection. (#1971) This PR temporarily fixes sampler selection logic for speculative decoding. As GPU sampler support for speculative decoding is not ready, speculative decoding will use cpu sampler. * [Serving][Grammar] Utility to convert json schema to EBNF grammar (#1983) This PR adds a generic utility to convert json schema, especially generated from pydantic, to EBNF grammar. This helps the grammar guided generation when we provide a json schema as the restriction. This converter features the support of json standard indent style in the output grammar. API: ``` def json_schema_to_ebnf( json_schema: str, *, indent: Optional[int] = None, separators: Optional[Tuple[str, str]] = None, strict_mode: bool = True, ) -> str: """Convert JSON schema string to EBNF grammar string. Parameters ---------- json_schema : str The JSON schema string. indent : Optional[int] The number of spaces for each indent. If it is None, there will be no indent or newline. The indent and separators parameters follow the same convention as `json.dumps()`. separators : Optional[Tuple[str, str]] The separator between different elements in json. Examples include "," and ", ". strict_mode : bool Whether to use strict mode. In strict mode, the generated grammar will not allow unevaluatedProperties and unevaluatedItems, i.e. these will be set to false by default. This helps LLM to generate accurate output in the grammar-guided generation with JSON schema. """ pass ``` * Auto updated submodule references * [Fix] Fix serve model to adapt the latest Allocator signature (#1989) PR apache/tvm#16738 updated the Allocator signature. This PR updates the caller side accordingly. * [Model] Use optimized group gemm for Mixtral (#1988) * [Attn] Fix the construction of attn result merge kernel (#1995) This PR fixes the mistake of passing wrong number of heads to the attention result merge kernel. * [iOS][Android] Add validation of library file for iOS and Android build (#1993) This PR adds validation of symbols in iOS and android build. During static library build, we need the right model_lib for us to point to the packaged model executables. Not doing so correctly will results in vm_load_executable not found which is not informative. This PR we validate the compiled model lib by dumping the global symbols and ensure the list of model libs matches with each other. In future we should perhaps lift the validation to mlc_llm package. * Auto updated submodule references * [Serve] add allocator in Storage as the upstream change (#1997) The changes in https://github.com/apache/tvm/pull/16750 modified the signature of the Storage, this pull request updates the caller code in mlc-llm to accommodate the new Storage class signature. Ran into build error w/o the change. * [Compiler] Support IPC memory and customized all-reduce kernels (#1990) This PR introduces the IPC memory and customized all-reduce kernel dispatches for tensor parallelism. We add a new compiler flag `--allreduce-strategy`, which supports `"ring"`, `"one-shot"` and `"two-shot"`. The flag defaults to `"ring"`, which means this PR makes no difference if people do not manually change the all-reduce strategy. As of now the IPC-memory-backed customized all-reduce kernels are only available on CUDA. To enable all-reduce strategies other than "ring", here are some example compile commands: ```python python -m mlc_llm compile model/mlc-chat-config.json --device cuda --opt "allreduce-strategy=one-shot" -o model/lib.so python -m mlc_llm compile model/mlc-chat-config.json --device cuda --opt "allreduce-strategy=two-shot" -o model/lib.so ``` Please be aware that, you probably also need to specify other compiler flags, for example, like `--opt "cublas_gemm=1;allreduce-strategy=one-shot"`. * Auto updated submodule references * [Model] Fix the top-k TIR script for well-formedness (#2002) This PR fixes the malformed MoE TIR scripts. * Fix invalid use of dataflow var in sampler output (#2003) * [Fix] Fix KV cache creation pass after nn.Module changes (#2011) This PR corrects the assertion after latest changes in apache/tvm that updates some nn.Module behavior. * [iOS] Fix typo in prepare_model_lib.py (#2013) Fix typo in prepare_model_lib.py tar_list.append(valid_paths[ls0]) is introduced by mistake in https://github.com/mlc-ai/mlc-llm/pull/1993 * Remove unstable assertion in KV cache creation dispatch (#2017) This particular assertion is unstable recently given the back-and-forth upstream TVM nn.Module exporter behavior. * Auto updated submodule references * [SLM] Qwen2 Multi-GPU support (#1985) * Update qwen2_model.py * fix lint issue * fix lint issue * fix lint issue * more info for preshard (#2027) * When the pre-sharded version of a certain model is not available, the program will default back to the normal workflow without issuing any alert. Now, when someone attempts to convert to a pre-sharded model but cannot, the program will throw a warning message to inform users that it will revert to the standard model conversion process. * format fix. * black reformatted, i did not see any diff. * black reformatted.. * Register stablelm-2 conversation template (#2029) * [Serving][Fix] Fix problems in PopenServer (#2032) This PR fixes several problems in the PopenServer: - Add check for the server is not started and the request returns a fail number, e.g. 502. And changed the retry time to 0.1s. - Add a `__enter__` and `__exit__` method for PopenServer. When the program is interrupted, using with clause (`__enter__` and `__exit__`) can ensure the server always terminates. When using `start()` and `terminate()`, the server may still be staying in the background even though the parent process ends. * [Quantization] Skip MoE gate layer (#2012) This PR skips quantizing the MoE gate layer. * [Serving][Grammar] Integration of JSON schema generation (#2030) Previous PR #1983 introduced a transformation from json schema to BNF grammar. This PR further integrates the grammar from json schema to the generation pipeline, so that the engine now supports json schema output. GrammarStateInitContexts are stored in a cache, so it will not be created again with the same schema. Interface: - Python ``` @dataclass class ResponseFormat: type: Literal["text", "json_object"] = "text" schema: Optional[str] = None ``` - Rest API ``` class RequestResponseFormat(BaseModel): type: Literal["text", "json_object"] = "text" json_schema: Optional[str] = Field(default=None, alias="schema") class CompletionRequest(BaseModel): ... response_format: RequestResponseFormat = Field(default_factory=RequestResponseFormat) class ChatCompletionRequest(BaseModel): ... response_format: RequestResponseFormat = Field(default_factory=RequestResponseFormat) ``` Performance: We only tests single-batch performance now to show the overhead in latency. - Model: `Llama-2-7b-chat-hf-q4f16_1` - GPU: `NVIDIA GeForce RTX 3080` - CPU: `AMD Ryzen 9 5900X 12-Core Processor` ``` JSON ON Batch=1 Average prefill tokens: 651.0000 tok/req Average decode tokens: 499.0000 tok/req Single token prefill latency: 0.3140 ms/tok Single token decode latency: 8.6831 ms/tok Prefill token throughput: 3184.8002 tok/s Decode token throughput: 116.6039 tok/s JSON OFF Batch=1 Average prefill tokens: 651.0000 tok/req Average decode tokens: 499.0000 tok/req Single token prefill latency: 0.3098 ms/tok Single token decode latency: 8.6823 ms/tok Prefill token throughput: 3227.8141 tok/s Decode token throughput: 116.9251 tok/s ``` This PR also does these bug fixes / changes: - Changed the structure of the converted grammar from schema to avoid large amount of uncertain tokens, which caused a performance degradation * [Compiler] Support AUTO mode for all-reduce strategy (#2034) This PR supports the auto mode for IPC all-reduce strategy. It renames the strategy from `allreduce-strategy` to `ipc-allreduce-strategy` in the compiler optimization flags. The default RING mode is renamed to NONE mode, which, when specified, uses nccl all-reduce without any IPC memory rewrite. So right now to enable IPC all-reduce, the ideal way is to do `ipc-allreduce-strategy=auto`. * [LLaVa] Follow-up for TODOs in LLaVa model (#2010) Llava: 1. Added base64 image support. 2. Merged as_prompt and as_prompt_list. 3. get_image_from_url uses config * [Pipeline] Defer GPU IPC memory lowering (#2038) This PR moves the position of GPU IPC memory lowering pass in pipeline, so that it applies after the CUDA graph rewrite to enable CUDA graph with the customized all-reduce kernels. * [Model] Add missing broadcast of logit_position for multigpu (#2040) This commit adds the broadcasting of `logit_pos` in batch prefill for all models to avoid the logit position out-of-bound issue. * [Preshard] apply presharding after quantization (#2039) This change the behavior of presharding by apply presharding after quantization. This makes the behavior consistent with or without presharding * [SLM] Baichuan Multi-GPU support (#2037) This PR enables TP function of Baichuan2 model. * Auto updated submodule references * [Model] Skip TVMSynchronize when tracing is not enabled (#2041) This PR removes the synchronization in `Model` when Chrome tracing is not enabled. It can help some logit process kernels launching earlier. * [Serving] Support NVTX for benchmarking (#2043) This PR supports MLC serve with NVTX which helps analyzing benchmarking results. **Note.** To enable NVTX, please add `set(USE_NVTX ON)` to file `build/config.cmake`. * Update huggingface_loader.py * [Serve] Separate callback invocation to another thread in AsyncEngine (#2046) This PR enhances the AsyncThreadEngine by separating the callback invocation to another thread, in order to reduce the CPU time overhead of invoking Python callback. * [LLaVa] Fix random token output after first sentence (#2048) Fix Llava random token after first '.' token Co-authored-by: Animesh Bohara * Auto updated submodule references * [Pass] Fix LiftGlobalBufferAlloc for proper GlobalVar struct info (#2053) This PR fixes the GlobalVar struct info mismatch issue cased by pass LiftGlobalBufferAlloc after a latest TVM commit. * Auto updated submodule references * [Serving] CLI Support for SERVE (#2014) This PR adds CLI support for serve. Usage: `mlc_llm serve [Model]` refer `mlc_llm serve -h` for more options Comments - Supports JIT compilation of Model lib - Added context manager to `ServerContext` class Co-authored-by: Ruihang Lai Co-authored-by: Shrey Gupta * [Pipeline] Insert hints to enable cuda graph symbolic capture (#2050) * [Pipeline] Add pass to insert hints to enable cuda graph symbolic capture * [Loader] Print message when multi-GPU loader is finished (#2051) * [Loader] Print message when multi-GPU loader is finished * Update multi_gpu_loader.cc * fix * [KVCache] Support matching arbitrary element offset for aux data (#2057) This PR enhances the TIR attention-related functions to support matching arbitrary element offests. This makes room for the KV cache to allocate a large array the all the auxiliary data and do slicing on it. This PR should affect nothing for the current codebase, given all the element offsets are zeros as of now. * [Serving] Support copy stream in LogitProcessor and GPUSampler (#2058) This PR introduces copy stream to LogitProcessor and GPUSampler for CUDA, so that auxiliary data can be copied on a separate stream and overlap with the computation time. * [SLM] Stablelm Multi-GPU support (#2052) This PR enables TP function of Stablelm model. * [KVCache] Introducing single page copy func for KV cache fork (#2060) This PR introduces the single page copy TIR function for KV cache. This function is helpful for sequence fork at specified positions. NOTE: this PR is a breaking change, so you will need to re-compile your model and update TVM or the MLC-AI pip package to the latest. Related PR: apache/tvm#16813 Co-authored-by: Yaxing Cai * [Python] Implement testing.DebugChat for end-to-end model debugging (#2056) * [Docs] Fix docs for python server and rest call (#2066) This PR updates the MLC serve documentation for server launching. * [CI] Enable submodule clone for WASM model compilation (#2068) The incoming WASM runtime requires 3rdparty for builds. This PR enables the submodule clone for WASM model compilation in CI. * [Serve] Fork sequence at specified positions (#2067) With PagedKVCache supporting fork at a specified position, this PR updates `Model` interface accordingly. The fork position defaults to -1, which means the last position. * [SLM] Add support for RWKV6 model (#1977) * [SLM]: Support for rwkv tokenizer * [SLM] RWKV6 World Support * [Quantization] Reorganize utils code in group_quantization (#2055) * [Serving] Bugfix for empty stop string (#2070) add check for empty stop string; fix Vanilla LM conversation template * [SLM] Internlm Multi-GPU support (#2072) This PR enables tensor parallelism support for InternLM model. * [WebGPU] Add mlc wasm runtime, support grammar in web (#2061) * [WebGPU] Add mlc wasm runtime, support grammar in web * Make in web for wasm ci * Fix wasm ci * Fix wasm ci * Change export library arg name * Move macro to cc instead of makefile * [Build] Use TVM_HOME environment variable (#2073) Prior to this commit, the `CMakeLists.txt` file checked a cmake `TVM_HOME` variable, but did not check the usual `TVM_HOME` environment variable. If this variable is set, it should be used. * [Serving] Support input chunking (#2069) This PR supports input chunking with regard to customized "prefill chunk size" (field `prefill_chunk_size` in `mlc-chat-config.json`). With this PR, we can now chunk a long input into multiples when there is an upper limit on the prefill chunk size. Only `TokenData` is supported for now. * [Docs] API Code Completion Guide (#2054) * Allow "mlc_llm --host" option to override host triple the model compi… (#2074) Allow "mlc_llm --host" option to override host triple the model compile to * [Web] Move prep emcc deps script to web folder (#2077) * [SLM] Qwen Multi-GPU support (#2075) * Fix mismatch of metadata func and global symbol (#2078) * Fix mismatch of metadata func and global symbol * Update estimate_memory_usage.py * [Disco] Set worker CPU affinity with env variable (#2042) This PR enables setting the CPU affinity of disco workers in MLC, following the support in apache/tvm#16807. The purpose is to try reduce the CPU core switch overhead brought to disco workers which may cause extra bubble times in disco workers before/during tasks. We use a macro `MLC_DISCO_WORKER_CPU_BINDING` to specify the CPU affinities of workers. This is by default not used. To enable it, you can run the command like ```shell MLC_DISCO_WORKER_CPU_BINDING=64,65,66,67 python some_mlc_app.py ``` to specify the four CPU core ids for the four workers. * [Quantization] Introduce PerTensor and F8 quantization (#2079) * [Quantization] Introduce PerTensor and F8 quantization * address comments * [Serving][Refactor] Rename AsyncThreadedEngine to ThreadedEngine (#2081) This PR renames the AsyncThreadedEngine to ThreadedEngine to prepare for follow up refactors of Python interface. Meanwhile, this PR exposes a creation function for AsyncThreadedEngine so that it can be further used by others, such as JSONFFIEngine. * [Serving] Add cuda profiling in benchmark test (#2084) * [Serving] Add cuda profiling in benchmark test * [Grammar] Fix broken grammar tests (#2083) This PR fixes some grammar parser tests that were broken. * [Serving][Fix] Fix chunked prefill condition (#2082) This PR fixes a bug when trying to chunk an input and do prefill. The stats prior ot this PR was wrong. * [Conversation] Fix RedPajama conversation template (#2087) As reported and discussed in #2086, this PR fixes the RedPajama template. * [Serving][Refactor] Python interface refactor (#2085) This PR is an initial major Python interface refactor of MLC Serve. With this PR, `mlc_llm.serve` in Python now exposes two engine classes: `AsyncEngine` and `Engine`. Both classes have two entrypoints, `chat_completion` and `completion` which conform to OpenAI Python API (reference: https://github.com/openai/openai-python). As the name suggested, `AsyncEngine` works asynchronously, and `Engine` works synchronously. It worths noting that the `Engine` since this PR is different from the `Engine` so far. The new `Engine` does not provide interfaces for batch generation. For robustness and correctness, the old `Engine` in Python is moved to `mlc_llm.serve.sync_engine.SyncEngine`. We do not directly expose this SyncEngine, and it now mainly serves testing and debug purposes. It is useful to check the correctness of new features, because of its simplicity. It keeps the low-level interface to directly invoke `step()` function of the engine, and also keeps the low-level batch generation interface. Our REST API entry points defined under `mlc_llm/serve/entrypoints/` are also refactored accordingly to adapt to the latest Python API in MLC Serve. In short, most of the logic in OpenAI API entry points are moved to Python API, which simplifies the implementation of entry points. Please note that this is the first (also the largest) planned refactor. We will follow up with some other refactors, which have smaller scopes compared with this PR. The planned refactors include: * provide submodule interface to align OpenAI Python package in https://github.com/openai/openai-python * refactor the constructor interface of `Engine`/`AsyncEngine` to align the MLC serve CLI interface. * [Serving] Separating ThreadedEngine creation and initialization (#2090) This PR separates the creation and initialization of ThreadedEngine for multi-threading use cases. So we can make sure that the ThreadedEngine instance is created before any other operations (such as initialization, running background loop, etc.). * [Serving] Enhance robustness with small KV capacity (#2091) This PR enhances the robustness, which had issue when the KV capacity is small. * [REST] Update REST API docs (#2092) This updates the rest docs to use `mlc_llm serve` and also adds a quick start section. * [DOCS] Clarify vulkan loader dependency (#2095) This PR clarifies the vulkan loader dependecy. Some system may not have the right vulkan loader and we need to install them via conda. * [SLM] Add support for Chatglm3 architecture (#2096) This pr enable Chatglm3 model. * [Quantization] Add OpenCL device (#2097) This PR adds OpenCL device for weight conversion. * [Serving] Support stream=True for Python API (#2098) The previous refactoring PR formalizes the MLC serve Python API but does not respect the `stream` flag properly: no matter if `stream` is True or False, the functions always work in a streaming style. This PR supports the non-stream case. * [Serving][Refactor] OpenAI API Python interface alignment (#2099) This PR aligns the Python API of chat completions and completions MLC serve with the OpenAI Python package https://github.com/openai/openai-python. Specifically, say we first create an engine or async engine, then we can use entrance `engine.chat.completions.create(...)` for chat completions. We will add more use examples in the codebase after another few refactors. * [DOC] fix small python env install error (#2102) Fixed one slight issue of tvm install: would require specify python=3.11 on the platform otherwise might encounter python not found error. * [JSONFFIEngine] Initial implementation of JSONFFIEngine (#2101) This PR introduces initial support for the JSONFFIEngine. The request is supposed to be a JSON string in the [Chat completion request body format](https://platform.openai.com/docs/api-reference/chat/create). The output (input to the callback function provided) is a list of JSON strings in the [Chat completion chunk object format](https://platform.openai.com/docs/api-reference/chat/streaming). There is still functionality to be added, which will be added in follow-up PRs. 1. Support for other input datatypes (image, etc.) 2. Applying conversation template to input 3. Function calling and tools support 4. Generation config parameters support 5. Independent text streamers for each request 6. logprobs support --- Co-authored-by: Ruihang Lai * [Model] Use tanh approximation of GeLU in Gemma MLP (#2106) This is in line with the implementation in the [transformers](https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma/modeling_gemma.py#L183) library. Also, the [gemma-1.1](https://huggingface.co/google/gemma-1.1-2b-it/blob/main/config.json#L10) model config. * Auto updated submodule references * [Quantization] Stricter checks for MoE gate (#2109) This PR strenthens the MoE gate checks to include checking number of experts, given the real MoE gate router layer's output feature number is the number of experts and is usually very small. This PR comes from a regression that there is a layer in RWKV6 that ends with name "gate" is not for MoE at all. * Auto updated submodule references * [LLaVa] Fix allowed text model value in config (#2062) * Llava support vicuna and mistral text models * Support f32 quantization * Lint fix * Use preset if transformers not installed * Rebase on main --------- Co-authored-by: Animesh Bohara * Auto updated submodule references * Revert "Allow "mlc_llm --host" option to override host triple the model compi…" (#2115) This reverts commit 12ca8fdbe2a24f43bbc72241a76735dbad8c2026. Co-authored-by: Mengshiun Yu * Revert "Auto updated submodule references" (#2117) This reverts commit c4169d8c8a4afedd06bc9d9b99c3aa65eee4a89e which causes CI broken. * [Metadata] Include picojson rather than forward declaring (#2118) This PR fixes the picojson uses in MLC that conflicts with the latest changes on the picojson side. * Auto updated submodule references * Auto updated submodule references * [Serving][Grammar] Porting the json schema converter from python to C++ (#2112) [Serve][Grammar] Porting the json schema converter from python to C++ This PR ports the json schema converter from python to C++. It defines the interface: ``` std::string JSONSchemaToEBNF( std::string schema, std::optional indent = std::nullopt, std::optional> separators = std::nullopt, bool strict_mode = true); ``` And uses it in BNFGrammar::FromSchema. This helps cases where python cannot be deployed. * [Model] Use R.topk/cumsum for mixtral (#2107) * Enable flashinfer when group_size == 6 (#2124) * [SpecDecode] Support Eagle in speculative decoding (#2080) 1. Add Eagle-Llama-7b-chat model support. 2. Add speculative decoding support with Eagle. * [Pass] Attach non-negative TIR var attributes (#2125) This PR attaches the attributes of `tir.non_negative_var` for memory planning. * [Serving][Refactor] Engine constructor interface refactor (#2126) This PR is a refactor of the engine's contructor interface and the serve CLI interface. This PR introduces the "mode" argument for engine, which has options "local", "interactive" and "server". The choice of mode will affect the automatically inferred value of `max_batch_size`, `max_total_sequence_length` and `prefill_chunk_size` (only effective when arguements are not specified. Once an argument is specified, we will not override it). For detailed specification of the mode, please check out the CLI help messages in `mlc_llm/help.py` or the engine constructor in `mlc_llm/serve/engine.py`. No matter which mode is chosen, we will print out the current mode and the values of these arguments, for peopple to understand the settings of the engine. We also provide hints on how to adjust the mode. For example, ``` [2024-04-12 16:12:26] INFO chat_module.py:379: Using model folder: /home/ruihang/Workspace/mlc-llm/dist/Llama-2-7b-chat-hf-q0f16-MLC [2024-04-12 16:12:26] INFO chat_module.py:380: Using mlc chat config: /home/ruihang/Workspace/mlc-llm/dist/Llama-2-7b-chat-hf-q0f16-MLC/mlc-chat-config.json [2024-04-12 16:12:26] INFO chat_module.py:529: Using library model: dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so [2024-04-12 16:12:26] INFO chat_module.py:379: Using model folder: /home/ruihang/Workspace/mlc-llm/dist/Llama-2-7b-chat-hf-q4f16_1-MLC [2024-04-12 16:12:26] INFO chat_module.py:380: Using mlc chat config: /home/ruihang/Workspace/mlc-llm/dist/Llama-2-7b-chat-hf-q4f16_1-MLC/mlc-chat-config.json [2024-04-12 16:12:26] INFO chat_module.py:529: Using library model: dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so [2024-04-12 16:12:29] INFO engine_base.py:382: Engine mode is "local". Max batch size is set to 4. Max KV cache token capacity is set to 4096. Prefill chunk size is set to 4096. [2024-04-12 16:12:29] INFO engine_base.py:387: Estimated total single GPU memory usage: 21543.74 MB (Parameters: 16467.64 MB. KVCache: 4450.07 MB. Temporary buffer: 626.03 MB). The actual usage might be slightly larger than the estimated number. [2024-04-12 16:12:29] INFO engine_base.py:398: Please switch to mode "server" if you want to use more GPU memory and support more concurrent requests. ``` After the refactor, we bring the speculative decoding to the serve CLI so that people can use multiple models and run speculative decoding with the server launched in CLI (which was not doable before). * [Serving] Revamp engine mode selection logging info (#2128) This PR revamps the logging info for engine mode selection to provide more detailed information and the rationale of different modes. * [SLM] Chatglm3 Multi-GPU support (#2123) This PR enables TP for Chatglm3 model. * [Serving] Fix support of large `n` under low max batch size (#2136) Prior to this PR, due to the improper prefill policy on `n` (parallel generation), the engine will loop forever when the a request has `n` larger than the maximum batch size that the engine can support. This PR fixes this issue by updating the prefill action, and with this PR, even the "interactive" engine mode can well support multiple parallel generation. After this fix, it is possible that a request require 10 parallel generation while the max batch size is 1. Given the shapes of temporary NDArrays in GPU sampler is determined by the max batch size, GPU sampler does not natively support sampling 10 tokens at a time. To approach this issue, this PR introduces chunking to GPU sampler. Therefore, in this particular case, the GPU sampler will have chunk size 1, and the 10 required samples will be processed by the GPU sampler one by one in order. Chunking is the minimum change we can do to support large `n`. * [Docs] Revamp landing page with Engine Python API and server (#2137) This PR revamps the landing documentation page. * The Python API panel is changed from showing ChatModule to showing Engine. * A new panel "REST Server" is added to show a quick start example of launching REST server and send request. * A "what to do next" section is introduced at the bottom of the landing page. Todo items for future PR: * add the page of Python API with Engine. * revamp weight conversion page. * revamp model library compilation page. * [Target] Update Target tags (#2141) The commit updates the target tags, in order to identify the different SoC hardware targets for further target-specific optimizations. Meanwhile, update the vulkan support for int64. * [Util] Support debug debug_compare (#2142) * [Minor][SpecInfer] Fix Optional FC Bias for Mixtral Eagle Model (#2146) * Add optional fc bias for mixtral. * Fix lint. * [Serving] fix hardcoded host and port in popen_server (#2147) * [Docs] Introductory tutorial (#2145) This PR updates the documentation with an introduction turorial. The landing page now directs to the quick start page and the tutorial. * [Serving] Support `DebugCallFuncOnAllAllWorker` and CUDA profiler (#2148) This PR adds a new function `DebugCallFuncOnAllAllWorker` which calls a global function of sigunature `[] -> None` on all distributed workers when tensor parallelism is enabled (or the local session itself if not enabled). As the name suggests, this function is only for the debug purpose, and we will not expose any public interface to invoke this function. This PR also introduces the global functions `"mlc.debug_cuda_profiler_start"` and `"mlc.debug_cuda_profiler_stop"`, which enables CUDA profiling when using PopenServer. * [DOCS] Update introduction (#2151) * [DOCS] Update introduction Some minor tweaks on the introduction doc * Update docs/get_started/introduction.rst Co-authored-by: Ruihang Lai --------- Co-authored-by: Ruihang Lai * [Serving][Python] Rename Engine to LLMEngine (#2152) We rename the public Python serve interface from `Engine` to `LLMEngine` (and from `AsyncEngine` to `AsyncLLMEngine` accordingly) for better class name clarity. This is because in cases people do wildcard import, in which case the name `Engine` itself does not convey enough meaning. * Auto updated submodule references * [Quantization] Add e4m3 mode and enable fp8 storage type (#2154) * [Quantization] Add e4m3 mode and enable fp8 storage type * add quantize linear flag * Revert "[Quantization] Add e4m3 mode and enable fp8 storage type" (#2158) Revert "[Quantization] Add e4m3 mode and enable fp8 storage type (#2154)" This reverts commit e9a4a0bf719a7c4fd42b438cf9e159a1e8d72590. * [Serving] EngineConfig refactor (#2159) This PR refactors EngineConfig for a cleaner interface of internal Engine constructor in MLC serve. This is a preparation step towards the engine reload/unload which will be introduced in follow-up PRs for JSONFFIEngine functionality on mobile and other platforms. * temporary hack for byoc --------- Co-authored-by: Ruihang Lai Co-authored-by: Bohan Hou Co-authored-by: Yixin Dong Co-authored-by: Git bot Co-authored-by: Charlie Ruan <53290280+CharlieFRuan@users.noreply.github.com> Co-authored-by: Eric Lunderberg Co-authored-by: Shushi Hong <820958424@qq.com> Co-authored-by: Egor Churaev Co-authored-by: Siyuan Feng Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Co-authored-by: Tianqi Chen Co-authored-by: Kartik Khandelwal Co-authored-by: Shrey Gupta Co-authored-by: Diego Cao <50705298+DiegoCao@users.noreply.github.com> Co-authored-by: David Pissarra <61968959+davidpissarra@users.noreply.github.com> Co-authored-by: Wuwei Lin Co-authored-by: Ricardo Lu <37237570+gesanqiu@users.noreply.github.com> Co-authored-by: Hongyi Jin Co-authored-by: Bohan Hou Co-authored-by: tqchen Co-authored-by: Rick Zhou Co-authored-by: Animesh Bohara Co-authored-by: Yong Wu Co-authored-by: Yong Wu Co-authored-by: Shrey Gupta <51860471+shreygupta2809@users.noreply.github.com> Co-authored-by: Yaxing Cai Co-authored-by: ZCHNO Co-authored-by: Andrew Co-authored-by: na20215 <78482004+na20215@users.noreply.github.com> Co-authored-by: Animesh Bohara Co-authored-by: Yogesh Garg Co-authored-by: Linyu Wu <95223577+Celve@users.noreply.github.com> Co-authored-by: Yu Xuanchi Co-authored-by: Mengshiun Yu Co-authored-by: Jeethu Rao Co-authored-by: Xiyou Zhou --- ci/task/pylint.sh | 1 + cpp/json_ffi/json_ffi_engine.cc | 204 +++ cpp/json_ffi/json_ffi_engine.h | 56 + cpp/json_ffi/openai_api_protocol.cc | 224 +++ cpp/json_ffi/openai_api_protocol.h | 168 ++ cpp/llm_chat.cc | 2 - cpp/metadata/json_parser.h | 49 + cpp/metadata/model.h | 8 +- cpp/serve/config.cc | 161 +- cpp/serve/config.h | 103 +- cpp/serve/engine.cc | 182 +-- cpp/serve/engine.h | 52 +- cpp/serve/engine_actions/action.h | 60 +- cpp/serve/engine_actions/batch_decode.cc | 2 + cpp/serve/engine_actions/batch_verify.cc | 16 +- cpp/serve/engine_actions/eagle_batch_draft.cc | 230 +++ .../engine_actions/eagle_batch_verify.cc | 364 +++++ .../eagle_new_request_prefill.cc | 601 +++++++ .../engine_actions/new_request_prefill.cc | 151 +- cpp/serve/function_table.cc | 71 +- cpp/serve/function_table.h | 13 +- cpp/serve/grammar/grammar.cc | 51 +- cpp/serve/grammar/grammar.h | 28 +- cpp/serve/grammar/grammar_parser.cc | 10 +- cpp/serve/grammar/grammar_parser.h | 4 +- cpp/serve/grammar/grammar_serializer.cc | 4 +- cpp/serve/grammar/grammar_serializer.h | 6 +- cpp/serve/grammar/json_schema_converter.cc | 987 ++++++++++++ cpp/serve/grammar/json_schema_converter.h | 44 + cpp/serve/logit_processor.cc | 4 +- cpp/serve/model.cc | 483 +++++- cpp/serve/model.h | 121 +- cpp/serve/request_state.cc | 12 +- cpp/serve/request_state.h | 9 +- cpp/serve/sampler/cpu_sampler.cc | 31 +- cpp/serve/sampler/gpu_sampler.cc | 56 +- cpp/serve/threaded_engine.cc | 262 +++ cpp/serve/threaded_engine.h | 75 + cpp/support/utils.h | 24 + docs/_static/img/project-workflow.svg | 1173 ++++++++++++++ docs/community/faq.rst | 2 +- docs/compilation/compile_models.rst | 2 +- docs/compilation/convert_weights.rst | 2 +- docs/compilation/get-vicuna-weight.rst | 68 - docs/conf.py | 2 - docs/deploy/javascript.rst | 2 +- docs/deploy/mlc_chat_config.rst | 210 +++ .../{python.rst => python_chat_module.rst} | 18 +- docs/deploy/python_engine.rst | 15 + docs/deploy/rest.rst | 214 ++- docs/get_started/introduction.rst | 319 ++++ docs/get_started/project_overview.rst | 4 +- docs/get_started/quick_start.rst | 190 +++ docs/index.rst | 143 +- docs/install/mlc_llm.rst | 9 +- docs/install/tvm.rst | 3 +- docs/prebuilt_models.rst | 4 +- examples/python/sample_mlc_engine.py | 17 + python/mlc_llm/__init__.py | 3 + python/mlc_llm/base.py | 19 + python/mlc_llm/cli/serve.py | 32 +- .../mlc_llm/compiler_pass/attach_sampler.py | 17 +- .../compiler_pass/attach_support_info.py | 5 +- .../dispatch_kv_cache_creation.py | 2 +- .../compiler_pass/estimate_memory_usage.py | 8 +- .../fuse_dequantize_matmul_ewise.py | 2 +- .../compiler_pass/lift_global_buffer_alloc.py | 165 +- python/mlc_llm/conversation_template.py | 22 +- python/mlc_llm/help.py | 50 + python/mlc_llm/interface/compiler_flags.py | 8 +- python/mlc_llm/interface/gen_config.py | 1 + python/mlc_llm/interface/serve.py | 38 +- python/mlc_llm/loader/huggingface_loader.py | 1 + python/mlc_llm/model/chatglm3/__init__.py | 0 .../mlc_llm/model/chatglm3/chatglm3_loader.py | 63 + .../mlc_llm/model/chatglm3/chatglm3_model.py | 438 +++++ .../model/chatglm3/chatglm3_quantization.py | 53 + python/mlc_llm/model/eagle/__init__.py | 0 python/mlc_llm/model/eagle/eagle_loader.py | 172 ++ python/mlc_llm/model/eagle/eagle_model.py | 244 +++ .../mlc_llm/model/eagle/eagle_quantization.py | 70 + python/mlc_llm/model/gemma/gemma_model.py | 4 +- python/mlc_llm/model/llava/llava_model.py | 118 +- python/mlc_llm/model/mixtral/mixtral_model.py | 2 + .../model/mixtral/mixtral_quantization.py | 25 +- python/mlc_llm/model/model.py | 32 + python/mlc_llm/model/model_preset.py | 37 + python/mlc_llm/op/attention.py | 4 +- python/mlc_llm/op/cutlass.py | 11 +- python/mlc_llm/op/moe_matmul.py | 159 +- python/mlc_llm/op/moe_misc.py | 17 +- .../mlc_llm/protocol/conversation_protocol.py | 129 +- python/mlc_llm/protocol/error_protocol.py | 34 + python/mlc_llm/protocol/protocol_utils.py | 10 - python/mlc_llm/serve/data.py | 57 +- python/mlc_llm/serve/engine_base.py | 1414 +++++++++++++++++ python/mlc_llm/serve/engine_utils.py | 97 ++ python/mlc_llm/serve/grammar.py | 50 +- python/mlc_llm/serve/server/popen_server.py | 53 +- python/mlc_llm/serve/server/server_context.py | 38 +- python/mlc_llm/serve/sync_engine.py | 360 +++++ python/mlc_llm/support/auto_target.py | 45 +- python/mlc_llm/testing/debug_chat.py | 6 +- python/mlc_llm/testing/debug_compare.py | 249 +++ .../python/integration/test_model_compile.py | 5 +- tests/python/json_ffi/test_json_ffi_engine.py | 307 ++++ tests/python/serve/evaluate_engine.py | 22 +- tests/python/serve/server/test_server.py | 58 +- tests/python/serve/test_grammar_parser.py | 27 +- .../test_grammar_state_matcher_custom.py | 2 +- .../serve/test_json_schema_converter.py | 125 +- tests/python/serve/test_serve_async_engine.py | 234 ++- .../serve/test_serve_async_engine_spec.py | 35 +- tests/python/serve/test_serve_engine.py | 529 +++--- .../python/serve/test_serve_engine_grammar.py | 26 +- tests/python/serve/test_serve_engine_image.py | 35 +- tests/python/serve/test_serve_engine_spec.py | 382 ++++- tests/python/serve/test_serve_sync_engine.py | 396 +++++ 118 files changed, 12306 insertions(+), 1567 deletions(-) create mode 100644 cpp/json_ffi/json_ffi_engine.cc create mode 100644 cpp/json_ffi/json_ffi_engine.h create mode 100644 cpp/json_ffi/openai_api_protocol.cc create mode 100644 cpp/json_ffi/openai_api_protocol.h create mode 100644 cpp/serve/engine_actions/eagle_batch_draft.cc create mode 100644 cpp/serve/engine_actions/eagle_batch_verify.cc create mode 100644 cpp/serve/engine_actions/eagle_new_request_prefill.cc create mode 100644 cpp/serve/grammar/json_schema_converter.cc create mode 100644 cpp/serve/grammar/json_schema_converter.h create mode 100644 cpp/serve/threaded_engine.cc create mode 100644 cpp/serve/threaded_engine.h create mode 100644 cpp/support/utils.h create mode 100644 docs/_static/img/project-workflow.svg delete mode 100644 docs/compilation/get-vicuna-weight.rst create mode 100644 docs/deploy/mlc_chat_config.rst rename docs/deploy/{python.rst => python_chat_module.rst} (96%) create mode 100644 docs/deploy/python_engine.rst create mode 100644 docs/get_started/introduction.rst create mode 100644 docs/get_started/quick_start.rst create mode 100644 examples/python/sample_mlc_engine.py create mode 100644 python/mlc_llm/model/chatglm3/__init__.py create mode 100644 python/mlc_llm/model/chatglm3/chatglm3_loader.py create mode 100644 python/mlc_llm/model/chatglm3/chatglm3_model.py create mode 100644 python/mlc_llm/model/chatglm3/chatglm3_quantization.py create mode 100644 python/mlc_llm/model/eagle/__init__.py create mode 100644 python/mlc_llm/model/eagle/eagle_loader.py create mode 100644 python/mlc_llm/model/eagle/eagle_model.py create mode 100644 python/mlc_llm/model/eagle/eagle_quantization.py create mode 100644 python/mlc_llm/protocol/error_protocol.py create mode 100644 python/mlc_llm/serve/engine_base.py create mode 100644 python/mlc_llm/serve/engine_utils.py create mode 100644 python/mlc_llm/serve/sync_engine.py create mode 100644 python/mlc_llm/testing/debug_compare.py create mode 100644 tests/python/json_ffi/test_json_ffi_engine.py create mode 100644 tests/python/serve/test_serve_sync_engine.py diff --git a/ci/task/pylint.sh b/ci/task/pylint.sh index c4abb81d90..849efe628e 100755 --- a/ci/task/pylint.sh +++ b/ci/task/pylint.sh @@ -8,6 +8,7 @@ export PYTHONPATH="./python":${PYTHONPATH:-""} # TVM Unity is a dependency to this testing pip install --quiet --pre -U -f https://mlc.ai/wheels mlc-ai-nightly +pip install --quiet --pre -U cuda-python pylint --jobs $NUM_THREADS ./python/ pylint --jobs $NUM_THREADS --recursive=y ./tests/python/ diff --git a/cpp/json_ffi/json_ffi_engine.cc b/cpp/json_ffi/json_ffi_engine.cc new file mode 100644 index 0000000000..b02a28ca89 --- /dev/null +++ b/cpp/json_ffi/json_ffi_engine.cc @@ -0,0 +1,204 @@ +#include "json_ffi_engine.h" + +#include +#include +#include + +namespace mlc { +namespace llm { +namespace json_ffi { + +using namespace tvm::runtime; + +JSONFFIEngine::JSONFFIEngine() { engine_ = serve::ThreadedEngine::Create(); } + +bool JSONFFIEngine::ChatCompletion(std::string request_json_str, std::string request_id) { + bool success = this->AddRequest(request_json_str, request_id); + if (!success) { + this->StreamBackError(request_id); + } + return success; +} + +void JSONFFIEngine::StreamBackError(std::string request_id) { + ChatCompletionMessage delta; + delta.content = std::vector>{ + {{"type", "text"}, {"text", this->err_}}}; + delta.role = Role::assistant; + + ChatCompletionStreamResponseChoice choice; + choice.finish_reason = FinishReason::error; + choice.index = 0; + choice.delta = delta; + + ChatCompletionStreamResponse response; + response.id = request_id; + response.choices = std::vector{choice}; + response.model = "json_ffi"; // TODO: Return model name from engine (or from args) + response.system_fingerprint = ""; + + this->request_stream_callback_(Array{picojson::value(response.ToJSON()).serialize()}); +} + +bool JSONFFIEngine::AddRequest(std::string request_json_str, std::string request_id) { + std::optional optional_request = + ChatCompletionRequest::FromJSON(request_json_str, &err_); + if (!optional_request.has_value()) { + return false; + } + ChatCompletionRequest request = optional_request.value(); + // Create Request + // TODO: Check if request_id is present already + + // inputs + // TODO: Apply conv template + Array inputs; + for (const auto& message : request.messages) { + if (message.content.has_value()) { + for (const auto& content : message.content.value()) { + if (content.find("type") == content.end()) { + err_ += "Content should have a type field"; + return false; + } + std::string type = content.at("type"); + if (type == "text") { + if (content.find("text") == content.end()) { + err_ += "Content should have a text field"; + return false; + } + std::string text = content.at("text"); + inputs.push_back(TextData(text)); + } else { + err_ += "Content type not supported"; + return false; + } + } + } + } + + // generation_cfg + Optional generation_cfg = GenerationConfig::FromJSON(request_json_str, &err_); + if (!generation_cfg.defined()) { + return false; + } + + Request engine_request(request_id, inputs, generation_cfg.value()); + this->engine_->AddRequest(engine_request); + + return true; +} + +bool JSONFFIEngine::Abort(std::string request_id) { + this->engine_->AbortRequest(request_id); + return true; +} + +std::string JSONFFIEngine::GetLastError() { return err_; } + +void JSONFFIEngine::ExitBackgroundLoop() { this->engine_->ExitBackgroundLoop(); } + +JSONFFIEngine::~JSONFFIEngine() { this->ExitBackgroundLoop(); } + +class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode { + public: + TVM_MODULE_VTABLE_BEGIN("mlc.json_ffi"); + TVM_MODULE_VTABLE_ENTRY("init_background_engine", &JSONFFIEngineImpl::InitBackgroundEngine); + TVM_MODULE_VTABLE_ENTRY("chat_completion", &JSONFFIEngineImpl::ChatCompletion); + TVM_MODULE_VTABLE_ENTRY("abort", &JSONFFIEngineImpl::Abort); + TVM_MODULE_VTABLE_ENTRY("get_last_error", &JSONFFIEngineImpl::GetLastError); + TVM_MODULE_VTABLE_ENTRY("run_background_loop", &JSONFFIEngineImpl::RunBackgroundLoop); + TVM_MODULE_VTABLE_ENTRY("run_background_stream_back_loop", + &JSONFFIEngineImpl::RunBackgroundStreamBackLoop); + TVM_MODULE_VTABLE_ENTRY("exit_background_loop", &JSONFFIEngineImpl::ExitBackgroundLoop); + TVM_MODULE_VTABLE_END(); + + void InitBackgroundEngine(EngineConfig engine_config, + Optional request_stream_callback, + Optional trace_recorder) { + this->streamer_ = TextStreamer(Tokenizer::FromPath(engine_config->model)); + + CHECK(request_stream_callback.defined()) + << "JSONFFIEngine requires request stream callback function, but it is not given."; + this->request_stream_callback_ = request_stream_callback.value(); + + auto frequest_stream_callback_wrapper = [this](TVMArgs args, TVMRetValue* ret) { + ICHECK_EQ(args.size(), 1); + Array delta_outputs = args[0]; + Array responses = this->GetResponseFromStreamOutput(delta_outputs); + this->request_stream_callback_(responses); + }; + + request_stream_callback = PackedFunc(frequest_stream_callback_wrapper); + this->engine_->InitBackgroundEngine( + std::move(engine_config), std::move(request_stream_callback), std::move(trace_recorder)); + } + + void RunBackgroundLoop() { this->engine_->RunBackgroundLoop(); } + + void RunBackgroundStreamBackLoop() { this->engine_->RunBackgroundStreamBackLoop(); } + + Array GetResponseFromStreamOutput(Array delta_outputs) { + std::unordered_map> response_map; + for (const auto& delta_output : delta_outputs) { + std::string request_id = delta_output->request_id; + if (response_map.find(request_id) == response_map.end()) { + response_map[request_id] = std::vector(); + } + ChatCompletionStreamResponseChoice choice; + + if (delta_output->group_finish_reason.size() != 1) { + // Only support n = 1 in ChatCompletionStreamResponse for now + this->err_ += "Group finish reason should have exactly one element"; + } + Optional finish_reason = delta_output->group_finish_reason[0]; + if (finish_reason.defined()) { + if (finish_reason.value() == "stop") { + choice.finish_reason = FinishReason::stop; + } else if (finish_reason.value() == "length") { + choice.finish_reason = FinishReason::length; + } else if (finish_reason.value() == "tool_calls") { + choice.finish_reason = FinishReason::tool_calls; + } else if (finish_reason.value() == "error") { + choice.finish_reason = FinishReason::error; + } + } else { + choice.finish_reason = std::nullopt; + } + + choice.index = response_map[request_id].size(); + + ChatCompletionMessage delta; + // Size of delta_output->group_delta_token_ids Array should be 1 + IntTuple delta_token_ids = delta_output->group_delta_token_ids[0]; + std::vector delta_token_ids_vec(delta_token_ids.begin(), delta_token_ids.end()); + delta.content = std::vector>(); + delta.content.value().push_back(std::unordered_map{ + {"type", "text"}, {"text", this->streamer_->Put(delta_token_ids_vec)}}); + + delta.role = Role::assistant; + + choice.delta = delta; + + response_map[request_id].push_back(choice); + } + + Array response_arr; + for (const auto& [request_id, choices] : response_map) { + ChatCompletionStreamResponse response; + response.id = request_id; + response.choices = choices; + response.model = "json_ffi"; // TODO: Return model name from engine (or from args) + response.system_fingerprint = ""; + response_arr.push_back(picojson::value(response.ToJSON()).serialize()); + } + return response_arr; + } +}; + +TVM_REGISTER_GLOBAL("mlc.json_ffi.CreateJSONFFIEngine").set_body_typed([]() { + return Module(make_object()); +}); + +} // namespace json_ffi +} // namespace llm +} // namespace mlc diff --git a/cpp/json_ffi/json_ffi_engine.h b/cpp/json_ffi/json_ffi_engine.h new file mode 100644 index 0000000000..83013b5876 --- /dev/null +++ b/cpp/json_ffi/json_ffi_engine.h @@ -0,0 +1,56 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file json_ffi/json_ffi_engine.h + * \brief The header of JSON FFI engine in MLC LLM. + */ +#ifndef MLC_LLM_JSON_FFI_JSON_FFI_ENGINE_H_ +#define MLC_LLM_JSON_FFI_JSON_FFI_ENGINE_H_ + +#include + +#include + +#include "../serve/threaded_engine.h" +#include "../streamer.h" +#include "openai_api_protocol.h" + +namespace mlc { +namespace llm { +namespace json_ffi { + +using namespace tvm::runtime; +using namespace mlc::llm::serve; + +/*! + * \brief // Todo: document this class, fields and member functions + */ +class JSONFFIEngine { + public: + JSONFFIEngine(); + + ~JSONFFIEngine(); + + bool ChatCompletion(std::string request_json_str, std::string request_id); + + bool AddRequest(std::string request_json_str, std::string request_id); + + void StreamBackError(std::string request_id); + + bool Abort(std::string request_id); + + std::string GetLastError(); + + void ExitBackgroundLoop(); + + protected: + std::unique_ptr engine_; + std::string err_; + PackedFunc request_stream_callback_; + TextStreamer streamer_; // TODO: Support "n", and support different streamers for each request +}; + +} // namespace json_ffi +} // namespace llm +} // namespace mlc + +#endif // MLC_LLM_JSON_FFI_JSON_FFI_ENGINE_H_ diff --git a/cpp/json_ffi/openai_api_protocol.cc b/cpp/json_ffi/openai_api_protocol.cc new file mode 100644 index 0000000000..41378fc3e0 --- /dev/null +++ b/cpp/json_ffi/openai_api_protocol.cc @@ -0,0 +1,224 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file json_ffi/openai_api_protocol.cc + * \brief The implementation of OpenAI API Protocol in MLC LLM. + */ +#include "openai_api_protocol.h" + +#include "../metadata/json_parser.h" + +namespace mlc { +namespace llm { +namespace json_ffi { + +std::optional ChatCompletionMessage::FromJSON(const picojson::value& json, + std::string* err) { + if (!json.is()) { + *err += "Input is not a valid JSON object"; + return std::nullopt; + } + picojson::object json_obj = json.get(); + + ChatCompletionMessage message; + + // content + picojson::array content_arr; + if (!json::ParseJSONField(json_obj, "content", content_arr, err, true)) { + return std::nullopt; + } + std::vector > content; + for (const auto& item : content_arr) { + if (!item.is()) { + *err += "Content item is not an object"; + return std::nullopt; + } + std::unordered_map item_map; + picojson::object item_obj = item.get(); + for (picojson::value::object::const_iterator i = item_obj.begin(); i != item_obj.end(); ++i) { + item_map[i->first] = i->second.to_str(); + } + content.push_back(item_map); + } + message.content = content; + + // role + std::string role_str; + if (!json::ParseJSONField(json_obj, "role", role_str, err, true)) { + return std::nullopt; + } + if (role_str == "system") { + message.role = Role::system; + } else if (role_str == "user") { + message.role = Role::user; + } else if (role_str == "assistant") { + message.role = Role::assistant; + } else if (role_str == "tool") { + message.role = Role::tool; + } else { + *err += "Invalid role"; + return std::nullopt; + } + + // name + std::string name; + if (json::ParseJSONField(json_obj, "name", name, err, false)) { + message.name = name; + } + + // TODO: tool_calls and tool_call_id + + return message; +} + +std::optional ChatCompletionRequest::FromJSON( + const picojson::object& json_obj, std::string* err) { + ChatCompletionRequest request; + + // messages + picojson::array messages_arr; + if (!json::ParseJSONField(json_obj, "messages", messages_arr, err, true)) { + return std::nullopt; + } + std::vector messages; + for (const auto& item : messages_arr) { + std::optional message = ChatCompletionMessage::FromJSON(item, err); + if (!message.has_value()) { + return std::nullopt; + } + messages.push_back(message.value()); + } + request.messages = messages; + + // model + std::string model; + if (!json::ParseJSONField(json_obj, "model", model, err, true)) { + return std::nullopt; + } + request.model = model; + + // frequency_penalty + double frequency_penalty; + if (json::ParseJSONField(json_obj, "frequency_penalty", frequency_penalty, err, false)) { + request.frequency_penalty = frequency_penalty; + } + + // presence_penalty + double presence_penalty; + if (json::ParseJSONField(json_obj, "presence_penalty", presence_penalty, err, false)) { + request.presence_penalty = presence_penalty; + } + + // TODO: Other parameters + + return request; +} + +std::optional ChatCompletionRequest::FromJSON(const std::string& json_str, + std::string* err) { + std::optional json_obj = json::LoadJSONFromString(json_str, err); + if (!json_obj.has_value()) { + return std::nullopt; + } + return ChatCompletionRequest::FromJSON(json_obj.value(), err); +} + +picojson::object ChatCompletionMessage::ToJSON() { + picojson::object obj; + picojson::array content_arr; + for (const auto& item : this->content.value()) { + picojson::object item_obj; + for (const auto& pair : item) { + item_obj[pair.first] = picojson::value(pair.second); + } + content_arr.push_back(picojson::value(item_obj)); + } + obj["content"] = picojson::value(content_arr); + if (this->role == Role::system) { + obj["role"] = picojson::value("system"); + } else if (this->role == Role::user) { + obj["role"] = picojson::value("user"); + } else if (this->role == Role::assistant) { + obj["role"] = picojson::value("assistant"); + } else if (this->role == Role::tool) { + obj["role"] = picojson::value("tool"); + } + if (name.has_value()) { + obj["name"] = picojson::value(name.value()); + } + return obj; +} + +picojson::object ChatCompletionResponseChoice::ToJSON() { + picojson::object obj; + if (!this->finish_reason.has_value()) { + obj["finish_reason"] = picojson::value(); + } else { + if (this->finish_reason == FinishReason::stop) { + obj["finish_reason"] = picojson::value("stop"); + } else if (this->finish_reason == FinishReason::length) { + obj["finish_reason"] = picojson::value("length"); + } else if (this->finish_reason == FinishReason::tool_calls) { + obj["finish_reason"] = picojson::value("tool_calls"); + } else if (this->finish_reason == FinishReason::error) { + obj["finish_reason"] = picojson::value("error"); + } + } + obj["index"] = picojson::value((int64_t)this->index); + obj["message"] = picojson::value(this->message.ToJSON()); + return obj; +} + +picojson::object ChatCompletionStreamResponseChoice::ToJSON() { + picojson::object obj; + if (!this->finish_reason.has_value()) { + obj["finish_reason"] = picojson::value(); + } else { + if (this->finish_reason.value() == FinishReason::stop) { + obj["finish_reason"] = picojson::value("stop"); + } else if (this->finish_reason.value() == FinishReason::length) { + obj["finish_reason"] = picojson::value("length"); + } else if (this->finish_reason.value() == FinishReason::tool_calls) { + obj["finish_reason"] = picojson::value("tool_calls"); + } else if (this->finish_reason.value() == FinishReason::error) { + obj["finish_reason"] = picojson::value("error"); + } + } + + obj["index"] = picojson::value((int64_t)this->index); + obj["delta"] = picojson::value(this->delta.ToJSON()); + return obj; +} + +picojson::object ChatCompletionResponse::ToJSON() { + picojson::object obj; + obj["id"] = picojson::value(this->id); + picojson::array choices_arr; + for (auto& choice : this->choices) { + choices_arr.push_back(picojson::value(choice.ToJSON())); + } + obj["choices"] = picojson::value(choices_arr); + obj["created"] = picojson::value((int64_t)this->created); + obj["model"] = picojson::value(this->model); + obj["system_fingerprint"] = picojson::value(this->system_fingerprint); + obj["object"] = picojson::value(this->object); + return obj; +} + +picojson::object ChatCompletionStreamResponse::ToJSON() { + picojson::object obj; + obj["id"] = picojson::value(this->id); + picojson::array choices_arr; + for (auto& choice : this->choices) { + choices_arr.push_back(picojson::value(choice.ToJSON())); + } + obj["choices"] = picojson::value(choices_arr); + obj["created"] = picojson::value((int64_t)this->created); + obj["model"] = picojson::value(this->model); + obj["system_fingerprint"] = picojson::value(this->system_fingerprint); + obj["object"] = picojson::value(this->object); + return obj; +} + +} // namespace json_ffi +} // namespace llm +} // namespace mlc diff --git a/cpp/json_ffi/openai_api_protocol.h b/cpp/json_ffi/openai_api_protocol.h new file mode 100644 index 0000000000..1579b5f337 --- /dev/null +++ b/cpp/json_ffi/openai_api_protocol.h @@ -0,0 +1,168 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file json_ffi/openai_api_protocol.h + * \brief The header of OpenAI API Protocol in MLC LLM. + */ +#ifndef MLC_LLM_JSON_FFI_OPENAI_API_PROTOCOL_H +#define MLC_LLM_JSON_FFI_OPENAI_API_PROTOCOL_H + +#include +#include +#include +#include +#include + +#include "picojson.h" + +namespace mlc { +namespace llm { +namespace json_ffi { + +enum class Role { system, user, assistant, tool }; +enum class Type { text, json_object, function }; +enum class FinishReason { stop, length, tool_calls, error }; + +// TODO: Implement the following class +class ChatFunction { + public: + std::optional description = std::nullopt; + std::string name; + std::unordered_map + parameters; // Assuming parameters are string key-value pairs + + static std::optional FromJSON(const picojson::value& json, std::string* err); +}; + +// TODO: Implement the following class +class ChatTool { + public: + Type type = Type::function; + ChatFunction function; + + static std::optional FromJSON(const picojson::value& json, std::string* err); +}; + +// TODO: Implement the following class +class ChatFunctionCall { + public: + std::string name; + std::optional> arguments = + std::nullopt; // Assuming arguments are string key-value pairs +}; + +// TODO: Implement the following class +class ChatToolCall { + public: + std::string id; // TODO: python code initializes this to an random string + Type type = Type::function; + ChatFunctionCall function; +}; + +class ChatCompletionMessage { + public: + std::optional>> content = + std::nullopt; // Assuming content is a list of string key-value pairs + Role role; + std::optional name = std::nullopt; + std::optional> tool_calls = std::nullopt; // TODO: Implement this + std::optional tool_call_id = std::nullopt; // TODO: Implement this + + static std::optional FromJSON(const picojson::value& json, + std::string* err); + picojson::object ToJSON(); +}; + +class RequestResponseFormat { + public: + Type type = Type::text; + std::optional json_schema = std::nullopt; +}; + +class ChatCompletionRequest { + public: + std::vector messages; + std::string model; + double frequency_penalty = 0.0; + double presence_penalty = 0.0; + bool logprobs = false; + int top_logprobs = 0; + std::optional> logit_bias = std::nullopt; + std::optional max_tokens = std::nullopt; + int n = 1; + std::optional seed = std::nullopt; + std::optional> stop = std::nullopt; + bool stream = false; + double temperature = 1.0; + double top_p = 1.0; + std::optional> tools = std::nullopt; + std::optional tool_choice = std::nullopt; + std::optional user = std::nullopt; + bool ignore_eos = false; + // RequestResponseFormat response_format; //TODO: implement this + + /*! + * \brief Create a ChatCompletionRequest instance from the given JSON object. + * When creation fails, errors are dumped to the input error string, and nullopt is returned. + */ + static std::optional FromJSON(const picojson::object& json_obj, + std::string* err); + /*! + * \brief Parse and create a ChatCompletionRequest instance from the given JSON string. + * When creation fails, errors are dumped to the input error string, and nullopt is returned. + */ + static std::optional FromJSON(const std::string& json_str, + std::string* err); + + // TODO: check_penalty_range, check_logit_bias, check_logprobs +}; + +class ChatCompletionResponseChoice { + public: + std::optional finish_reason; + int index = 0; + ChatCompletionMessage message; + // TODO: logprobs + + picojson::object ToJSON(); +}; + +class ChatCompletionStreamResponseChoice { + public: + std::optional finish_reason; + int index = 0; + ChatCompletionMessage delta; + // TODO: logprobs + + picojson::object ToJSON(); +}; + +class ChatCompletionResponse { + public: + std::string id; + std::vector choices; + int created = static_cast(std::time(nullptr)); + std::string model; + std::string system_fingerprint; + std::string object = "chat.completion"; + // TODO: usage_info + + picojson::object ToJSON(); +}; + +class ChatCompletionStreamResponse { + public: + std::string id; + std::vector choices; + int created = static_cast(std::time(nullptr)); + std::string model; + std::string system_fingerprint; + std::string object = "chat.completion.chunk"; + + picojson::object ToJSON(); +}; + +} // namespace json_ffi +} // namespace llm +} // namespace mlc + +#endif // MLC_LLM_JSON_FFI_OPENAI_API_PROTOCOL_H diff --git a/cpp/llm_chat.cc b/cpp/llm_chat.cc index 8cadbe8df4..9485ccad02 100644 --- a/cpp/llm_chat.cc +++ b/cpp/llm_chat.cc @@ -1618,8 +1618,6 @@ class LLMChat { NDArray logits_on_cpu_{nullptr}; // pre-allocated ndarray for decode function's input tokens DRef input_tokens_decode_{nullptr}; - // KV cache config - serve::KVCacheConfig kv_cache_config_{nullptr}; }; /*! diff --git a/cpp/metadata/json_parser.h b/cpp/metadata/json_parser.h index 14f622f2c8..f6ff10e1ac 100644 --- a/cpp/metadata/json_parser.h +++ b/cpp/metadata/json_parser.h @@ -10,6 +10,8 @@ #include #include +#include + namespace mlc { namespace llm { namespace json { @@ -20,6 +22,53 @@ namespace json { * \return The parsed JSON object. */ picojson::object ParseToJsonObject(const std::string& json_str); + +// Todo(mlc-team): implement "Result" class for JSON parsing with error collection. +/*! + * \brief Parse input JSON string into JSON dict. + * Any error will be dumped to the input error string. + */ +inline std::optional LoadJSONFromString(const std::string& json_str, + std::string* err) { + ICHECK_NOTNULL(err); + picojson::value json; + *err = picojson::parse(json, json_str); + if (!json.is()) { + *err += "The input JSON string does not correspond to a JSON dict."; + return std::nullopt; + } + return json.get(); +} + +/*! + * \brief // Todo(mlc-team): document this function. + * \tparam T + * \param json_obj + * \param field + * \param value + * \param err + * \param required + * \return + */ +template +inline bool ParseJSONField(const picojson::object& json_obj, const std::string& field, T& value, + std::string* err, bool required) { + // T can be int, double, bool, string, picojson::array + if (json_obj.count(field)) { + if (!json_obj.at(field).is()) { + *err += "Field " + field + " is not of type " + typeid(T).name() + "\n"; + return false; + } + value = json_obj.at(field).get(); + } else { + if (required) { + *err += "Field " + field + " is required\n"; + return false; + } + } + return true; +} + /*! * \brief Lookup a JSON object by a key, and convert it to a given type. * \param json The JSON object to look up. diff --git a/cpp/metadata/model.h b/cpp/metadata/model.h index 7a3224d28e..2472cb7d36 100644 --- a/cpp/metadata/model.h +++ b/cpp/metadata/model.h @@ -5,6 +5,7 @@ #ifndef MLC_LLM_CPP_MODEL_METADATA_H_ #define MLC_LLM_CPP_MODEL_METADATA_H_ +#include #include #include #include @@ -12,13 +13,6 @@ #include -// Forward declare picojson's value, object and array -namespace picojson { -class value; -using object = std::unordered_map; -using array = std::vector; -} // namespace picojson - namespace mlc { namespace llm { diff --git a/cpp/serve/config.cc b/cpp/serve/config.cc index 3465de402e..5d647ec532 100644 --- a/cpp/serve/config.cc +++ b/cpp/serve/config.cc @@ -5,9 +5,12 @@ #include "config.h" #include +#include #include +#include "../json_ffi/openai_api_protocol.h" +#include "../metadata/json_parser.h" #include "data.h" namespace mlc { @@ -158,6 +161,24 @@ GenerationConfig::GenerationConfig(String config_json_str) { data_ = std::move(n); } +Optional GenerationConfig::FromJSON(const std::string& json_str, + std::string* err) { + std::optional json_obj = json::LoadJSONFromString(json_str, err); + if (!err->empty() || !json_obj.has_value()) { + return NullOpt; + } + ObjectPtr n = make_object(); + + // TODO(mlc-team): Pass the parameters from `json_obj` to `n`. + + if (!err->empty()) { + return NullOpt; + } + GenerationConfig gen_config; + gen_config.data_ = std::move(n); + return gen_config; +} + String GenerationConfigNode::AsJSONString() const { picojson::object config; config["n"] = picojson::value(static_cast(this->n)); @@ -202,123 +223,43 @@ String GenerationConfigNode::AsJSONString() const { return picojson::value(config).serialize(true); } -/****************** KVCacheConfig ******************/ - -TVM_REGISTER_OBJECT_TYPE(KVCacheConfigNode); - -KVCacheConfig::KVCacheConfig(int page_size, int max_num_sequence, int max_total_sequence_length, - int prefill_chunk_size) { - ObjectPtr n = make_object(); - n->page_size = page_size; +/****************** EngineConfig ******************/ + +TVM_REGISTER_OBJECT_TYPE(EngineConfigNode); + +EngineConfig::EngineConfig(String model, String model_lib_path, Array additional_models, + Array additional_model_lib_paths, DLDevice device, + int kv_cache_page_size, int max_num_sequence, + int max_total_sequence_length, int max_single_sequence_length, + int prefill_chunk_size, SpeculativeMode speculative_mode, + int spec_draft_length) { + ObjectPtr n = make_object(); + n->model = std::move(model); + n->model_lib_path = std::move(model_lib_path); + n->additional_models = std::move(additional_models); + n->additional_model_lib_paths = std::move(additional_model_lib_paths); + n->device = device; + n->kv_cache_page_size = kv_cache_page_size; n->max_num_sequence = max_num_sequence; n->max_total_sequence_length = max_total_sequence_length; + n->max_single_sequence_length = max_single_sequence_length; n->prefill_chunk_size = prefill_chunk_size; - data_ = std::move(n); -} - -KVCacheConfig::KVCacheConfig(const std::string& config_str, int max_single_sequence_length) { - int page_size; - int max_total_sequence_length; - int max_num_sequence = -1; - int prefill_chunk_size; - - picojson::value config_json; - std::string err = picojson::parse(config_json, config_str); - if (!err.empty()) { - LOG(FATAL) << err; - } - - // Get json fields. - picojson::object config = config_json.get(); - if (config.count("page_size")) { - CHECK(config["page_size"].is()); - page_size = config["page_size"].get(); - CHECK_EQ(page_size, 16) << "KV cache page size other than 16 is not supported."; - } else { - LOG(FATAL) << "Key \"page_size\" not found."; - } - if (config.count("max_total_sequence_length")) { - CHECK(config["max_total_sequence_length"].is()); - max_total_sequence_length = config["max_total_sequence_length"].get(); - } else { - LOG(FATAL) << "Key \"max_total_sequence_length\" not found."; - } - if (config.count("prefill_chunk_size")) { - CHECK(config["prefill_chunk_size"].is()); - prefill_chunk_size = config["prefill_chunk_size"].get(); - } else { - LOG(FATAL) << "Key \"prefill_chunk_size\" not found."; - } - if (config.count("max_num_sequence")) { - CHECK(config["max_num_sequence"].is()); - max_num_sequence = config["max_num_sequence"].get(); - CHECK_GT(max_num_sequence, 0) << "Max number of sequence should be positive."; - } else { - LOG(FATAL) << "Key \"max_num_sequence\" not found."; - } - - ObjectPtr n = make_object(); - n->page_size = page_size; - n->max_num_sequence = max_num_sequence; - n->max_total_sequence_length = max_total_sequence_length; - n->prefill_chunk_size = prefill_chunk_size; - data_ = std::move(n); -} - -String KVCacheConfigNode::AsJSONString() const { - picojson::object config; - config["page_size"] = picojson::value(static_cast(this->page_size)); - config["max_num_sequence"] = picojson::value(static_cast(this->max_num_sequence)); - config["max_total_sequence_length"] = - picojson::value(static_cast(this->max_total_sequence_length)); - config["prefill_chunk_size"] = picojson::value(static_cast(this->prefill_chunk_size)); - return picojson::value(config).serialize(true); -} - -/****************** EngineMode ******************/ - -TVM_REGISTER_OBJECT_TYPE(EngineModeNode); - -EngineMode::EngineMode(bool enable_speculative, int spec_draft_length) { - ObjectPtr n = make_object(); - n->enable_speculative = enable_speculative; - n->spec_draft_length = spec_draft_length; - data_ = std::move(n); -} - -EngineMode::EngineMode(const std::string& config_str) { - bool enable_speculative = false; - int spec_draft_length = 4; - - picojson::value config_json; - std::string err = picojson::parse(config_json, config_str); - if (!err.empty()) { - LOG(FATAL) << err; - } - - // Get json fields. - picojson::object config = config_json.get(); - if (config.count("enable_speculative")) { - CHECK(config["enable_speculative"].is()); - enable_speculative = config["enable_speculative"].get(); - } - if (config.count("spec_draft_length")) { - CHECK(config["spec_draft_length"].is()); - spec_draft_length = config["spec_draft_length"].get(); - } - - ObjectPtr n = make_object(); - n->enable_speculative = enable_speculative; n->spec_draft_length = spec_draft_length; + n->speculative_mode = speculative_mode; data_ = std::move(n); } -String EngineModeNode::AsJSONString() const { - picojson::object config; - config["enable_speculative"] = picojson::value(static_cast(this->enable_speculative)); - config["spec_draft_length"] = picojson::value(static_cast(this->spec_draft_length)); - return picojson::value(config).serialize(true); -} +TVM_REGISTER_GLOBAL("mlc.serve.EngineConfig") + .set_body_typed([](String model, String model_lib_path, Array additional_models, + Array additional_model_lib_paths, DLDevice device, + int kv_cache_page_size, int max_num_sequence, int max_total_sequence_length, + int max_single_sequence_length, int prefill_chunk_size, int speculative_mode, + int spec_draft_length) { + return EngineConfig(std::move(model), std::move(model_lib_path), std::move(additional_models), + std::move(additional_model_lib_paths), device, kv_cache_page_size, + max_num_sequence, max_total_sequence_length, max_single_sequence_length, + prefill_chunk_size, SpeculativeMode(speculative_mode), spec_draft_length); + }); } // namespace serve } // namespace llm diff --git a/cpp/serve/config.h b/cpp/serve/config.h index c406e55125..404566fe2c 100644 --- a/cpp/serve/config.h +++ b/cpp/serve/config.h @@ -9,6 +9,8 @@ #include #include +#include + namespace mlc { namespace llm { namespace serve { @@ -57,62 +59,89 @@ class GenerationConfig : public ObjectRef { public: explicit GenerationConfig(String config_json_str); + /*! + * \brief Parse the generation config from the given JSON string. + * When parsing fails, errors are dumped to the input error string, and NullOpt is returned. + */ + static Optional FromJSON(const std::string& json_str, std::string* err); + TVM_DEFINE_OBJECT_REF_METHODS(GenerationConfig, ObjectRef, GenerationConfigNode); }; -/****************** KV Cache config ******************/ +/****************** Engine config ******************/ -/*! \brief The configuration of paged KV cache. */ -class KVCacheConfigNode : public Object { +/*! \brief The speculative mode. */ +enum class SpeculativeMode : int { + /*! \brief Disable speculative decoding. */ + kDisable = 0, + /*! \brief The normal speculative decoding (small draft) mode. */ + kSmallDraft = 1, + /*! \brief The eagle-style speculative decoding. */ + kEagle = 2, +}; + +/*! \brief The configuration of engine execution config. */ +class EngineConfigNode : public Object { public: - int page_size; + /*************** Models ***************/ + + /*! \brief The path to the model directory. */ + String model; + /*! \brief The path to the model library. */ + String model_lib_path; + /*! \brief The path to the additional models' directories. */ + Array additional_models; + /*! \brief The path to the additional models' libraries. */ + Array additional_model_lib_paths; + + /*************** Device ***************/ + + /*! \brief The device where the models run. */ + DLDevice device; + + /*************** KV cache config and engine capacities ***************/ + + /*! \brief The number of consecutive tokens handled in each page in paged KV cache. */ + int kv_cache_page_size; + /*! + * \brief The maximum number of sequences that are allowed to be + * processed by the KV cache at any time. + */ int max_num_sequence; + /*! \brief The maximum length allowed for a single sequence in the engine. */ int max_total_sequence_length; + /*! + * \brief The maximum total number of tokens whose KV data are allowed + * to exist in the KV cache at any time. + */ + int max_single_sequence_length; + /*! \brief The maximum total sequence length in a prefill. */ int prefill_chunk_size; - String AsJSONString() const; + /*************** Speculative decoding ***************/ - static constexpr const char* _type_key = "mlc.serve.KVCacheConfig"; - static constexpr const bool _type_has_method_sequal_reduce = false; - static constexpr const bool _type_has_method_shash_reduce = false; - TVM_DECLARE_BASE_OBJECT_INFO(KVCacheConfigNode, Object); -}; - -class KVCacheConfig : public ObjectRef { - public: - explicit KVCacheConfig(int page_size, int max_num_sequence, int max_total_sequence_length, - int prefill_chunk_size); - - explicit KVCacheConfig(const std::string& config_str, int max_single_sequence_length); - - TVM_DEFINE_OBJECT_REF_METHODS(KVCacheConfig, ObjectRef, KVCacheConfigNode); -}; - -/****************** Engine Mode ******************/ - -/*! \brief The configuration of engine execution mode. */ -class EngineModeNode : public Object { - public: - /* Whether the speculative decoding mode is enabled */ - bool enable_speculative; - /* The number of tokens to generate in speculative proposal (draft) */ - int spec_draft_length; + /*! \brief The speculative mode. */ + SpeculativeMode speculative_mode; + /*! \brief The number of tokens to generate in speculative proposal (draft). */ + int spec_draft_length = 4; String AsJSONString() const; - static constexpr const char* _type_key = "mlc.serve.EngineMode"; + static constexpr const char* _type_key = "mlc.serve.EngineConfig"; static constexpr const bool _type_has_method_sequal_reduce = false; static constexpr const bool _type_has_method_shash_reduce = false; - TVM_DECLARE_BASE_OBJECT_INFO(EngineModeNode, Object); + TVM_DECLARE_BASE_OBJECT_INFO(EngineConfigNode, Object); }; -class EngineMode : public ObjectRef { +class EngineConfig : public ObjectRef { public: - explicit EngineMode(bool enable_speculative, int spec_draft_length); - - explicit EngineMode(const std::string& config_str); + explicit EngineConfig(String model, String model_lib_path, Array additional_models, + Array additional_model_lib_paths, DLDevice device, + int kv_cache_page_size, int max_num_sequence, int max_total_sequence_length, + int max_single_sequence_length, int prefill_chunk_size, + SpeculativeMode speculative_mode, int spec_draft_length); - TVM_DEFINE_OBJECT_REF_METHODS(EngineMode, ObjectRef, EngineModeNode); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(EngineConfig, ObjectRef, EngineConfigNode); }; } // namespace serve diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index abb5c7b6c7..85d1c66c2d 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -44,80 +44,101 @@ class EngineImpl : public Engine { public: /********************** Engine Management **********************/ - explicit EngineImpl(int max_single_sequence_length, const String& tokenizer_path, - const String& kv_cache_config_json_str, const String& engine_mode_json_str, - Optional request_stream_callback, - Optional trace_recorder, - const std::vector>& model_infos) { - CHECK_GE(model_infos.size(), 1) << "ValueError: No model is provided in the engine."; + explicit EngineImpl(EngineConfig engine_config, Optional request_stream_callback, + Optional trace_recorder) { // Step 1. Initialize metadata and singleton states inside the engine this->estate_->Reset(); // Being "-1" means there is no limit on single sequence length. - this->max_single_sequence_length_ = max_single_sequence_length != -1 - ? max_single_sequence_length - : std::numeric_limits::max(); - this->kv_cache_config_ = KVCacheConfig(kv_cache_config_json_str, max_single_sequence_length); - this->engine_mode_ = EngineMode(engine_mode_json_str); + if (engine_config->max_single_sequence_length == -1) { + engine_config->max_single_sequence_length = std::numeric_limits::max(); + } this->request_stream_callback_ = std::move(request_stream_callback); this->trace_recorder_ = trace_recorder; - this->tokenizer_ = Tokenizer::FromPath(tokenizer_path); + this->tokenizer_ = Tokenizer::FromPath(engine_config->model); this->token_table_ = tokenizer_->TokenTable(); this->grammar_init_context_storage_ = GrammarInitContextStorage(this->token_table_); // Step 2. Initialize each model independently. // Create the logit processor and sampler. this->models_.clear(); this->model_workspaces_.clear(); - for (const auto& model_info : model_infos) { - TVMArgValue model_lib = std::get<0>(model_info); - String model_path = std::get<1>(model_info); - DLDevice device = std::get<2>(model_info); - Model model = Model::Create(model_lib, std::move(model_path), device, - kv_cache_config_->max_num_sequence, + + auto f_create_model = [this, &engine_config, &trace_recorder](const String& model_path, + const String& model_lib_path) { + Model model = Model::Create(model_lib_path, std::move(model_path), engine_config->device, + engine_config->max_num_sequence, /*trace_enabled=*/trace_recorder.defined()); - model->CreateKVCache(this->kv_cache_config_); - CHECK_GE(model->GetMaxWindowSize(), this->max_single_sequence_length_) + model->CreateKVCache(engine_config->kv_cache_page_size, engine_config->max_num_sequence, + engine_config->max_total_sequence_length, + engine_config->prefill_chunk_size); + CHECK_GE(model->GetMaxWindowSize(), engine_config->max_single_sequence_length) << "The window size of the model, " << model->GetMaxWindowSize() << ", is smaller than the pre-defined max single sequence length, " - << this->max_single_sequence_length_; + << engine_config->max_single_sequence_length; this->models_.push_back(model); - this->model_workspaces_.push_back(ModelWorkspace{model->AllocEmbeddingTensor()}); + this->model_workspaces_.push_back( + ModelWorkspace{model->AllocEmbeddingTensor(), model->AllocHiddenStatesTensor()}); + }; + + f_create_model(engine_config->model, engine_config->model_lib_path); + CHECK_EQ(engine_config->additional_models.size(), + engine_config->additional_model_lib_paths.size()) + << "The additional model and lib path list has mismatched size."; + for (int i = 0; i < static_cast(engine_config->additional_models.size()); ++i) { + f_create_model(engine_config->additional_models[i], + engine_config->additional_model_lib_paths[i]); } - int max_num_tokens = kv_cache_config_->max_num_sequence; - if (engine_mode_->enable_speculative) { - max_num_tokens *= engine_mode_->spec_draft_length; + + int max_num_tokens = engine_config->max_num_sequence; + if (engine_config->speculative_mode != SpeculativeMode::kDisable) { + max_num_tokens *= engine_config->spec_draft_length; } LogitProcessor logit_processor = this->models_[0]->CreateLogitProcessor(max_num_tokens, trace_recorder); Sampler sampler = this->models_[0]->CreateSampler( max_num_tokens, static_cast(this->models_.size()), trace_recorder); // Step 3. Initialize engine actions that represent state transitions. - if (this->engine_mode_->enable_speculative) { + if (engine_config->speculative_mode != SpeculativeMode::kDisable) { // Speculative decoding is only possible for more than one model. ICHECK_GT(this->models_.size(), 1U); - this->actions_ = { - EngineAction::NewRequestPrefill(this->models_, // - logit_processor, // - sampler, // - this->model_workspaces_, // - this->kv_cache_config_, // - this->engine_mode_, // - this->trace_recorder_), - EngineAction::BatchDraft(this->models_, logit_processor, sampler, this->trace_recorder_, - this->engine_mode_->spec_draft_length), - EngineAction::BatchVerify(this->models_, logit_processor, sampler, this->kv_cache_config_, - this->trace_recorder_)}; + switch (engine_config->speculative_mode) { + case SpeculativeMode::kEagle: + this->actions_ = { + EngineAction::EagleNewRequestPrefill(this->models_, // + logit_processor, // + sampler, // + this->model_workspaces_, // + engine_config, // + this->trace_recorder_), + EngineAction::EagleBatchDraft(this->models_, logit_processor, sampler, + this->model_workspaces_, this->trace_recorder_), + EngineAction::EagleBatchVerify(this->models_, logit_processor, sampler, + this->model_workspaces_, engine_config, + this->trace_recorder_)}; + break; + default: + this->actions_ = {EngineAction::NewRequestPrefill(this->models_, // + logit_processor, // + sampler, // + this->model_workspaces_, // + engine_config, // + this->trace_recorder_), + EngineAction::BatchDraft(this->models_, logit_processor, sampler, + this->trace_recorder_), + EngineAction::BatchVerify(this->models_, logit_processor, sampler, + engine_config, this->trace_recorder_)}; + } } else { this->actions_ = {EngineAction::NewRequestPrefill(this->models_, // logit_processor, // sampler, // this->model_workspaces_, // - this->kv_cache_config_, // - this->engine_mode_, // + engine_config, // this->trace_recorder_), EngineAction::BatchDecode(this->models_, logit_processor, sampler, this->trace_recorder_)}; } // Step 4. Automatically set the threading backend max concurrency. + this->engine_config_ = engine_config; SetThreadMaxConcurrency(); } @@ -146,7 +167,7 @@ class EngineImpl : public Engine { request = Request::FromUntokenized(request, tokenizer_); ICHECK_NE(request->input_total_length, -1); - if (request->input_total_length >= max_single_sequence_length_) { + if (request->input_total_length >= engine_config_->max_single_sequence_length) { // If the request input length exceeds the maximum allowed single sequence length, // invoke callback and do not process the request. Array output{RequestStreamOutput( @@ -230,7 +251,8 @@ class EngineImpl : public Engine { Array processed_requests = action->Step(estate_); if (!processed_requests.empty()) { ActionStepPostProcess(processed_requests, estate_, models_, tokenizer_, - request_stream_callback_.value(), max_single_sequence_length_); + request_stream_callback_.value(), + engine_config_->max_single_sequence_length); return; } } @@ -239,6 +261,13 @@ class EngineImpl : public Engine { "action (e.g. prefill, decode, etc.) but it does not."; } + /************** Debug/Profile **************/ + + void DebugCallFuncOnAllAllWorker(const String& func_name) final { + CHECK(!models_.empty()) << "There is no model running in Engine."; + models_[0]->DebugCallFuncOnAllAllWorker(func_name); + } + private: /*! \brief Set the maximum threading backend concurrency. */ void SetThreadMaxConcurrency() { @@ -247,8 +276,8 @@ class EngineImpl : public Engine { host_cpu_usage += model->EstimateHostCPURequirement(); } int max_concurrency = tvm::runtime::threading::MaxConcurrency(); - tvm::runtime::threading::SetMaxConcurrency(std::min( - std::max(max_concurrency - host_cpu_usage, 1), kv_cache_config_->max_num_sequence)); + tvm::runtime::threading::SetMaxConcurrency( + std::min(std::max(max_concurrency - host_cpu_usage, 1), engine_config_->max_num_sequence)); } /*! \brief Create a grammar init context according to the response format. If the response format @@ -268,9 +297,7 @@ class EngineImpl : public Engine { // Engine state, managing requests and request states. EngineState estate_; // Configurations and singletons - KVCacheConfig kv_cache_config_; - EngineMode engine_mode_; - int max_single_sequence_length_; + EngineConfig engine_config_; Tokenizer tokenizer_; std::vector token_table_; // Helper to get the grammar init context for requests. @@ -287,14 +314,11 @@ class EngineImpl : public Engine { Optional trace_recorder_; }; -std::unique_ptr Engine::Create( - int max_single_sequence_length, const String& tokenizer_path, - const String& kv_cache_config_json_str, const String& engine_mode_json_str, - Optional request_stream_callback, Optional trace_recorder, - const std::vector>& model_infos) { - return std::make_unique( - max_single_sequence_length, tokenizer_path, kv_cache_config_json_str, engine_mode_json_str, - request_stream_callback, std::move(trace_recorder), model_infos); +std::unique_ptr Engine::Create(EngineConfig engine_config, + Optional request_stream_callback, + Optional trace_recorder) { + return std::make_unique(std::move(engine_config), std::move(request_stream_callback), + std::move(trace_recorder)); } /*! \brief Clear global memory manager */ @@ -305,48 +329,10 @@ void ClearGlobalMemoryManager() { (*f)(); } -std::unique_ptr CreateEnginePacked(TVMArgs args) { - ClearGlobalMemoryManager(); - const int num_non_model_args = 6; - const int num_model_args = 4; - int num_models = (args.size() - num_non_model_args) / num_model_args; - int max_single_sequence_length; - std::string tokenizer_path; - std::string kv_cache_config_json_str; - std::string engine_mode_json_str; - Optional request_stream_callback; - Optional trace_recorder; - std::vector> model_infos; - model_infos.reserve(num_models); - try { - CHECK_LE(num_models * num_model_args + num_non_model_args, args.size()) - << "Incorrect number of arguments."; - max_single_sequence_length = args.At(0); - tokenizer_path = args.At(1); - kv_cache_config_json_str = args.At(2); - engine_mode_json_str = args.At(3); - request_stream_callback = args.At>(4); - trace_recorder = args.At>(5); - for (int i = 0; i < num_models; ++i) { - TVMArgValue model_lib = args[i * num_model_args + num_non_model_args]; - std::string model_path = args.At(i * num_model_args + num_non_model_args + 1); - DLDeviceType device_type = - static_cast(args.At(i * num_model_args + num_non_model_args + 2)); - int device_id = args.At(i * num_model_args + num_non_model_args + 3); - model_infos.emplace_back(model_lib, model_path, DLDevice{device_type, device_id}); - } - } catch (const dmlc::Error& e) { - LOG(FATAL) << "ValueError: " << e.what() << kEngineCreationErrorMessage; - } - return Engine::Create(max_single_sequence_length, tokenizer_path, kv_cache_config_json_str, - engine_mode_json_str, request_stream_callback, std::move(trace_recorder), - model_infos); -} - class EngineModule : public ModuleNode { public: TVM_MODULE_VTABLE_BEGIN("mlc.serve.engine"); - TVM_MODULE_VTABLE_ENTRY_PACKED("init", &EngineModule::InitPacked); + TVM_MODULE_VTABLE_ENTRY("init", &EngineModule::Init); TVM_MODULE_VTABLE_ENTRY("add_request", &EngineModule::AddRequest); TVM_MODULE_VTABLE_ENTRY("abort_request", &EngineModule::Abort); TVM_MODULE_VTABLE_ENTRY("step", &EngineModule::Step); @@ -356,8 +342,12 @@ class EngineModule : public ModuleNode { TVM_MODULE_VTABLE_ENTRY("set_request_stream_callback", &EngineModule::SetRequestStreamCallback); TVM_MODULE_VTABLE_END(); - void InitPacked(TVMArgs args, TVMRetValue* rv) { this->engine_ = CreateEnginePacked(args); } - + /*! \brief Initialize the engine with config and other fields. */ + void Init(EngineConfig engine_config, Optional request_stream_callback, + Optional trace_recorder) { + this->engine_ = Engine::Create(std::move(engine_config), std::move(request_stream_callback), + std::move(trace_recorder)); + } /*! \brief Construct an EngineModule. */ static tvm::runtime::Module Create() { return Module(make_object()); } /*! \brief Redirection to `Engine::AddRequest`. */ diff --git a/cpp/serve/engine.h b/cpp/serve/engine.h index 9ff38bdc42..fc5e4205ae 100644 --- a/cpp/serve/engine.h +++ b/cpp/serve/engine.h @@ -11,6 +11,7 @@ #include "data.h" #include "event_trace_recorder.h" #include "request.h" +#include "request_state.h" namespace mlc { namespace llm { @@ -49,26 +50,14 @@ class Engine { /*! * \brief Create an engine in unique pointer. - * \param max_single_sequence_length The maximum allowed single - * sequence length supported by the engine. - * \param tokenizer_path The tokenizer path on disk. - * \param kv_cache_config_json_str The KV cache config in JSON string. - * \param engine_mode_json_str The Engine execution mode in JSON string. - * \param request_stream_callback The request stream callback function to - * stream back generated output for requests. + * \param engine_config The engine config. + * \param request_stream_callback The request stream callback function to. * \param trace_recorder Event trace recorder for requests. - * \param model_infos The model info tuples. Each tuple contains - * - the model library, which might be a path to the binary file or - * an executable module that is pre-loaded, - * - the path to the model weight parameters, - * - the device to run the model on. * \return The created Engine in pointer. */ - static std::unique_ptr Create( - int max_single_sequence_length, const String& tokenizer_path, - const String& kv_cache_config_json_str, const String& engine_mode_json_str, - Optional request_stream_callback, Optional trace_recorder, - const std::vector>& model_infos); + static std::unique_ptr Create(EngineConfig engine_config, + Optional request_stream_callback, + Optional trace_recorder); /*! \brief Reset the engine, clean up all running data and statistics. */ virtual void Reset() = 0; @@ -106,31 +95,12 @@ class Engine { * generation results for those finished requests. */ virtual void Step() = 0; -}; -/*! - * \brief Create an Engine from packed arguments in TVMArgs. - * \param args The arguments of engine construction. - * \return The constructed engine in unique pointer. - */ -std::unique_ptr CreateEnginePacked(TVMArgs args); - -constexpr const char* kEngineCreationErrorMessage = - "With `n` models, engine initialization " - "takes (6 + 4 * n) arguments. The first 6 arguments should be: " - "1) (int) maximum length of a sequence, which must be equal or smaller than the context " - "window size of each model; " - "2) (string) path to tokenizer configuration files, which in MLC LLM, usually in a model " - "weights directory; " - "3) (string) JSON configuration for the KVCache; " - "4) (string) JSON mode for Engine;" - "5) (packed function, optional) global request stream callback function. " - "6) (EventTraceRecorder, optional) the event trace recorder for requests." - "The following (4 * n) arguments, 4 for each model, should be: " - "1) (tvm.runtime.Module) The model library loaded into TVM's RelaxVM; " - "2) (string) Model path which includes weights and mlc-chat-config.json; " - "3) (int, enum DLDeviceType) Device type, e.g. CUDA, ROCm, etc; " - "4) (int) Device id, i.e. the ordinal index of the device that exists locally."; + /************** Debug/Profile **************/ + + /*! \brief Call the given global function on all workers. Only for debug purpose. */ + virtual void DebugCallFuncOnAllAllWorker(const String& func_name) = 0; +}; } // namespace serve } // namespace llm diff --git a/cpp/serve/engine_actions/action.h b/cpp/serve/engine_actions/action.h index e355168365..79359c5741 100644 --- a/cpp/serve/engine_actions/action.h +++ b/cpp/serve/engine_actions/action.h @@ -56,16 +56,31 @@ class EngineAction : public ObjectRef { * \param logit_processor The logit processor. * \param sampler The sampler to sample new tokens. * \param model_workspaces The workspace of each model. - * \param kv_cache_config The KV cache config to help decide prefill is doable. - * \param engine_mode The engine operation mode. + * \param engine_config The engine config. * \param trace_recorder The event trace recorder for requests. * \return The created action object. */ static EngineAction NewRequestPrefill(Array models, LogitProcessor logit_processor, Sampler sampler, std::vector model_workspaces, - KVCacheConfig kv_cache_config, EngineMode engine_mode, + EngineConfig engine_config, Optional trace_recorder); + /*! + * \brief Create the action that prefills requests in the `waiting_queue` + * of the engine state. + * \param models The models to run prefill in. + * \param logit_processor The logit processor. + * \param sampler The sampler to sample new tokens. + * \param model_workspaces The workspace of each model. + * \param engine_config The engine config. + * \param trace_recorder The event trace recorder for requests. + * \return The created action object. + */ + static EngineAction EagleNewRequestPrefill(Array models, LogitProcessor logit_processor, + Sampler sampler, + std::vector model_workspaces, + EngineConfig engine_config, + Optional trace_recorder); /*! * \brief Create the action that runs one-step decode for requests in the * `running_queue` of engine state. Preempt low-priority requests @@ -97,6 +112,23 @@ class EngineAction : public ObjectRef { Sampler sampler, Optional trace_recorder, int draft_length = 4); + /*! + * \brief Create the action that runs one-step speculative draft proposal for + * requests in the `running_queue` of engine state. Preempt low-priority requests + * accordingly when it is impossible to decode all the running requests. + * \param models The model to run decode in. When there are multiple + * models, the `Step` function of the created action will not take effect. + * \param sampler The sampler to sample new tokens. + * \param model_workspaces The workspace of each model. + * \param trace_recorder The event trace recorder for requests. + * \param draft_length The number of draft proposal rounds. + * \return The created action object. + */ + static EngineAction EagleBatchDraft(Array models, LogitProcessor logit_processor, + Sampler sampler, std::vector model_workspaces, + Optional trace_recorder, + int draft_length = 4); + /*! * \brief Create the action that runs one-step speculative verification for requests in the * `running_queue` of engine state. Preempt low-priority requests @@ -104,14 +136,32 @@ class EngineAction : public ObjectRef { * \param models The model to run decode in. When there are multiple * models, the `Step` function of the created action will not take effect. * \param sampler The sampler to sample new tokens. - * \param kv_cache_config The KV cache config to help decide verify is doable. + * \param engine_config The engine config. * \param trace_recorder The event trace recorder for requests. * \return The created action object. */ static EngineAction BatchVerify(Array models, LogitProcessor logit_processor, - Sampler sampler, KVCacheConfig kv_cache_config, + Sampler sampler, EngineConfig engine_config, Optional trace_recorder); + /*! + * \brief Create the action that runs one-step speculative verification for requests in the + * `running_queue` of engine state. Preempt low-priority requests + * accordingly when it is impossible to decode all the running requests. + * \param models The model to run decode in. When there are multiple + * models, the `Step` function of the created action will not take effect. + * \param sampler The sampler to sample new tokens. + * \param model_workspaces The workspace of each model. + * \param engine_config The engine config. + * \param trace_recorder The event trace recorder for requests. + * \return The created action object. + */ + static EngineAction EagleBatchVerify(Array models, LogitProcessor logit_processor, + Sampler sampler, + std::vector model_workspaces, + EngineConfig engine_config, + Optional trace_recorder); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(EngineAction, ObjectRef, EngineActionObj); }; diff --git a/cpp/serve/engine_actions/batch_decode.cc b/cpp/serve/engine_actions/batch_decode.cc index fc830a21ee..94e441279a 100644 --- a/cpp/serve/engine_actions/batch_decode.cc +++ b/cpp/serve/engine_actions/batch_decode.cc @@ -59,6 +59,8 @@ class BatchDecodeActionObj : public EngineActionObj { // NOTE: Right now we only support decode all the running request states at a time. int num_rsentries = running_rsentries.size(); + ICHECK_GT(num_rsentries, 0) + << "There should be at least one request state entry that can run decode"; // Collect // - the last committed token, // - the request id, diff --git a/cpp/serve/engine_actions/batch_verify.cc b/cpp/serve/engine_actions/batch_verify.cc index 9270b6d284..6f38292ba3 100644 --- a/cpp/serve/engine_actions/batch_verify.cc +++ b/cpp/serve/engine_actions/batch_verify.cc @@ -27,12 +27,12 @@ namespace serve { class BatchVerifyActionObj : public EngineActionObj { public: explicit BatchVerifyActionObj(Array models, LogitProcessor logit_processor, - Sampler sampler, KVCacheConfig kv_cache_config, + Sampler sampler, EngineConfig engine_config, Optional trace_recorder) : models_(std::move(models)), logit_processor_(std::move(logit_processor)), sampler_(std::move(sampler)), - kv_cache_config_(std::move(kv_cache_config)), + engine_config_(std::move(engine_config)), trace_recorder_(std::move(trace_recorder)), rng_(RandomGenerator::GetInstance()) {} @@ -182,8 +182,8 @@ class BatchVerifyActionObj : public EngineActionObj { num_page_requirement.reserve(running_rsentries.size()); for (const RequestStateEntry& rsentry : running_rsentries) { int draft_length = rsentry->mstates[draft_model_id_]->draft_output_tokens.size(); - int num_require_pages = - (draft_length + kv_cache_config_->page_size - 1) / kv_cache_config_->page_size; + int num_require_pages = (draft_length + engine_config_->kv_cache_page_size - 1) / + engine_config_->kv_cache_page_size; draft_lengths.push_back(draft_length); num_page_requirement.push_back(num_require_pages); total_draft_length += draft_length; @@ -218,8 +218,8 @@ class BatchVerifyActionObj : public EngineActionObj { LogitProcessor logit_processor_; /*! \brief The sampler to sample new tokens. */ Sampler sampler_; - /*! \brief The kv cache config. */ - KVCacheConfig kv_cache_config_; + /*! \brief The engine config. */ + EngineConfig engine_config_; /*! \brief Event trace recorder. */ Optional trace_recorder_; /*! \brief Random number generator. */ @@ -231,10 +231,10 @@ class BatchVerifyActionObj : public EngineActionObj { }; EngineAction EngineAction::BatchVerify(Array models, LogitProcessor logit_processor, - Sampler sampler, KVCacheConfig kv_cache_config, + Sampler sampler, EngineConfig engine_config, Optional trace_recorder) { return EngineAction(make_object( - std::move(models), std::move(logit_processor), std::move(sampler), std::move(kv_cache_config), + std::move(models), std::move(logit_processor), std::move(sampler), std::move(engine_config), std::move(trace_recorder))); } diff --git a/cpp/serve/engine_actions/eagle_batch_draft.cc b/cpp/serve/engine_actions/eagle_batch_draft.cc new file mode 100644 index 0000000000..50393c38a2 --- /dev/null +++ b/cpp/serve/engine_actions/eagle_batch_draft.cc @@ -0,0 +1,230 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file serve/engine_actions/eagle_batch_draft.cc + */ + +#include + +#include "../config.h" +#include "../model.h" +#include "../sampler/sampler.h" +#include "action.h" +#include "action_commons.h" + +namespace mlc { +namespace llm { +namespace serve { + +/*! + * \brief The action that runs draft proposal for requests in the + * `running_queue` of engine state. Preempt low-priority requests + * accordingly when it is impossible to decode all the running requests. + */ +class EagleBatchDraftActionObj : public EngineActionObj { + public: + explicit EagleBatchDraftActionObj(Array models, LogitProcessor logit_processor, + Sampler sampler, std::vector model_workspaces, + Optional trace_recorder, int draft_length) + : models_(std::move(models)), + logit_processor_(std::move(logit_processor)), + sampler_(std::move(sampler)), + model_workspaces_(std::move(model_workspaces)), + trace_recorder_(std::move(trace_recorder)), + draft_length_(draft_length) { + ICHECK_GT(draft_length_, 0); + } + + Array Step(EngineState estate) final { + // - Only run spec decode when there are two models (llm+ssm) and >=1 running requests. + if (models_.size() != 2 || estate->running_queue.empty()) { + return {}; + } + + // Preempt request state entries when decode cannot apply. + std::vector running_rsentries = GetRunningRequestStateEntries(estate); + while (!CanDecode(running_rsentries.size())) { + RequestStateEntry preempted = + PreemptLastRunningRequestStateEntry(estate, models_, trace_recorder_); + if (preempted.same_as(running_rsentries.back())) { + running_rsentries.pop_back(); + } + } + + auto tstart = std::chrono::high_resolution_clock::now(); + + int num_rsentries = running_rsentries.size(); + Array request_ids; + std::vector request_internal_ids; + Array generation_cfg; + std::vector rngs; + request_ids.reserve(num_rsentries); + request_internal_ids.reserve(num_rsentries); + generation_cfg.reserve(num_rsentries); + for (const RequestStateEntry& rsentry : running_rsentries) { + request_ids.push_back(rsentry->request->id); + request_internal_ids.push_back(rsentry->mstates[0]->internal_id); + generation_cfg.push_back(rsentry->request->generation_cfg); + rngs.push_back(&rsentry->rng); + } + + // The first model doesn't get involved in draft proposal. + for (int model_id = 1; model_id < static_cast(models_.size()); ++model_id) { + // Collect + // - the last committed token, + // - the request model state + // of each request. + std::vector input_tokens; + Array mstates; + input_tokens.reserve(num_rsentries); + mstates.reserve(num_rsentries); + for (const RequestStateEntry& rsentry : running_rsentries) { + mstates.push_back(rsentry->mstates[model_id]); + } + // draft_length_ rounds of draft proposal. + NDArray hidden_states_nd{nullptr}; + ObjectRef last_hidden_states{nullptr}; + ObjectRef hidden_states = model_workspaces_[model_id].hidden_states; + // Concat last hidden_states + std::vector previous_hidden_on_device; + for (int i = 0; i < num_rsentries; ++i) { + previous_hidden_on_device.push_back(mstates[i]->draft_last_hidden_on_device.back()); + } + hidden_states_nd = + models_[model_id]->ConcatLastHidden(previous_hidden_on_device, &hidden_states); + ICHECK_EQ(hidden_states_nd->ndim, 2); + ICHECK_EQ(hidden_states_nd->shape[0], num_rsentries); + hidden_states_nd = hidden_states_nd.CreateView( + {hidden_states_nd->shape[0], 1, hidden_states_nd->shape[1]}, hidden_states_nd->dtype); + last_hidden_states = hidden_states_nd; + // The first draft token has been generated in prefill/verify stage + for (int draft_id = 1; draft_id < draft_length_; ++draft_id) { + // prepare new input tokens + input_tokens.clear(); + for (int i = 0; i < num_rsentries; ++i) { + ICHECK(!mstates[i]->draft_output_tokens.empty()); + input_tokens.push_back(mstates[i]->draft_output_tokens.back().sampled_token_id.first); + } + + // - Compute embeddings. + RECORD_EVENT(trace_recorder_, request_ids, "start proposal embedding"); + ObjectRef embeddings = + models_[model_id]->TokenEmbed({IntTuple{input_tokens.begin(), input_tokens.end()}}); + RECORD_EVENT(trace_recorder_, request_ids, "finish proposal embedding"); + + // - Invoke model decode. + RECORD_EVENT(trace_recorder_, request_ids, "start proposal decode"); + ObjectRef fused_hidden_states = models_[model_id]->FuseEmbedHidden( + embeddings, last_hidden_states, /*batch_size*/ num_rsentries, /*seq_len*/ 1); + hidden_states_nd = + models_[model_id]->BatchDecodeToLastHidden(fused_hidden_states, request_internal_ids); + last_hidden_states = hidden_states_nd; + NDArray logits; + if (models_[model_id]->CanGetLogits()) { + logits = models_[model_id]->GetLogits(hidden_states_nd, /*batch_size*/ num_rsentries, + /*seq_len*/ 1); + } else { + // - Use base model's head. + logits = + models_[0]->GetLogits(hidden_states_nd, /*batch_size*/ num_rsentries, /*seq_len*/ 1); + } + RECORD_EVENT(trace_recorder_, request_ids, "finish proposal decode"); + ICHECK_EQ(logits->ndim, 3); + ICHECK_EQ(logits->shape[0], num_rsentries); + ICHECK_EQ(logits->shape[1], 1); + + // - Update logits. + logits = logits.CreateView({num_rsentries, logits->shape[2]}, logits->dtype); + logit_processor_->InplaceUpdateLogits(logits, generation_cfg, mstates, request_ids); + + // - Compute probability distributions. + NDArray probs_on_device = + logit_processor_->ComputeProbsFromLogits(logits, generation_cfg, request_ids); + + // - Sample tokens. + // Fill range [0, num_rsentries) into `sample_indices`. + std::vector sample_indices(num_rsentries); + std::iota(sample_indices.begin(), sample_indices.end(), 0); + std::vector prob_dist; + std::vector sample_results = sampler_->BatchSampleTokens( + probs_on_device, sample_indices, request_ids, generation_cfg, rngs, &prob_dist); + ICHECK_EQ(sample_results.size(), num_rsentries); + + // - Add draft token to the state. + for (int i = 0; i < num_rsentries; ++i) { + // - Slice hidden_states_for_sample + NDArray last_hidden_on_device = GetTokenHidden(hidden_states_nd, i); + CHECK(i < static_cast(prob_dist.size())); + CHECK(prob_dist[i].defined()); + mstates[i]->AddDraftToken(sample_results[i], prob_dist[i], last_hidden_on_device); + estate->stats.total_draft_length += 1; + } + } + } + + auto tend = std::chrono::high_resolution_clock::now(); + estate->stats.engine_total_decode_time += static_cast((tend - tstart).count()) / 1e9; + + return {}; + } + + private: + /*! \brief Check if the input requests can be decoded under conditions. */ + bool CanDecode(int num_rsentries) { + // The first model is not involved in draft proposal. + for (int model_id = 1; model_id < static_cast(models_.size()); ++model_id) { + // Check if the model has enough available pages. + int num_available_pages = models_[model_id]->GetNumAvailablePages(); + if (num_rsentries > num_available_pages) { + return false; + } + } + return true; + } + + /*! + * \brief Get one item from a hidden_states array, which corresponds to the last token. + * \param hidden_states The hidden_states of all the tokens. + * \param token_pos The desired token position in the sequence. + * \return The desired token's hidden_states + */ + NDArray GetTokenHidden(NDArray hidden_states, int token_pos) { + ICHECK_EQ(hidden_states->ndim, 3); + NDArray last_hidden_on_device = + NDArray::Empty({hidden_states->shape[2]}, hidden_states->dtype, hidden_states->device); + + int64_t ndata = hidden_states->shape[2]; + const int16_t* __restrict p_hidden = + static_cast(__builtin_assume_aligned(hidden_states->data, 2)) + + (token_pos * ndata); + + last_hidden_on_device.CopyFromBytes(p_hidden, ndata * sizeof(int16_t)); + return last_hidden_on_device; + } + + /*! \brief The model to run draft generation in speculative decoding. */ + Array models_; + /*! \brief The logit processor. */ + LogitProcessor logit_processor_; + /*! \brief The sampler to sample new tokens. */ + Sampler sampler_; + /*! \brief Workspace of each model. */ + std::vector model_workspaces_; + /*! \brief Event trace recorder. */ + Optional trace_recorder_; + /*! \brief Draft proposal length */ + int draft_length_; +}; + +EngineAction EngineAction::EagleBatchDraft(Array models, LogitProcessor logit_processor, + Sampler sampler, + std::vector model_workspaces, + Optional trace_recorder, + int draft_length) { + return EngineAction(make_object( + std::move(models), std::move(logit_processor), std::move(sampler), + std::move(model_workspaces), std::move(trace_recorder), draft_length)); +} + +} // namespace serve +} // namespace llm +} // namespace mlc diff --git a/cpp/serve/engine_actions/eagle_batch_verify.cc b/cpp/serve/engine_actions/eagle_batch_verify.cc new file mode 100644 index 0000000000..043f68b9c2 --- /dev/null +++ b/cpp/serve/engine_actions/eagle_batch_verify.cc @@ -0,0 +1,364 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file serve/engine_actions/eagle_batch_verify.cc + */ + +#include + +#include +#include +#include + +#include "../../random.h" +#include "../config.h" +#include "../model.h" +#include "../sampler/sampler.h" +#include "action.h" +#include "action_commons.h" + +namespace mlc { +namespace llm { +namespace serve { + +/*! + * \brief The action that runs verification for requests in the + * `running_queue` of engine state. Preempt low-priority requests + * accordingly when it is impossible to decode all the running requests. + */ +class EagleBatchVerifyActionObj : public EngineActionObj { + public: + explicit EagleBatchVerifyActionObj(Array models, LogitProcessor logit_processor, + Sampler sampler, std::vector model_workspaces, + EngineConfig engine_config, + Optional trace_recorder) + : models_(std::move(models)), + logit_processor_(std::move(logit_processor)), + sampler_(std::move(sampler)), + model_workspaces_(std::move(model_workspaces)), + engine_config_(std::move(engine_config)), + trace_recorder_(std::move(trace_recorder)), + rng_(RandomGenerator::GetInstance()) {} + + Array Step(EngineState estate) final { + // - Only run spec decode when there are two models (llm+ssm) and >=1 running requests. + if (models_.size() != 2 || estate->running_queue.empty()) { + return {}; + } + + const auto& [rsentries, draft_lengths, total_draft_length] = GetDraftsToVerify(estate); + ICHECK_EQ(rsentries.size(), draft_lengths.size()); + if (rsentries.empty()) { + return {}; + } + + int num_rsentries = rsentries.size(); + Array request_ids = + rsentries.Map([](const RequestStateEntry& rstate) { return rstate->request->id; }); + auto tstart = std::chrono::high_resolution_clock::now(); + + // - Get embedding and run verify. + std::vector request_internal_ids; + std::vector all_tokens_to_verify; + Array verify_request_mstates; + Array generation_cfg; + std::vector rngs; + std::vector> draft_output_tokens; + std::vector> draft_output_prob_dist; + request_internal_ids.reserve(num_rsentries); + all_tokens_to_verify.reserve(total_draft_length); + verify_request_mstates.reserve(num_rsentries); + rngs.reserve(num_rsentries); + generation_cfg.reserve(num_rsentries); + draft_output_tokens.reserve(num_rsentries); + draft_output_prob_dist.reserve(num_rsentries); + + for (int i = 0; i < num_rsentries; ++i) { + RequestModelState verify_mstate = rsentries[i]->mstates[verify_model_id_]; + RequestModelState draft_mstate = rsentries[i]->mstates[draft_model_id_]; + request_internal_ids.push_back(verify_mstate->internal_id); + ICHECK(!draft_lengths.empty()); + ICHECK_EQ(draft_lengths[i], draft_mstate->draft_output_tokens.size()); + ICHECK_EQ(draft_lengths[i], draft_mstate->draft_output_prob_dist.size()); + // the last committed token + all the draft tokens but the last one. + all_tokens_to_verify.push_back(draft_mstate->committed_tokens.back().sampled_token_id.first); + for (int j = 0; j < static_cast(draft_mstate->draft_output_tokens.size()); ++j) { + all_tokens_to_verify.push_back(draft_mstate->draft_output_tokens[j].sampled_token_id.first); + } + verify_request_mstates.push_back(verify_mstate); + generation_cfg.push_back(rsentries[i]->request->generation_cfg); + rngs.push_back(&rsentries[i]->rng); + draft_output_tokens.push_back(draft_mstate->draft_output_tokens); + CHECK(draft_mstate->draft_output_prob_dist[0]->device.device_type == kDLCPU); + draft_output_prob_dist.push_back(draft_mstate->draft_output_prob_dist); + } + + std::vector cum_verify_lengths = {0}; + cum_verify_lengths.reserve(num_rsentries + 1); + std::vector verify_lengths; + for (int i = 0; i < num_rsentries; ++i) { + // Add one committed token. + verify_lengths.push_back(draft_lengths[i] + 1); + cum_verify_lengths.push_back(cum_verify_lengths.back() + verify_lengths.back()); + } + + RECORD_EVENT(trace_recorder_, request_ids, "start verify embedding"); + ObjectRef embeddings = models_[verify_model_id_]->TokenEmbed( + {IntTuple{all_tokens_to_verify.begin(), all_tokens_to_verify.end()}}); + RECORD_EVENT(trace_recorder_, request_ids, "finish verify embedding"); + + RECORD_EVENT(trace_recorder_, request_ids, "start verify"); + ObjectRef fused_hidden_states = models_[verify_model_id_]->FuseEmbedHidden( + embeddings, NDArray(), 1, cum_verify_lengths[num_rsentries]); + NDArray hidden_states = models_[verify_model_id_]->BatchVerifyToLastHidden( + fused_hidden_states, request_internal_ids, verify_lengths); + ICHECK_EQ(hidden_states->ndim, 3); + ICHECK_EQ(hidden_states->shape[0], 1); + NDArray logits = + models_[verify_model_id_]->GetLogits(hidden_states, 1, cum_verify_lengths[num_rsentries]); + RECORD_EVENT(trace_recorder_, request_ids, "finish verify"); + ICHECK_EQ(logits->ndim, 3); + ICHECK_EQ(logits->shape[0], 1); + ICHECK_EQ(logits->shape[1], cum_verify_lengths[num_rsentries]); + + // - Update logits. + logits = + logits.CreateView({cum_verify_lengths[num_rsentries], logits->shape[2]}, logits->dtype); + logit_processor_->InplaceUpdateLogits(logits, generation_cfg, verify_request_mstates, + request_ids, &cum_verify_lengths, &draft_output_tokens); + + // - Compute probability distributions. + NDArray probs_on_device = logit_processor_->ComputeProbsFromLogits( + logits, generation_cfg, request_ids, &cum_verify_lengths); + + std::vector> sample_results_arr = sampler_->BatchVerifyDraftTokens( + probs_on_device, request_ids, cum_verify_lengths, generation_cfg, rngs, draft_output_tokens, + draft_output_prob_dist); + ICHECK_EQ(sample_results_arr.size(), num_rsentries); + + std::vector last_hidden_states; + for (int i = 0; i < num_rsentries; ++i) { + const std::vector& sample_results = sample_results_arr[i]; + int accept_length = sample_results.size(); + ICHECK_GE(accept_length, 1); + for (SampleResult sample_result : sample_results) { + rsentries[i]->mstates[verify_model_id_]->CommitToken(sample_result); + rsentries[i]->mstates[draft_model_id_]->CommitToken(sample_result); + } + estate->stats.total_accepted_length += accept_length - 1; + // - Minus one because the last draft token has no kv cache entry + // - Take max with 0 in case of all accepted. + int rollback_length = + std::max(cum_verify_lengths[i + 1] - cum_verify_lengths[i] - accept_length, 0); + // rollback kv cache + // NOTE: when number of small models is more than 1 (in the future), + // it is possible to re-compute prefill for the small models. + if (rollback_length > 0) { + models_[verify_model_id_]->PopNFromKVCache( + rsentries[i]->mstates[verify_model_id_]->internal_id, rollback_length); + // Draft model rollback minus one because verify uses one more token. + models_[draft_model_id_]->PopNFromKVCache( + rsentries[i]->mstates[draft_model_id_]->internal_id, rollback_length - 1); + } + // clear the draft model state entries + rsentries[i]->mstates[draft_model_id_]->RemoveAllDraftTokens(); + // - Slice hidden_states_for_sample + NDArray last_hidden_on_device = + GetTokenHidden(hidden_states, (cum_verify_lengths[i] + accept_length - 1)); + last_hidden_states.push_back(last_hidden_on_device); + } + + { + // One step draft for the following steps + NDArray hidden_states_nd{nullptr}; + ObjectRef next_hidden_states = model_workspaces_[draft_model_id_].hidden_states; + // Concat last hidden_states + hidden_states_nd = + models_[draft_model_id_]->ConcatLastHidden(last_hidden_states, &next_hidden_states); + ICHECK_EQ(hidden_states_nd->ndim, 2); + ICHECK_EQ(hidden_states_nd->shape[0], num_rsentries); + hidden_states_nd = hidden_states_nd.CreateView( + {hidden_states_nd->shape[0], 1, hidden_states_nd->shape[1]}, hidden_states_nd->dtype); + + std::vector input_tokens; + Array mstates; + input_tokens.reserve(num_rsentries); + mstates.reserve(num_rsentries); + for (const RequestStateEntry& rsentry : rsentries) { + mstates.push_back(rsentry->mstates[draft_model_id_]); + } + for (int i = 0; i < num_rsentries; ++i) { + ICHECK(!mstates[i]->committed_tokens.empty()); + input_tokens.push_back(mstates[i]->committed_tokens.back().sampled_token_id.first); + } + + // - Compute embeddings. + RECORD_EVENT(trace_recorder_, request_ids, "start proposal embedding"); + embeddings = models_[draft_model_id_]->TokenEmbed( + {IntTuple{input_tokens.begin(), input_tokens.end()}}); + RECORD_EVENT(trace_recorder_, request_ids, "finish proposal embedding"); + + // - Invoke model decode. + RECORD_EVENT(trace_recorder_, request_ids, "start proposal decode"); + ObjectRef fused_hidden_states = models_[draft_model_id_]->FuseEmbedHidden( + embeddings, hidden_states_nd, /*batch_size*/ num_rsentries, /*seq_len*/ 1); + hidden_states_nd = models_[draft_model_id_]->BatchDecodeToLastHidden(fused_hidden_states, + request_internal_ids); + + if (models_[draft_model_id_]->CanGetLogits()) { + logits = models_[draft_model_id_]->GetLogits(hidden_states_nd, /*batch_size*/ num_rsentries, + /*seq_len*/ 1); + } else { + // - Use base model's head. + logits = + models_[0]->GetLogits(hidden_states_nd, /*batch_size*/ num_rsentries, /*seq_len*/ 1); + } + RECORD_EVENT(trace_recorder_, request_ids, "finish proposal decode"); + ICHECK_EQ(logits->ndim, 3); + ICHECK_EQ(logits->shape[0], num_rsentries); + ICHECK_EQ(logits->shape[1], 1); + + // - Update logits. + logits = logits.CreateView({num_rsentries, logits->shape[2]}, logits->dtype); + logit_processor_->InplaceUpdateLogits(logits, generation_cfg, mstates, request_ids); + + // - Compute probability distributions. + probs_on_device = + logit_processor_->ComputeProbsFromLogits(logits, generation_cfg, request_ids); + + // - Sample tokens. + // Fill range [0, num_rsentries) into `sample_indices`. + std::vector sample_indices(num_rsentries); + std::iota(sample_indices.begin(), sample_indices.end(), 0); + std::vector prob_dist; + std::vector sample_results = sampler_->BatchSampleTokens( + probs_on_device, sample_indices, request_ids, generation_cfg, rngs, &prob_dist); + ICHECK_EQ(sample_results.size(), num_rsentries); + + // - Add draft token to the state. + for (int i = 0; i < num_rsentries; ++i) { + // - Slice hidden_states_for_sample + NDArray last_hidden_on_device = GetTokenHidden(hidden_states_nd, i); + CHECK(i < static_cast(prob_dist.size())); + CHECK(prob_dist[i].defined()); + mstates[i]->AddDraftToken(sample_results[i], prob_dist[i], last_hidden_on_device); + estate->stats.total_draft_length += 1; + } + } + + auto tend = std::chrono::high_resolution_clock::now(); + estate->stats.engine_total_decode_time += static_cast((tend - tstart).count()) / 1e9; + + return estate->running_queue; + } + + private: + struct DraftRequestStateEntries { + /*! \brief The request state entries to verify. */ + Array draft_rsentries; + /*! \brief The draft length of each request state. */ + std::vector draft_lengths; + /*! \brief The total draft length. */ + int total_draft_length; + }; + + /*! + * \brief Decide whether to run verify for the draft of each request. + * \param estate The engine state. + * \return The drafts to verify, together with their respective + * state and input length. + */ + DraftRequestStateEntries GetDraftsToVerify(EngineState estate) { + std::vector draft_lengths; + int total_draft_length = 0; + int total_required_pages = 0; + int num_available_pages = models_[verify_model_id_]->GetNumAvailablePages(); + + // Preempt the request state entries that cannot fit the large model for verification. + std::vector running_rsentries = GetRunningRequestStateEntries(estate); + std::vector num_page_requirement; + num_page_requirement.reserve(running_rsentries.size()); + for (const RequestStateEntry& rsentry : running_rsentries) { + int draft_length = rsentry->mstates[draft_model_id_]->draft_output_tokens.size(); + int num_require_pages = (draft_length + engine_config_->kv_cache_page_size - 1) / + engine_config_->kv_cache_page_size; + draft_lengths.push_back(draft_length); + num_page_requirement.push_back(num_require_pages); + total_draft_length += draft_length; + total_required_pages += num_require_pages; + } + while (!CanVerify(total_required_pages)) { + RequestStateEntry preempted = + PreemptLastRunningRequestStateEntry(estate, models_, trace_recorder_); + if (preempted.same_as(running_rsentries.back())) { + total_draft_length -= draft_lengths.back(); + total_required_pages -= num_page_requirement.back(); + draft_lengths.pop_back(); + num_page_requirement.pop_back(); + running_rsentries.pop_back(); + } + } + + return {running_rsentries, draft_lengths, total_draft_length}; + } + + bool CanVerify(int num_required_pages) { + int num_available_pages = models_[0]->GetNumAvailablePages(); + return num_required_pages <= num_available_pages; + } + + /*! + * \brief Get one item from a hidden_states array, which corresponds to the last token. + * \param hidden_states The hidden_states of all the tokens. + * \param token_pos The desired token position in the sequence. + * \return The desired token's hidden_states + */ + NDArray GetTokenHidden(NDArray hidden_states, int token_pos) { + ICHECK_EQ(hidden_states->ndim, 3); + NDArray last_hidden_on_device = + NDArray::Empty({hidden_states->shape[2]}, hidden_states->dtype, hidden_states->device); + + int64_t ndata = hidden_states->shape[2]; + const int16_t* __restrict p_hidden = + static_cast(__builtin_assume_aligned(hidden_states->data, 2)) + + (token_pos * ndata); + + last_hidden_on_device.CopyFromBytes(p_hidden, ndata * sizeof(int16_t)); + return last_hidden_on_device; + } + + /*! + * \brief The model to run decode in. When there are multiple + * models, the `Step` function of the created action will not take effect. + */ + Array models_; + /*! \brief The logit processor. */ + LogitProcessor logit_processor_; + /*! \brief The sampler to sample new tokens. */ + Sampler sampler_; + /*! \brief Workspace of each model. */ + std::vector model_workspaces_; + /*! \brief The engine config. */ + EngineConfig engine_config_; + /*! \brief Event trace recorder. */ + Optional trace_recorder_; + /*! \brief Random number generator. */ + RandomGenerator& rng_; + /*! \brief The ids of verify/draft models. */ + const int verify_model_id_ = 0; + const int draft_model_id_ = 1; + const float eps_ = 1e-5; +}; + +EngineAction EngineAction::EagleBatchVerify(Array models, LogitProcessor logit_processor, + Sampler sampler, + std::vector model_workspaces, + EngineConfig engine_config, + Optional trace_recorder) { + return EngineAction(make_object( + std::move(models), std::move(logit_processor), std::move(sampler), + std::move(model_workspaces), std::move(engine_config), std::move(trace_recorder))); +} + +} // namespace serve +} // namespace llm +} // namespace mlc diff --git a/cpp/serve/engine_actions/eagle_new_request_prefill.cc b/cpp/serve/engine_actions/eagle_new_request_prefill.cc new file mode 100644 index 0000000000..133c23e8a1 --- /dev/null +++ b/cpp/serve/engine_actions/eagle_new_request_prefill.cc @@ -0,0 +1,601 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file serve/engine_actions/eagle_new_request_prefill.cc + */ + +#include + +#include "../config.h" +#include "../model.h" +#include "../sampler/sampler.h" +#include "action.h" +#include "action_commons.h" + +namespace mlc { +namespace llm { +namespace serve { + +/*! + * \brief The action that prefills requests in the `waiting_queue` of + * the engine state. + */ +class EagleNewRequestPrefillActionObj : public EngineActionObj { + public: + explicit EagleNewRequestPrefillActionObj(Array models, LogitProcessor logit_processor, + Sampler sampler, + std::vector model_workspaces, + EngineConfig engine_config, + Optional trace_recorder) + : models_(std::move(models)), + logit_processor_(std::move(logit_processor)), + sampler_(std::move(sampler)), + model_workspaces_(std::move(model_workspaces)), + engine_config_(std::move(engine_config)), + trace_recorder_(std::move(trace_recorder)) {} + + Array Step(EngineState estate) final { + // - Find the requests in `waiting_queue` that can prefill in this step. + std::vector prefill_inputs; + { + NVTXScopedRange nvtx_scope("NewRequestPrefill getting requests"); + prefill_inputs = GetRequestStateEntriesToPrefill(estate); + if (prefill_inputs.empty()) { + return {}; + } + } + + int num_rsentries = prefill_inputs.size(); + auto tstart = std::chrono::high_resolution_clock::now(); + + // - Update status of request states from pending to alive. + Array request_ids; + std::vector rstates_of_entries; + std::vector status_before_prefill; + request_ids.reserve(num_rsentries); + rstates_of_entries.reserve(num_rsentries); + status_before_prefill.reserve(num_rsentries); + for (const PrefillInput& prefill_input : prefill_inputs) { + const RequestStateEntry& rsentry = prefill_input.rsentry; + const Request& request = rsentry->request; + RequestState request_rstate = estate->GetRequestState(request); + request_ids.push_back(request->id); + status_before_prefill.push_back(rsentry->status); + rsentry->status = RequestStateStatus::kAlive; + + if (status_before_prefill.back() == RequestStateStatus::kPending) { + // - Add the request to running queue if the request state + // status was pending and all its request states were pending. + bool alive_state_existed = false; + for (const RequestStateEntry& rsentry_ : request_rstate->entries) { + if (rsentry_->status == RequestStateStatus::kAlive && !rsentry_.same_as(rsentry)) { + alive_state_existed = true; + } + } + if (!alive_state_existed) { + estate->running_queue.push_back(request); + } + } + rstates_of_entries.push_back(std::move(request_rstate)); + } + + // - Get embedding and run prefill for each model. + std::vector prefill_lengths; + prefill_lengths.resize(/*size=*/num_rsentries, /*value=*/-1); + NDArray hidden_states_for_input{nullptr}; + NDArray hidden_states_for_sample{nullptr}; + NDArray logits_for_sample{nullptr}; + // A map used to record the entry and child_idx pair needed to fork sequence. + // The base model (id 0) should record all the pairs and all the small models + // fork sequences according to this map. + std::unordered_map> fork_rsentry_child_map; + for (int model_id = 0; model_id < static_cast(models_.size()); ++model_id) { + std::vector request_internal_ids; + request_internal_ids.reserve(num_rsentries); + ObjectRef embeddings = model_workspaces_[model_id].embeddings; + int cum_prefill_length = 0; + bool single_input = + num_rsentries == 1 && prefill_inputs[0].rsentry->mstates[model_id]->inputs.size() == 1; + for (int i = 0; i < num_rsentries; ++i) { + const RequestStateEntry& rsentry = prefill_inputs[i].rsentry; + RequestModelState mstate = rsentry->mstates[model_id]; + auto [input_data, input_length] = + ChunkPrefillInputData(mstate, prefill_inputs[i].max_prefill_length); + if (prefill_lengths[i] == -1) { + prefill_lengths[i] = input_length; + } else { + ICHECK_EQ(prefill_lengths[i], input_length); + } + + ICHECK(mstate->draft_output_tokens.empty()); + ICHECK(mstate->draft_output_prob_dist.empty()); + if (status_before_prefill[i] == RequestStateStatus::kPending) { + // Add the sequence to the model, or fork the sequence from its parent. + if (rsentry->parent_idx == -1) { + models_[model_id]->AddNewSequence(mstate->internal_id); + } else { + models_[model_id]->ForkSequence( + rstates_of_entries[i]->entries[rsentry->parent_idx]->mstates[model_id]->internal_id, + mstate->internal_id); + } + // Enable sliding window for the sequence if it is not a parent. + if (rsentry->child_indices.empty()) { + models_[model_id]->EnableSlidingWindowForSeq(mstate->internal_id); + } + } + request_internal_ids.push_back(mstate->internal_id); + RECORD_EVENT(trace_recorder_, prefill_inputs[i].rsentry->request->id, "start embedding"); + // Speculative models shift left the input tokens by 1 when base model has committed tokens. + // Note: for n > 1 cases Eagle doesn't work because parent entry doesn't shift input tokens. + int embed_offset = + prefill_inputs[i].rsentry->mstates[model_id]->committed_tokens.empty() ? 0 : 1; + for (int j = 0; j < static_cast(input_data.size()); ++j) { + if (j == static_cast(input_data.size()) - 1) { + std::vector tail_tokens; + TokenData tk_data = Downcast(input_data[j]); + CHECK(tk_data.defined()); + for (int k = embed_offset; k < static_cast(tk_data->token_ids.size()); ++k) { + tail_tokens.push_back(tk_data->token_ids[k]); + } + embeddings = models_[model_id]->TokenEmbed( + {tail_tokens.begin(), tail_tokens.end()}, + /*dst=*/!single_input ? &model_workspaces_[model_id].embeddings : nullptr, + /*offset=*/cum_prefill_length); + cum_prefill_length += input_data[j]->GetLength(); + cum_prefill_length -= embed_offset; + } else { + embeddings = input_data[i]->GetEmbedding( + models_[model_id], + /*dst=*/!single_input ? &model_workspaces_[model_id].embeddings : nullptr, + /*offset=*/cum_prefill_length); + cum_prefill_length += input_data[j]->GetLength(); + } + } + if (embed_offset > 0) { + std::vector new_tokens = {prefill_inputs[i] + .rsentry->mstates[model_id] + ->committed_tokens.back() + .sampled_token_id.first}; + embeddings = + models_[model_id]->TokenEmbed({new_tokens.begin(), new_tokens.end()}, + /*dst=*/&model_workspaces_[model_id].embeddings, + /*offset=*/cum_prefill_length); + cum_prefill_length += new_tokens.size(); + } + RECORD_EVENT(trace_recorder_, rsentry->request->id, "finish embedding"); + } + + RECORD_EVENT(trace_recorder_, request_ids, "start prefill"); + ObjectRef fused_hidden_states = models_[model_id]->FuseEmbedHidden( + embeddings, hidden_states_for_input, /*batch_size*/ 1, /*seq_len*/ cum_prefill_length); + NDArray hidden_states = models_[model_id]->BatchPrefillToLastHidden( + fused_hidden_states, request_internal_ids, prefill_lengths); + RECORD_EVENT(trace_recorder_, request_ids, "finish prefill"); + ICHECK_EQ(hidden_states->ndim, 3); + ICHECK_EQ(hidden_states->shape[0], 1); + ICHECK_EQ(hidden_states->shape[1], cum_prefill_length); + + if (model_id == 0) { + // We only need to sample for model 0 in prefill. + hidden_states_for_input = hidden_states; + } + + // Whether to use base model to get logits. + int sample_model_id = !models_[model_id]->CanGetLogits() ? 0 : model_id; + hidden_states_for_sample = models_[sample_model_id]->BatchSelectLastHidden( + hidden_states, request_internal_ids, prefill_lengths); + logits_for_sample = + models_[sample_model_id]->GetLogits(hidden_states_for_sample, 1, num_rsentries); + ICHECK_EQ(hidden_states_for_sample->ndim, 3); + ICHECK_EQ(hidden_states_for_sample->shape[0], 1); + ICHECK_EQ(hidden_states_for_sample->shape[1], num_rsentries); + + // - Update logits. + ICHECK(logits_for_sample.defined()); + Array generation_cfg; + Array mstates_for_logitproc; + generation_cfg.reserve(num_rsentries); + mstates_for_logitproc.reserve(num_rsentries); + for (int i = 0; i < num_rsentries; ++i) { + generation_cfg.push_back(prefill_inputs[i].rsentry->request->generation_cfg); + mstates_for_logitproc.push_back(prefill_inputs[i].rsentry->mstates[sample_model_id]); + } + logits_for_sample = logits_for_sample.CreateView({num_rsentries, logits_for_sample->shape[2]}, + logits_for_sample->dtype); + logit_processor_->InplaceUpdateLogits(logits_for_sample, generation_cfg, + mstates_for_logitproc, request_ids); + + // - Compute probability distributions. + NDArray probs_on_device = + logit_processor_->ComputeProbsFromLogits(logits_for_sample, generation_cfg, request_ids); + + // - Sample tokens. + // For prefill_inputs which have children, sample + // one token for each rstate that is depending. + // Otherwise, sample a token for the current rstate. + std::vector sample_indices; + std::vector rsentries_for_sample; + std::vector rngs; + std::vector rsentry_activated; + sample_indices.reserve(num_rsentries); + rsentries_for_sample.reserve(num_rsentries); + rngs.reserve(num_rsentries); + rsentry_activated.reserve(num_rsentries); + request_ids.clear(); + generation_cfg.clear(); + for (int i = 0; i < num_rsentries; ++i) { + const RequestStateEntry& rsentry = prefill_inputs[i].rsentry; + int remaining_num_child_to_activate = prefill_inputs[i].num_child_to_activate; + for (int child_idx : rsentry->child_indices) { + // Only use base model to judge if we need to add child entries. + if (rstates_of_entries[i]->entries[child_idx]->status == RequestStateStatus::kPending && + (rstates_of_entries[i]->entries[child_idx]->mstates[0]->committed_tokens.empty() || + fork_rsentry_child_map[i].count(child_idx))) { + // If rstates_of_entries[i]->entries[child_idx] has no committed token, + // the prefill of the current rsentry will unblock + // rstates_of_entries[i]->entries[child_idx], + // and thus we want to sample a token for rstates_of_entries[i]->entries[child_idx]. + fork_rsentry_child_map[i].insert(child_idx); + sample_indices.push_back(i); + rsentries_for_sample.push_back(rstates_of_entries[i]->entries[child_idx]); + request_ids.push_back(rsentry->request->id); + generation_cfg.push_back(rsentry->request->generation_cfg); + rngs.push_back(&rstates_of_entries[i]->entries[child_idx]->rng); + + // We only fork the first `num_child_to_activate` children. + // The children not being forked will be forked via later prefills. + // Usually `num_child_to_activate` is the same as the number of children. + // But it can be fewer subject to the KV cache max num sequence limit. + if (remaining_num_child_to_activate == 0) { + rsentry_activated.push_back(false); + continue; + } + rsentry_activated.push_back(true); + --remaining_num_child_to_activate; + if (model_id == 0) { + ICHECK(rstates_of_entries[i]->entries[child_idx]->status == + RequestStateStatus::kPending); + rstates_of_entries[i]->entries[child_idx]->status = RequestStateStatus::kAlive; + } + int64_t child_internal_id = + rstates_of_entries[i]->entries[child_idx]->mstates[model_id]->internal_id; + models_[model_id]->ForkSequence(rsentry->mstates[model_id]->internal_id, + child_internal_id); + // Enable sliding window for the child sequence if the child is not a parent. + if (rstates_of_entries[i]->entries[child_idx]->child_indices.empty()) { + models_[model_id]->EnableSlidingWindowForSeq(child_internal_id); + } + } + } + if (rsentry->child_indices.empty()) { + // If rsentry has no child, we sample a token for itself. + sample_indices.push_back(i); + rsentries_for_sample.push_back(rsentry); + request_ids.push_back(rsentry->request->id); + generation_cfg.push_back(rsentry->request->generation_cfg); + rngs.push_back(&rsentry->rng); + rsentry_activated.push_back(true); + } + } + std::vector prob_dist; + std::vector sample_results = sampler_->BatchSampleTokens( + probs_on_device, sample_indices, request_ids, generation_cfg, rngs, &prob_dist); + ICHECK_EQ(sample_results.size(), rsentries_for_sample.size()); + + // - Update the committed tokens of states. + // - If a request is first-time prefilled, set the prefill finish time. + auto tnow = std::chrono::high_resolution_clock::now(); + for (int i = 0; i < static_cast(rsentries_for_sample.size()); ++i) { + if (model_id == 0) { + for (int mid = 0; mid < static_cast(models_.size()); ++mid) { + rsentries_for_sample[i]->mstates[mid]->CommitToken(sample_results[i]); + if (!rsentry_activated[i]) { + // When the child rsentry is not activated, + // add the sampled token as an input of the mstate for prefill. + rsentries_for_sample[i]->mstates[mid]->inputs.push_back( + TokenData(std::vector{sample_results[i].sampled_token_id.first})); + } + } + // Only base model trigger timing records. + if (rsentries_for_sample[i]->mstates[0]->committed_tokens.size() == 1) { + rsentries_for_sample[i]->tprefill_finish = tnow; + } + } else { + // - Slice hidden_states_for_sample + NDArray last_hidden_on_device = GetTokenHidden(hidden_states_for_sample, i); + CHECK(i < static_cast(prob_dist.size())); + CHECK(prob_dist[i].defined()); + rsentries_for_sample[i]->mstates[model_id]->AddDraftToken(sample_results[i], prob_dist[i], + last_hidden_on_device); + estate->stats.total_draft_length += 1; + } + } + } + + auto tend = std::chrono::high_resolution_clock::now(); + estate->stats.engine_total_prefill_time += static_cast((tend - tstart).count()) / 1e9; + + // - Remove the request from waiting queue if all its request states + // are now alive and have no remaining chunked inputs. + std::vector processed_requests; + { + processed_requests.reserve(num_rsentries); + std::unordered_set dedup_map; + for (int i = 0; i < num_rsentries; ++i) { + const RequestStateEntry& rsentry = prefill_inputs[i].rsentry; + if (dedup_map.find(rsentry->request.get()) != dedup_map.end()) { + continue; + } + dedup_map.insert(rsentry->request.get()); + processed_requests.push_back(rsentry->request); + + bool pending_state_exists = false; + for (const RequestStateEntry& rsentry_ : rstates_of_entries[i]->entries) { + if (rsentry_->status == RequestStateStatus::kPending || + !rsentry_->mstates[0]->inputs.empty()) { + pending_state_exists = true; + break; + } + } + if (!pending_state_exists) { + auto it = std::find(estate->waiting_queue.begin(), estate->waiting_queue.end(), + rsentry->request); + ICHECK(it != estate->waiting_queue.end()); + estate->waiting_queue.erase(it); + } + } + } + return processed_requests; + } + + private: + /*! \brief The class of request state entry and its maximum allowed length for prefill. */ + struct PrefillInput { + RequestStateEntry rsentry; + int max_prefill_length = 0; + int num_child_to_activate = 0; + }; + + /*! + * \brief Find one or multiple request state entries to run prefill. + * \param estate The engine state. + * \return The request entries to prefill, together with their input lengths. + */ + std::vector GetRequestStateEntriesToPrefill(EngineState estate) { + if (estate->waiting_queue.empty()) { + // No request to prefill. + return {}; + } + + std::vector prefill_inputs; + + // - Try to prefill pending requests. + int total_input_length = 0; + int total_required_pages = 0; + int num_available_pages = models_[0]->GetNumAvailablePages(); + int num_running_rsentries = GetRunningRequestStateEntries(estate).size(); + int current_total_seq_len = models_[0]->GetCurrentTotalSequenceLength(); + + int num_prefill_rsentries = 0; + for (const Request& request : estate->waiting_queue) { + RequestState rstate = estate->GetRequestState(request); + bool prefill_stops = false; + for (const RequestStateEntry& rsentry : rstate->entries) { + // A request state entry can be prefilled only when: + // - it has inputs, and + // - it has no parent or its parent is alive and has no remaining input. + if (rsentry->mstates[0]->inputs.empty() || + (rsentry->parent_idx != -1 && + (rstate->entries[rsentry->parent_idx]->status == RequestStateStatus::kPending || + !rstate->entries[rsentry->parent_idx]->mstates[0]->inputs.empty()))) { + continue; + } + + int input_length = rsentry->mstates[0]->GetInputLength(); + int num_require_pages = (input_length + engine_config_->kv_cache_page_size - 1) / + engine_config_->kv_cache_page_size; + total_input_length += input_length; + total_required_pages += num_require_pages; + // - Attempt 1. Check if the entire request state entry can fit for prefill. + bool can_prefill = false; + for (int num_child_to_activate = rsentry->child_indices.size(); num_child_to_activate >= 0; + --num_child_to_activate) { + if (CanPrefill(estate, num_prefill_rsentries + 1 + num_child_to_activate, + total_input_length, total_required_pages, num_available_pages, + current_total_seq_len, num_running_rsentries)) { + prefill_inputs.push_back({rsentry, input_length, num_child_to_activate}); + num_prefill_rsentries += 1 + num_child_to_activate; + can_prefill = true; + break; + } + } + if (can_prefill) { + continue; + } + total_input_length -= input_length; + total_required_pages -= num_require_pages; + + // - Attempt 2. Check if the request state entry can partially fit by input chunking. + ICHECK_LE(total_input_length, engine_config_->prefill_chunk_size); + if (engine_config_->prefill_chunk_size - total_input_length >= input_length || + engine_config_->prefill_chunk_size == total_input_length) { + // 1. If the input length can fit the remaining prefill chunk size, + // it means the failure of attempt 1 is not because of the input + // length being too long, and thus chunking does not help. + // 2. If the total input length already reaches the prefill chunk size, + // the current request state entry will not be able to be processed. + // So we can safely return in either case. + prefill_stops = true; + break; + } + input_length = engine_config_->prefill_chunk_size - total_input_length; + num_require_pages = (input_length + engine_config_->kv_cache_page_size - 1) / + engine_config_->kv_cache_page_size; + total_input_length += input_length; + total_required_pages += num_require_pages; + if (CanPrefill(estate, num_prefill_rsentries + 1, total_input_length, total_required_pages, + num_available_pages, current_total_seq_len, num_running_rsentries)) { + prefill_inputs.push_back({rsentry, input_length, 0}); + num_prefill_rsentries += 1; + } + + // - Prefill stops here. + prefill_stops = true; + break; + } + if (prefill_stops) { + break; + } + } + + return prefill_inputs; + } + + /*! \brief Check if the input requests can be prefilled under conditions. */ + bool CanPrefill(EngineState estate, int num_prefill_rsentries, int total_input_length, + int num_required_pages, int num_available_pages, int current_total_seq_len, + int num_running_rsentries) { + ICHECK_LE(num_running_rsentries, engine_config_->max_num_sequence); + + // No exceeding of the maximum allowed requests that can + // run simultaneously. + int spec_factor = engine_config_->speculative_mode != SpeculativeMode::kDisable + ? engine_config_->spec_draft_length + : 1; + if ((num_running_rsentries + num_prefill_rsentries) * spec_factor > + std::min(engine_config_->max_num_sequence, engine_config_->prefill_chunk_size)) { + return false; + } + + // NOTE: The conditions are heuristic and can be revised. + // Cond 1: total input length <= prefill chunk size. + // Cond 2: at least one decode can be performed after prefill. + // Cond 3: number of total tokens after 8 times of decode does not + // exceed the limit, where 8 is a watermark number can + // be configured and adjusted in the future. + int new_batch_size = num_running_rsentries + num_prefill_rsentries; + return total_input_length <= engine_config_->prefill_chunk_size && + num_required_pages + new_batch_size <= num_available_pages && + current_total_seq_len + total_input_length + 8 * new_batch_size <= + engine_config_->max_total_sequence_length; + } + + /*! + * \brief Chunk the input of the given RequestModelState for prefill + * with regard to the provided maximum allowed prefill length. + * Return the list of input for prefill and the total prefill length. + * The `inputs` field of the given `mstate` will be mutated to exclude + * the returned input. + * \param mstate The RequestModelState whose input data is to be chunked. + * \param max_prefill_length The maximum allowed prefill length for the mstate. + * \return The list of input for prefill and the total prefill length. + */ + std::pair, int> ChunkPrefillInputData(const RequestModelState& mstate, + int max_prefill_length) { + if (mstate->inputs.empty()) { + } + ICHECK(!mstate->inputs.empty()); + std::vector inputs; + int cum_input_length = 0; + inputs.reserve(mstate->inputs.size()); + for (int i = 0; i < static_cast(mstate->inputs.size()); ++i) { + inputs.push_back(mstate->inputs[i]); + int input_length = mstate->inputs[i]->GetLength(); + cum_input_length += input_length; + // Case 0. the cumulative input length does not reach the maximum prefill length. + if (cum_input_length < max_prefill_length) { + continue; + } + + // Case 1. the cumulative input length equals the maximum prefill length. + if (cum_input_length == max_prefill_length) { + if (i == static_cast(mstate->inputs.size()) - 1) { + // - If `i` is the last input, we just copy and reset `mstate->inputs`. + mstate->inputs.clear(); + } else { + // - Otherwise, set the new input array. + mstate->inputs = Array{mstate->inputs.begin() + i + 1, mstate->inputs.end()}; + } + return {inputs, cum_input_length}; + } + + // Case 2. cum_input_length > max_prefill_length + // The input `i` itself needs chunking if it is TokenData, + // or otherwise it cannot be chunked. + Data input = mstate->inputs[i]; + inputs.pop_back(); + cum_input_length -= input_length; + const auto* token_input = input.as(); + if (token_input == nullptr) { + // Cannot chunk the input. + if (i != 0) { + mstate->inputs = Array{mstate->inputs.begin() + i, mstate->inputs.end()}; + } + return {inputs, cum_input_length}; + } + + // Split the token data into two parts. + // Return the first part for prefill, and keep the second part. + int chunked_input_length = max_prefill_length - cum_input_length; + ICHECK_GT(input_length, chunked_input_length); + TokenData chunked_input(IntTuple{token_input->token_ids.begin(), + token_input->token_ids.begin() + chunked_input_length}); + TokenData remaining_input(IntTuple{token_input->token_ids.begin() + chunked_input_length, + token_input->token_ids.end()}); + inputs.push_back(chunked_input); + cum_input_length += chunked_input_length; + std::vector remaining_inputs{mstate->inputs.begin() + i + 1, mstate->inputs.end()}; + remaining_inputs.insert(remaining_inputs.begin(), remaining_input); + mstate->inputs = remaining_inputs; + return {inputs, cum_input_length}; + } + + ICHECK(false) << "Cannot reach here"; + } + + /*! + * \brief Get one item from a hidden_states array, which corresponds to the last token. + * \param hidden_states The hidden_states of all the tokens. + * \param token_pos The desired token position in the sequence. + * \return The desired token's hidden_states + */ + NDArray GetTokenHidden(NDArray hidden_states, int token_pos) { + ICHECK_EQ(hidden_states->ndim, 3); + NDArray last_hidden_on_device = + NDArray::Empty({hidden_states->shape[2]}, hidden_states->dtype, hidden_states->device); + + int64_t ndata = hidden_states->shape[2]; + const int16_t* __restrict p_hidden = + static_cast(__builtin_assume_aligned(hidden_states->data, 2)) + + (token_pos * ndata); + + last_hidden_on_device.CopyFromBytes(p_hidden, ndata * sizeof(int16_t)); + return last_hidden_on_device; + } + + /*! \brief The models to run prefill in. */ + Array models_; + /*! \brief The logit processor. */ + LogitProcessor logit_processor_; + /*! \brief The sampler to sample new tokens. */ + Sampler sampler_; + /*! \brief Workspace of each model. */ + std::vector model_workspaces_; + /*! \brief The engine config. */ + EngineConfig engine_config_; + /*! \brief Event trace recorder. */ + Optional trace_recorder_; +}; + +EngineAction EngineAction::EagleNewRequestPrefill(Array models, + LogitProcessor logit_processor, Sampler sampler, + std::vector model_workspaces, + EngineConfig engine_config, + Optional trace_recorder) { + return EngineAction(make_object( + std::move(models), std::move(logit_processor), std::move(sampler), + std::move(model_workspaces), std::move(engine_config), std::move(trace_recorder))); +} + +} // namespace serve +} // namespace llm +} // namespace mlc diff --git a/cpp/serve/engine_actions/new_request_prefill.cc b/cpp/serve/engine_actions/new_request_prefill.cc index f93fbc2ded..c3f7491960 100644 --- a/cpp/serve/engine_actions/new_request_prefill.cc +++ b/cpp/serve/engine_actions/new_request_prefill.cc @@ -23,14 +23,13 @@ class NewRequestPrefillActionObj : public EngineActionObj { public: explicit NewRequestPrefillActionObj(Array models, LogitProcessor logit_processor, Sampler sampler, std::vector model_workspaces, - KVCacheConfig kv_cache_config, EngineMode engine_mode, + EngineConfig engine_config, Optional trace_recorder) : models_(std::move(models)), logit_processor_(std::move(logit_processor)), sampler_(std::move(sampler)), model_workspaces_(std::move(model_workspaces)), - kv_cache_config_(std::move(kv_cache_config)), - engine_mode_(std::move(engine_mode)), + engine_config_(std::move(engine_config)), trace_recorder_(std::move(trace_recorder)) {} Array Step(EngineState estate) final { @@ -167,9 +166,11 @@ class NewRequestPrefillActionObj : public EngineActionObj { std::vector sample_indices; std::vector rsentries_for_sample; std::vector rngs; + std::vector rsentry_activated; sample_indices.reserve(num_rsentries); rsentries_for_sample.reserve(num_rsentries); rngs.reserve(num_rsentries); + rsentry_activated.reserve(num_rsentries); request_ids.clear(); generation_cfg.clear(); for (int i = 0; i < num_rsentries; ++i) { @@ -179,29 +180,42 @@ class NewRequestPrefillActionObj : public EngineActionObj { continue; } + int remaining_num_child_to_activate = prefill_inputs[i].num_child_to_activate; for (int child_idx : rsentry->child_indices) { - if (rstates_of_entries[i]->entries[child_idx]->mstates[0]->committed_tokens.empty()) { - // If rstates_of_entries[i]->entries[child_idx] has no committed token, - // the prefill of the current rsentry will unblock - // rstates_of_entries[i]->entries[child_idx], - // and thus we want to sample a token for rstates_of_entries[i]->entries[child_idx]. - sample_indices.push_back(i); - rsentries_for_sample.push_back(rstates_of_entries[i]->entries[child_idx]); - request_ids.push_back(rsentry->request->id); - generation_cfg.push_back(rsentry->request->generation_cfg); - rngs.push_back(&rstates_of_entries[i]->entries[child_idx]->rng); - - ICHECK(rstates_of_entries[i]->entries[child_idx]->status == RequestStateStatus::kPending); - rstates_of_entries[i]->entries[child_idx]->status = RequestStateStatus::kAlive; - for (int model_id = 0; model_id < static_cast(models_.size()); ++model_id) { - int64_t child_internal_id = - rstates_of_entries[i]->entries[child_idx]->mstates[model_id]->internal_id; - models_[model_id]->ForkSequence(rsentry->mstates[model_id]->internal_id, - child_internal_id); - // Enable sliding window for the child sequence if the child is not a parent. - if (rstates_of_entries[i]->entries[child_idx]->child_indices.empty()) { - models_[model_id]->EnableSlidingWindowForSeq(child_internal_id); - } + // If rstates_of_entries[i]->entries[child_idx] has no committed token, + // the prefill of the current rsentry will unblock + // rstates_of_entries[i]->entries[child_idx], + // and thus we want to sample a token for rstates_of_entries[i]->entries[child_idx]. + if (rstates_of_entries[i]->entries[child_idx]->status != RequestStateStatus::kPending || + !rstates_of_entries[i]->entries[child_idx]->mstates[0]->committed_tokens.empty()) { + continue; + } + sample_indices.push_back(i); + rsentries_for_sample.push_back(rstates_of_entries[i]->entries[child_idx]); + request_ids.push_back(rsentry->request->id); + generation_cfg.push_back(rsentry->request->generation_cfg); + rngs.push_back(&rstates_of_entries[i]->entries[child_idx]->rng); + + ICHECK(rstates_of_entries[i]->entries[child_idx]->status == RequestStateStatus::kPending); + // We only fork the first `num_child_to_activate` children. + // The children not being forked will be forked via later prefills. + // Usually `num_child_to_activate` is the same as the number of children. + // But it can be fewer subject to the KV cache max num sequence limit. + if (remaining_num_child_to_activate == 0) { + rsentry_activated.push_back(false); + continue; + } + rsentry_activated.push_back(true); + --remaining_num_child_to_activate; + rstates_of_entries[i]->entries[child_idx]->status = RequestStateStatus::kAlive; + for (int model_id = 0; model_id < static_cast(models_.size()); ++model_id) { + int64_t child_internal_id = + rstates_of_entries[i]->entries[child_idx]->mstates[model_id]->internal_id; + models_[model_id]->ForkSequence(rsentry->mstates[model_id]->internal_id, + child_internal_id); + // Enable sliding window for the child sequence if the child is not a parent. + if (rstates_of_entries[i]->entries[child_idx]->child_indices.empty()) { + models_[model_id]->EnableSlidingWindowForSeq(child_internal_id); } } } @@ -212,6 +226,7 @@ class NewRequestPrefillActionObj : public EngineActionObj { request_ids.push_back(rsentry->request->id); generation_cfg.push_back(rsentry->request->generation_cfg); rngs.push_back(&rsentry->rng); + rsentry_activated.push_back(true); } } std::vector sample_results = sampler_->BatchSampleTokens( @@ -224,6 +239,12 @@ class NewRequestPrefillActionObj : public EngineActionObj { for (int i = 0; i < static_cast(rsentries_for_sample.size()); ++i) { for (const RequestModelState& mstate : rsentries_for_sample[i]->mstates) { mstate->CommitToken(sample_results[i]); + if (!rsentry_activated[i]) { + // When the child rsentry is not activated, + // add the sampled token as an input of the mstate for prefill. + mstate->inputs.push_back( + TokenData(std::vector{sample_results[i].sampled_token_id.first})); + } } if (rsentries_for_sample[i]->mstates[0]->committed_tokens.size() == 1) { rsentries_for_sample[i]->tprefill_finish = tnow; @@ -270,7 +291,8 @@ class NewRequestPrefillActionObj : public EngineActionObj { /*! \brief The class of request state entry and its maximum allowed length for prefill. */ struct PrefillInput { RequestStateEntry rsentry; - int max_prefill_length; + int max_prefill_length = 0; + int num_child_to_activate = 0; }; /*! @@ -309,33 +331,51 @@ class NewRequestPrefillActionObj : public EngineActionObj { } int input_length = rsentry->mstates[0]->GetInputLength(); - int num_require_pages = - (input_length + kv_cache_config_->page_size - 1) / kv_cache_config_->page_size; + int num_require_pages = (input_length + engine_config_->kv_cache_page_size - 1) / + engine_config_->kv_cache_page_size; total_input_length += input_length; total_required_pages += num_require_pages; // - Attempt 1. Check if the entire request state entry can fit for prefill. - if (CanPrefill(estate, num_prefill_rsentries + 1 + rsentry->child_indices.size(), - total_input_length, total_required_pages, num_available_pages, - current_total_seq_len, num_running_rsentries)) { - prefill_inputs.push_back({rsentry, input_length}); - num_prefill_rsentries += 1 + rsentry->child_indices.size(); + bool can_prefill = false; + for (int num_child_to_activate = rsentry->child_indices.size(); num_child_to_activate >= 0; + --num_child_to_activate) { + if (CanPrefill(estate, num_prefill_rsentries + 1 + num_child_to_activate, + total_input_length, total_required_pages, num_available_pages, + current_total_seq_len, num_running_rsentries)) { + prefill_inputs.push_back({rsentry, input_length, num_child_to_activate}); + num_prefill_rsentries += 1 + num_child_to_activate; + can_prefill = true; + break; + } + } + if (can_prefill) { continue; } total_input_length -= input_length; total_required_pages -= num_require_pages; // - Attempt 2. Check if the request state entry can partially fit by input chunking. - ICHECK_LE(total_input_length, kv_cache_config_->prefill_chunk_size); - input_length = - std::min(input_length, kv_cache_config_->prefill_chunk_size - total_input_length); - num_require_pages = - (input_length + kv_cache_config_->page_size - 1) / kv_cache_config_->page_size; - if (input_length > 0 && - CanPrefill(estate, num_prefill_rsentries + 1 + rsentry->child_indices.size(), - total_input_length, total_required_pages, num_available_pages, - current_total_seq_len, num_running_rsentries)) { - prefill_inputs.push_back({rsentry, input_length}); - num_prefill_rsentries += 1 + rsentry->child_indices.size(); + ICHECK_LE(total_input_length, engine_config_->prefill_chunk_size); + if (engine_config_->prefill_chunk_size - total_input_length >= input_length || + engine_config_->prefill_chunk_size == total_input_length) { + // 1. If the input length can fit the remaining prefill chunk size, + // it means the failure of attempt 1 is not because of the input + // length being too long, and thus chunking does not help. + // 2. If the total input length already reaches the prefill chunk size, + // the current request state entry will not be able to be processed. + // So we can safely return in either case. + prefill_stops = true; + break; + } + input_length = engine_config_->prefill_chunk_size - total_input_length; + num_require_pages = (input_length + engine_config_->kv_cache_page_size - 1) / + engine_config_->kv_cache_page_size; + total_input_length += input_length; + total_required_pages += num_require_pages; + if (CanPrefill(estate, num_prefill_rsentries + 1, total_input_length, total_required_pages, + num_available_pages, current_total_seq_len, num_running_rsentries)) { + prefill_inputs.push_back({rsentry, input_length, 0}); + num_prefill_rsentries += 1; } // - Prefill stops here. @@ -354,13 +394,15 @@ class NewRequestPrefillActionObj : public EngineActionObj { bool CanPrefill(EngineState estate, int num_prefill_rsentries, int total_input_length, int num_required_pages, int num_available_pages, int current_total_seq_len, int num_running_rsentries) { - ICHECK_LE(num_running_rsentries, kv_cache_config_->max_num_sequence); + ICHECK_LE(num_running_rsentries, engine_config_->max_num_sequence); // No exceeding of the maximum allowed requests that can // run simultaneously. - int spec_factor = engine_mode_->enable_speculative ? engine_mode_->spec_draft_length : 1; + int spec_factor = engine_config_->speculative_mode != SpeculativeMode::kDisable + ? engine_config_->spec_draft_length + : 1; if ((num_running_rsentries + num_prefill_rsentries) * spec_factor > - std::min(kv_cache_config_->max_num_sequence, kv_cache_config_->prefill_chunk_size)) { + std::min(engine_config_->max_num_sequence, engine_config_->prefill_chunk_size)) { return false; } @@ -371,10 +413,10 @@ class NewRequestPrefillActionObj : public EngineActionObj { // exceed the limit, where 8 is a watermark number can // be configured and adjusted in the future. int new_batch_size = num_running_rsentries + num_prefill_rsentries; - return total_input_length <= kv_cache_config_->prefill_chunk_size && + return total_input_length <= engine_config_->prefill_chunk_size && num_required_pages + new_batch_size <= num_available_pages && current_total_seq_len + total_input_length + 8 * new_batch_size <= - kv_cache_config_->max_total_sequence_length; + engine_config_->max_total_sequence_length; } /*! @@ -458,10 +500,8 @@ class NewRequestPrefillActionObj : public EngineActionObj { Sampler sampler_; /*! \brief Workspace of each model. */ std::vector model_workspaces_; - /*! \brief The KV cache config to help decide prefill is doable. */ - KVCacheConfig kv_cache_config_; - /*! \brief The engine operation mode. */ - EngineMode engine_mode_; + /*! \brief The engine config. */ + EngineConfig engine_config_; /*! \brief Event trace recorder. */ Optional trace_recorder_; }; @@ -469,12 +509,11 @@ class NewRequestPrefillActionObj : public EngineActionObj { EngineAction EngineAction::NewRequestPrefill(Array models, LogitProcessor logit_processor, Sampler sampler, std::vector model_workspaces, - KVCacheConfig kv_cache_config, EngineMode engine_mode, + EngineConfig engine_config, Optional trace_recorder) { return EngineAction(make_object( std::move(models), std::move(logit_processor), std::move(sampler), - std::move(model_workspaces), std::move(kv_cache_config), std::move(engine_mode), - std::move(trace_recorder))); + std::move(model_workspaces), std::move(engine_config), std::move(trace_recorder))); } } // namespace serve diff --git a/cpp/serve/function_table.cc b/cpp/serve/function_table.cc index f4466c875b..fa24828399 100644 --- a/cpp/serve/function_table.cc +++ b/cpp/serve/function_table.cc @@ -13,17 +13,44 @@ #include #include +#include #include #include #include #include "../support/load_bytes_from_file.h" +#include "../support/utils.h" #include "sampler/sampler.h" namespace mlc { namespace llm { namespace serve { +Optional GetDiscoWorkerCPUBinding(int num_workers) { + const char* raw_cpu_binding = std::getenv("MLC_DISCO_WORKER_CPU_BINDING"); + if (raw_cpu_binding == nullptr) { + return NullOpt; + } + + std::string cpu_binding_str(raw_cpu_binding); + std::vector cpu_ids_str = Split(cpu_binding_str, ','); + std::vector cpu_ids; + for (const std::string& cpu_id_str : cpu_ids_str) { + try { + cpu_ids.push_back(std::stol(cpu_id_str)); + } catch (std::invalid_argument const& ex) { + LOG(FATAL) << "Invalid MLC_DISCO_WORKER_CPU_BINDING \"" << cpu_binding_str << "\""; + } + } + if (static_cast(cpu_ids.size()) < num_workers) { + LOG(FATAL) << "Insufficient number of specified CPU workers in MLC_DISCO_WORKER_CPU_BINDING, " + "expecting at least " + << num_workers << "CPU ids but only " << cpu_ids.size() << " are given."; + } + + return IntTuple{cpu_ids}; +} + PackedFunc FunctionTable::SessionFuncAsPackedFunc(Session sess, DRef sess_func, String name) { return PackedFunc([sess, func = std::move(sess_func), name = std::move(name)]( TVMArgs args, TVMRetValue* rv) -> void { @@ -42,7 +69,7 @@ PackedFunc FunctionTable::SessionFuncAsPackedFunc(Session sess, DRef sess_func, }); } -void FunctionTable::Init(TVMArgValue reload_lib, Device device, picojson::object model_config) { +void FunctionTable::Init(String reload_lib_path, Device device, picojson::object model_config) { local_gpu_device = device; Device null_device{DLDeviceType(0), 0}; int num_shards; @@ -58,15 +85,6 @@ void FunctionTable::Init(TVMArgValue reload_lib, Device device, picojson::object this->cached_buffers = Map(); if (num_shards > 1) { - String lib_path{nullptr}; - try { - lib_path = reload_lib.operator String(); - } catch (...) { - LOG(FATAL) - << "ValueError: In multi-GPU inference, we expect the first argument to Reload to be a " - "string path to the model library (.so on Linux or .dll on Windows), but got: " - << ArgTypeCode2Str(reload_lib.type_code()); - } constexpr const char* f_create_process_pool = "runtime.disco.create_process_pool"; if (Registry::Get(f_create_process_pool) == nullptr) { LOG(FATAL) << "Cannot find process launcher `" << f_create_process_pool << "`. " @@ -89,7 +107,7 @@ void FunctionTable::Init(TVMArgValue reload_lib, Device device, picojson::object this->sess = Session::ProcessSession(num_shards, f_create_process_pool, "mlc_llm.cli.worker"); this->sess->InitCCL(ccl, ShapeTuple(device_ids)); this->disco_mod = sess->CallPacked(sess->GetGlobalFunc("runtime.disco.load_vm_module"), - lib_path, null_device); + std::move(reload_lib_path), null_device); this->mod_get_func = [this, fmodule_get_function = sess->GetGlobalFunc("runtime.ModuleGetFunction")]( const std::string& name) -> PackedFunc { @@ -100,6 +118,10 @@ void FunctionTable::Init(TVMArgValue reload_lib, Device device, picojson::object } return SessionFuncAsPackedFunc(sess, func, name); }; + if (Optional cpu_ids = GetDiscoWorkerCPUBinding(/*num_workers=*/num_shards)) { + IntTuple cpu_ids_value = cpu_ids.value(); + sess->CallPacked(sess->GetGlobalFunc("runtime.disco.bind_worker_to_cpu_core"), cpu_ids_value); + } this->get_global_func = [this](const std::string& name) -> PackedFunc { return SessionFuncAsPackedFunc(sess, sess->GetGlobalFunc(name), name); }; @@ -108,11 +130,10 @@ void FunctionTable::Init(TVMArgValue reload_lib, Device device, picojson::object this->_InitFunctions(); } else { Module executable{nullptr}; - if (reload_lib.type_code() == kTVMModuleHandle) { - executable = reload_lib.operator Module(); + if (false) { + // Todo(mlc-team): system lib reload // reload_lib_path starts with "system://" } else { - String lib_path = reload_lib.operator String(); - executable = tvm::runtime::Module::LoadFromFile(lib_path); + executable = tvm::runtime::Module::LoadFromFile(reload_lib_path); } this->use_disco = false; auto fload_exec = executable->GetFunction("vm_load_executable"); @@ -197,7 +218,16 @@ void FunctionTable::_InitFunctions() { this->prefill_func_ = mod_get_func("batch_prefill"); this->decode_func_ = mod_get_func("batch_decode"); this->verify_func_ = mod_get_func("batch_verify"); + this->single_batch_prefill_to_last_hidden_func_ = mod_get_func("prefill_to_last_hidden_states"); + this->single_batch_decode_to_last_hidden_func_ = mod_get_func("decode_to_last_hidden_states"); + this->prefill_to_last_hidden_func_ = mod_get_func("batch_prefill_to_last_hidden_states"); + this->decode_to_last_hidden_func_ = mod_get_func("batch_decode_to_last_hidden_states"); + this->verify_to_last_hidden_func_ = mod_get_func("batch_verify_to_last_hidden_states"); + this->fuse_embed_hidden_func_ = mod_get_func("fuse_embed_hidden_states"); Module mod = this->use_disco ? this->disco_mod->DebugGetFromRemote(0) : this->local_vm; + this->get_logits_func_ = mod->GetFunction("get_logits", true); + this->batch_get_logits_func_ = mod->GetFunction("batch_get_logits", true); + this->batch_select_last_hidden_func_ = mod->GetFunction("batch_select_last_hidden_states", true); this->softmax_func_ = mod->GetFunction("softmax_with_temperature", true); this->apply_logit_bias_func_ = mod->GetFunction("apply_logit_bias_inplace", true); this->apply_penalty_func_ = mod->GetFunction("apply_penalty_inplace", true); @@ -245,7 +275,6 @@ ObjectRef FunctionTable::Empty(ShapeTuple shape, DataType dtype, Device device) ObjectRef FunctionTable::CopyToWorker0(const NDArray& host_array, String buffer_cache_key, ShapeTuple max_reserved_shape) { - ICHECK(host_array->device.device_type == DLDeviceType::kDLCPU); if (this->use_disco) { Device null_device{DLDeviceType(0), 0}; DRef buffer(nullptr); @@ -276,6 +305,16 @@ ObjectRef FunctionTable::CopyToWorker0(const NDArray& host_array, String buffer_ } } +void FunctionTable::DebugCallFuncOnAllAllWorker(const String& func_name) const { + if (this->use_disco) { + sess->CallPacked(sess->GetGlobalFunc(func_name)); + } else { + const PackedFunc* func = Registry::Get(func_name); + CHECK(func != nullptr) << "Global function name \"" << func_name << "\" is not found"; + (*func)(); + } +} + } // namespace serve } // namespace llm } // namespace mlc diff --git a/cpp/serve/function_table.h b/cpp/serve/function_table.h index 29d9d82fbc..f6a156b8a3 100644 --- a/cpp/serve/function_table.h +++ b/cpp/serve/function_table.h @@ -41,7 +41,7 @@ using namespace tvm::runtime; struct FunctionTable { static PackedFunc SessionFuncAsPackedFunc(Session sess, DRef sess_func, String name); - void Init(TVMArgValue reload_lib, Device device, picojson::object model_config); + void Init(String reload_lib_path, Device device, picojson::object model_config); ObjectRef LoadParams(const std::string& model_path, Device device); @@ -52,6 +52,8 @@ struct FunctionTable { ObjectRef CopyToWorker0(const NDArray& host_array, String buffer_cache_key, ShapeTuple max_reserved_shape); + void DebugCallFuncOnAllAllWorker(const String& func_name) const; + bool use_disco = false; Device local_gpu_device; Session sess{nullptr}; @@ -72,6 +74,15 @@ struct FunctionTable { PackedFunc prefill_func_; PackedFunc decode_func_; PackedFunc verify_func_; + PackedFunc single_batch_prefill_to_last_hidden_func_; + PackedFunc single_batch_decode_to_last_hidden_func_; + PackedFunc prefill_to_last_hidden_func_; + PackedFunc decode_to_last_hidden_func_; + PackedFunc verify_to_last_hidden_func_; + PackedFunc fuse_embed_hidden_func_; + PackedFunc get_logits_func_; + PackedFunc batch_get_logits_func_; + PackedFunc batch_select_last_hidden_func_; PackedFunc softmax_func_; PackedFunc apply_logit_bias_func_; PackedFunc apply_penalty_func_; diff --git a/cpp/serve/grammar/grammar.cc b/cpp/serve/grammar/grammar.cc index c4d6445c7e..c8d760538c 100644 --- a/cpp/serve/grammar/grammar.cc +++ b/cpp/serve/grammar/grammar.cc @@ -8,6 +8,7 @@ #include "grammar_parser.h" #include "grammar_serializer.h" #include "grammar_simplifier.h" +#include "json_schema_converter.h" namespace mlc { namespace llm { @@ -20,7 +21,7 @@ std::ostream& operator<<(std::ostream& os, const BNFGrammar& grammar) { return os; } -BNFGrammar BNFGrammar::FromEBNFString(const String& ebnf_string, const String& main_rule, +BNFGrammar BNFGrammar::FromEBNFString(const std::string& ebnf_string, const std::string& main_rule, bool normalize, bool simplify) { auto grammar = EBNFParser::Parse(ebnf_string, main_rule); if (normalize) { @@ -34,7 +35,7 @@ TVM_REGISTER_GLOBAL("mlc.serve.BNFGrammarFromEBNFString") return BNFGrammar::FromEBNFString(ebnf_string, main_rule, normalize, simplify); }); -BNFGrammar BNFGrammar::FromJSON(const String& json_string) { +BNFGrammar BNFGrammar::FromJSON(const std::string& json_string) { return BNFJSONParser::Parse(json_string); } @@ -42,33 +43,31 @@ TVM_REGISTER_GLOBAL("mlc.serve.BNFGrammarFromJSON").set_body_typed([](String jso return BNFGrammar::FromJSON(json_string); }); -BNFGrammar BNFGrammar::FromSchema(const String& schema, int indent, - Optional> separators, bool strict_mode) { - static const PackedFunc* json_schema_to_ebnf = Registry::Get("mlc.serve.json_schema_to_ebnf"); - CHECK(json_schema_to_ebnf != nullptr) << "mlc.serve.json_schema_to_ebnf is not registered."; - - String ebnf_string; - - // Convert the indent parameter to NullOpt for sending it to the PackedFunc. - if (indent == -1) { - // The conversion from TVMRetValue to String is ambiguous, so we call the conversion function - // explicitly - ebnf_string = - ((*json_schema_to_ebnf)(schema, Optional(NullOpt), separators, strict_mode) - . - operator String()); +BNFGrammar BNFGrammar::FromSchema(const std::string& schema, std::optional indent, + std::optional> separators, + bool strict_mode) { + return FromEBNFString(JSONSchemaToEBNF(schema, indent, separators, strict_mode)); +} + +TVM_REGISTER_GLOBAL("mlc.serve.BNFGrammarFromSchema").set_body([](TVMArgs args, TVMRetValue* rv) { + std::optional indent; + if (args[1].type_code() != kTVMNullptr) { + indent = args[1]; } else { - ebnf_string = (*json_schema_to_ebnf)(schema, indent, separators, strict_mode).operator String(); - ; + indent = std::nullopt; } - return FromEBNFString(ebnf_string); -} -TVM_REGISTER_GLOBAL("mlc.serve.BNFGrammarFromSchema") - .set_body_typed([](const String& schema, int indent, Optional> separators, - bool strict_mode) { - return BNFGrammar::FromSchema(schema, indent, separators, strict_mode); - }); + std::optional> separators; + if (args[2].type_code() != kTVMNullptr) { + Array separators_arr = args[2]; + CHECK(separators_arr.size() == 2); + separators = std::make_pair(separators_arr[0], separators_arr[1]); + } else { + separators = std::nullopt; + } + + *rv = BNFGrammar::FromSchema(args[0], indent, separators, args[3]); +}); const std::string kJSONGrammarString = R"( main ::= ( diff --git a/cpp/serve/grammar/grammar.h b/cpp/serve/grammar/grammar.h index 545a4e08a0..ba15e58af3 100644 --- a/cpp/serve/grammar/grammar.h +++ b/cpp/serve/grammar/grammar.h @@ -11,6 +11,7 @@ #include #include +#include #include #include @@ -183,33 +184,38 @@ class BNFGrammar : public ObjectRef { * \param simplify Whether to simplify the grammar to make matching more efficient. Default: true. * Not implemented yet. */ - static BNFGrammar FromEBNFString(const String& ebnf_string, const String& main_rule = "main", - bool normalize = true, bool simplify = true); + static BNFGrammar FromEBNFString(const std::string& ebnf_string, + const std::string& main_rule = "main", bool normalize = true, + bool simplify = true); /*! * \brief Construct a BNF grammar from the dumped JSON string. * \param json_string The JSON-formatted string. This string should have the same format as * the result of BNFGrammarJSONSerializer::ToString. */ - static BNFGrammar FromJSON(const String& json_string); + static BNFGrammar FromJSON(const std::string& json_string); /*! * \brief Construct a BNF grammar from the json schema string. The schema string should be in the * format of the schema of a JSON file. We will parse the schema and generate a BNF grammar. * \param schema The schema string. - * \param indent The number of spaces for indentation. If -1, the output will be in one line. - * Default: -1. + * \param indent The number of spaces for indentation. If set to std::nullopt, the output will be + * in one line. Default: std::nullopt. * \param separators Two separators used in the schema: comma and colon. Examples: {",", ":"}, - * {", ", ": "}. If NullOpt, the default separators will be used: {",", ": "} when the indent - * is not -1, and {", ", ": "} otherwise. Default: NullOpt. + * {", ", ": "}. If std::nullopt, the default separators will be used: {",", ": "} when the + * indent is not -1, and {", ", ": "} otherwise. This follows the convention in python + * json.dumps(). Default: std::nullopt. * \param strict_mode Whether to use strict mode. In strict mode, the generated grammar will not - * allow unevaluatedProperties and unevaluatedItems, i.e. these will be set to false by default. + * allow properties and items that is not specified in the schema. This is equivalent to + * setting unevaluatedProperties and unevaluatedItems to false. + * * This helps LLM to generate accurate output in the grammar-guided generation with JSON * schema. Default: true. */ - static BNFGrammar FromSchema(const String& schema, int indent = -1, - Optional> separators = NullOpt, - bool strict_mode = true); + static BNFGrammar FromSchema( + const std::string& schema, std::optional indent = std::nullopt, + std::optional> separators = std::nullopt, + bool strict_mode = true); /*! * \brief Get the grammar of standard JSON format. We have built-in support for JSON. diff --git a/cpp/serve/grammar/grammar_parser.cc b/cpp/serve/grammar/grammar_parser.cc index ba9ac80135..1ece99099e 100644 --- a/cpp/serve/grammar/grammar_parser.cc +++ b/cpp/serve/grammar/grammar_parser.cc @@ -16,7 +16,7 @@ namespace serve { class EBNFParserImpl { public: /*! \brief The logic of parsing the grammar string. */ - BNFGrammar DoParse(String ebnf_string, String main_rule); + BNFGrammar DoParse(std::string ebnf_string, std::string main_rule); private: using Rule = BNFGrammarNode::Rule; @@ -192,7 +192,7 @@ int32_t EBNFParserImpl::ParseString() { std::vector character_classes; while (Peek() && Peek() != '\"') { if (Peek() == '\r' || Peek() == '\n') { - ThrowParseError("String should not contain newline"); + ThrowParseError("There should be no newline character in a string literal"); } auto [codepoint, len] = Utf8OrEscapeToCodepoint(cur_); if (codepoint == static_cast(CharHandlingError::kInvalidUtf8)) { @@ -391,7 +391,7 @@ void EBNFParserImpl::ResetStringIterator(const char* cur) { in_parentheses_ = false; } -BNFGrammar EBNFParserImpl::DoParse(String ebnf_string, String main_rule) { +BNFGrammar EBNFParserImpl::DoParse(std::string ebnf_string, std::string main_rule) { ResetStringIterator(ebnf_string.c_str()); BuildRuleNameToId(); @@ -412,12 +412,12 @@ BNFGrammar EBNFParserImpl::DoParse(String ebnf_string, String main_rule) { return builder_.Get(main_rule); } -BNFGrammar EBNFParser::Parse(String ebnf_string, String main_rule) { +BNFGrammar EBNFParser::Parse(std::string ebnf_string, std::string main_rule) { EBNFParserImpl parser; return parser.DoParse(ebnf_string, main_rule); } -BNFGrammar BNFJSONParser::Parse(String json_string) { +BNFGrammar BNFJSONParser::Parse(std::string json_string) { auto node = make_object(); auto grammar_json = json::ParseToJsonObject(json_string); auto rules_json = json::Lookup(grammar_json, "rules"); diff --git a/cpp/serve/grammar/grammar_parser.h b/cpp/serve/grammar/grammar_parser.h index be36f40459..4d10e8eb0d 100644 --- a/cpp/serve/grammar/grammar_parser.h +++ b/cpp/serve/grammar/grammar_parser.h @@ -37,7 +37,7 @@ class EBNFParser { * \param main_rule The name of the main rule. Default is "main". * \return The parsed grammar. */ - static BNFGrammar Parse(String ebnf_string, String main_rule = "main"); + static BNFGrammar Parse(std::string ebnf_string, std::string main_rule = "main"); /*! * \brief The exception thrown when parsing fails. @@ -58,7 +58,7 @@ class BNFJSONParser { * \param json_string The JSON string. * \return The parsed BNF grammar. */ - static BNFGrammar Parse(String json_string); + static BNFGrammar Parse(std::string json_string); }; } // namespace serve diff --git a/cpp/serve/grammar/grammar_serializer.cc b/cpp/serve/grammar/grammar_serializer.cc index a057921f61..fd41517863 100644 --- a/cpp/serve/grammar/grammar_serializer.cc +++ b/cpp/serve/grammar/grammar_serializer.cc @@ -107,7 +107,7 @@ std::string BNFGrammarPrinter::PrintCharacterClassStar(const RuleExpr& rule_expr return PrintRuleExpr(rule_expr[0]) + "*"; } -String BNFGrammarPrinter::ToString() { +std::string BNFGrammarPrinter::ToString() { std::string result; auto num_rules = grammar_->NumRules(); for (auto i = 0; i < num_rules; ++i) { @@ -120,7 +120,7 @@ TVM_REGISTER_GLOBAL("mlc.serve.BNFGrammarToString").set_body_typed([](const BNFG return BNFGrammarPrinter(grammar).ToString(); }); -String BNFGrammarJSONSerializer::ToString() { +std::string BNFGrammarJSONSerializer::ToString() { picojson::object grammar_json; picojson::array rules_json; diff --git a/cpp/serve/grammar/grammar_serializer.h b/cpp/serve/grammar/grammar_serializer.h index 5837ce2bf6..8746b1f6ae 100644 --- a/cpp/serve/grammar/grammar_serializer.h +++ b/cpp/serve/grammar/grammar_serializer.h @@ -27,7 +27,7 @@ class BNFGrammarSerializer { explicit BNFGrammarSerializer(const BNFGrammar& grammar) : grammar_(grammar) {} /*! \brief Serialize the grammar to string. */ - virtual String ToString() = 0; + virtual std::string ToString() = 0; protected: const BNFGrammar& grammar_; @@ -50,7 +50,7 @@ class BNFGrammarPrinter : public BNFGrammarSerializer { explicit BNFGrammarPrinter(const BNFGrammar& grammar) : BNFGrammarSerializer(grammar) {} /*! \brief Print the complete grammar. */ - String ToString() final; + std::string ToString() final; /*! \brief Print a rule. */ std::string PrintRule(const Rule& rule); @@ -102,7 +102,7 @@ class BNFGrammarJSONSerializer : public BNFGrammarSerializer { * \brief Dump the raw representation of the AST to a JSON file. * \param prettify Whether to format the JSON string. If false, all whitespaces will be removed. */ - String ToString() final; + std::string ToString() final; private: bool prettify_; diff --git a/cpp/serve/grammar/json_schema_converter.cc b/cpp/serve/grammar/json_schema_converter.cc new file mode 100644 index 0000000000..93d693f3c6 --- /dev/null +++ b/cpp/serve/grammar/json_schema_converter.cc @@ -0,0 +1,987 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file serve/grammar/json_schema_converter.cc + */ +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace mlc { +namespace llm { +namespace serve { + +using namespace tvm::runtime; + +/*! + * \brief Manage the indent and separator for the generation of EBNF grammar. + * \param indent The number of spaces for each indent. If it is std::nullopt, there will be no + * indent or newline. + * \param separator The separator between different elements in json. Examples include "," and ", ". + */ +class IndentManager { + public: + IndentManager(std::optional indent, const std::string& separator) + : enable_newline_(indent.has_value()), + indent_(indent.value_or(0)), + separator_(separator), + total_indent_(0), + is_first_({true}) {} + + /*! \brief Enter a new indent level. */ + void StartIndent() { + total_indent_ += indent_; + is_first_.push_back(true); + } + + /*! \brief Exit the current indent level. */ + void EndIndent() { + total_indent_ -= indent_; + is_first_.pop_back(); + } + + /*! + * \brief Get the next separator in the current level. When first called in the current + * level, the starting separator will be returned. When called again, the middle separator will be + * returned. When called with `is_end=True`, the ending separator will be returned. + * \param is_end Get the separator for the end of the current level. + * \example + * \code + * IndentManager indent_manager(2, ", "); + * indent_manager.StartIndent(); + * indent_manager.GetSep(); // get the start separator: "\"\n \"" + * indent_manager.GetSep(); // get the middle separator: "\",\n \"" + * indent_manager.GetSep(true); // get the end separator: "\"\n\"" + * \endcode + */ + std::string NextSeparator(bool is_end = false); + + /*! \brief Get the separator itself. */ + std::string GetBareSeparator() { return separator_; } + + private: + bool enable_newline_; + int indent_; + std::string separator_; + int total_indent_; + std::vector is_first_; + friend class JSONSchemaToEBNFConverter; +}; + +std::string IndentManager::NextSeparator(bool is_end) { + std::string res = ""; + if (!is_first_.back() && !is_end) { + res += separator_; + } + is_first_.back() = false; + + if (enable_newline_) { + res += "\\n"; + } + + if (!is_end) { + res += std::string(total_indent_, ' '); + } else { + res += std::string(total_indent_ - indent_, ' '); + } + + return "\"" + res + "\""; +} + +/*! + * \brief Convert JSON schema string to EBNF grammar string. The parameters follow + * JSONSchemaToEBNF(). + * + * \note About the representation of json schema in this converter. JSON schema could be two types: + * bool (true or false) or dict (a json dict) containing attributes. We use picojson::value to + * represent the json schema. + */ +class JSONSchemaToEBNFConverter { + public: + JSONSchemaToEBNFConverter( + const picojson::value& json_schema, std::optional indent = std::nullopt, + std::optional> separators = std::nullopt, + bool strict_mode = false); + + /*! \brief The main method. Convert the JSON schema to EBNF grammar string. */ + std::string Convert(); + + private: + // The name of the basic rules + inline static const std::string kBasicAny = "basic_any"; + inline static const std::string kBasicInteger = "basic_integer"; + inline static const std::string kBasicNumber = "basic_number"; + inline static const std::string kBasicString = "basic_string"; + inline static const std::string kBasicBoolean = "basic_boolean"; + inline static const std::string kBasicNull = "basic_null"; + inline static const std::string kBasicArray = "basic_array"; + inline static const std::string kBasicObject = "basic_object"; + + // The name of the helper rules to construct basic rules + inline static const std::string kBasicEscape = "basic_escape"; + inline static const std::string kBasicStringSub = "basic_string_sub"; + + /*! \brief Add the basic rules to the rules list and the basic_rules_cache. */ + void AddBasicRules(); + + /*! \brief Add helper rules for the basic rules. */ + void AddHelperRules(); + + /*! \brief Create a rule for the given schema and name, and add it to the basic_rules_cache. */ + void CreateBasicRule(const picojson::value& schema, const std::string& name); + + /*! \brief Get the index for the schema in the cache. Keys that do not effect the validation + * will be ignored when finding the corresponding cache rule. */ + std::string GetSchemaCacheIndex(const picojson::value& schema); + + /*! + * \brief Create a rule with the given schema and rule name hint. + * \returns The name of the rule will be returned. That is not necessarily the same as the + * rule_name_hint due to the caching mechanism. + */ + std::string CreateRuleFromSchema(const picojson::value& schema, + const std::string& rule_name_hint); + + /*! \brief Get the next separator in the current level from the indent manager. */ + std::string NextSeparator(bool is_end = false); + + /*! \brief Warn if any keyword is existing in the schema but not supported. */ + static void WarnUnsupportedKeywords(const picojson::value& schema, + const std::vector& keywords); + + /*! \brief Warn if any keyword is existing in the object but not supported. */ + static void WarnUnsupportedKeywords(const picojson::object& schema, + const std::vector& keywords); + + /*! \brief Visit the schema and return the rule body for later constructing the rule. */ + std::string VisitSchema(const picojson::value& schema, const std::string& rule_name); + + /*! \brief Visit a reference schema. */ + std::string VisitRef(const picojson::object& schema, const std::string& rule_name); + + /*! \brief Get the schema from the URI. */ + picojson::value URIToSchema(const picojson::value& uri); + + /*! \brief Visit a const schema. */ + std::string VisitConst(const picojson::object& schema, const std::string& rule_name); + + /*! \brief Visit an enum schema. */ + std::string VisitEnum(const picojson::object& schema, const std::string& rule_name); + + /*! \brief Convert the JSON string to a printable string that can be shown in BNF. */ + std::string JSONStrToPrintableStr(const std::string& json_str); + + /*! \brief Visit an anyOf schema. */ + std::string VisitAnyOf(const picojson::object& schema, const std::string& rule_name); + + /*! \brief Visit a true schema that can match anything. */ + std::string VisitAny(const picojson::value& schema, const std::string& rule_name); + + /*! \brief Visit an integer schema. */ + std::string VisitInteger(const picojson::object& schema, const std::string& rule_name); + + /*! \brief Visit a number schema. */ + std::string VisitNumber(const picojson::object& schema, const std::string& rule_name); + /*! \brief Visit a string schema. */ + std::string VisitString(const picojson::object& schema, const std::string& rule_name); + + /*! \brief Visit a boolean schema. */ + std::string VisitBoolean(const picojson::object& schema, const std::string& rule_name); + + /*! \brief Visit a null schema. */ + std::string VisitNull(const picojson::object& schema, const std::string& rule_name); + + /*! + * \brief Visit an array schema. + * \example + * Schema: + * \code + * { + * "type": "array", + * "prefixItems": [ + * {"type": "boolean"}, + * {"type": "integer"} + * ], + * "items": { + * "type": "string" + * } + * } + * \endcode + * Rule (not considering the indent): + * \code + * main ::= "[" basic_boolean ", " basic_integer (", " basic_string)* "]" + * \endcode + */ + std::string VisitArray(const picojson::object& schema, const std::string& rule_name); + + /*! + * \brief Visit an object schema. + * \example + * Schema: + * \code + * { + * "type": "object", + * "properties": { + * "a": {"type": "string"}, + * "b": {"type": "integer"} + * }, + * "required": ["a"], + * "additionalProperties": true + * } + * \endcode + * + * Rule (not considering the indent): + * \code + * main ::= "{" "a" ":" basic_string (", " "b" ":" basic_integer)* + * (", " basic_string ": " basic_any)* "}" + * \endcode + + * We need special handling when all properties are optional, since the handling of separators + * is tricky in this case. E.g. + + * Schema: + * \code + * { + * "type": "object", + * "properties": { + * "a": {"type": "string"}, + * "b": {"type": "integer"}, + * "c": {"type": "boolean"} + * }, + * "additionalProperties": true + * } + * \endcode + * + * Rule (indent=2): + * \code + * main ::= "{" ("\n " (a main_sub_1 | b main_sub_2 | c main_sub_3 | d main_sub_3) + * "\n" | "") "}" + * main_sub_1 ::= ",\n " b r2 | r2 + * main_sub_2 ::= ",\n " c r3 | r3 + * main_sub_3 ::= (",\n " d)* + * \endcode + */ + std::string VisitObject(const picojson::object& schema, const std::string& rule_name); + + /*! \brief Get the pattern for a property in the object schema. */ + std::string GetPropertyPattern(const std::string& prop_name, const picojson::value& prop_schema, + const std::string& rule_name, int idx); + + /*! \brief Get the pattern for the additional/unevaluated properties in the object schema. */ + std::string GetOtherPropertyPattern(const std::string& key_pattern, + const picojson::value& prop_schema, + const std::string& rule_name, + const std::string& rule_name_suffix); + + /*! \brief Get the partial rule for the properties when all properties are optional. See the + * example in VisitObject(). */ + std::string GetPartialRuleForPropertiesAllOptional( + const std::vector>& properties, + const picojson::value& additional, const std::string& rule_name, + const std::string& additional_suffix = ""); + + /*! + * \brief Get the partial rule for the properties when some properties are required. See the + * example in VisitObject(). + * + * The constructed rule should be: + * \code + * start_separator (optional_property separator)? (optional_property separator)? ... + * first_required_property (separator optional_property)? separator required_property ... + * end_separator + * \endcode + * + * i.e. Before the first required property, all properties are in the form + * (property separator) ; and after the first required property, all properties are in the form + * (separator property) . */ + std::string GetPartialRuleForPropertiesContainRequired( + const std::vector>& properties, + const std::unordered_set& required, const std::string& rule_name); + + // The indent manager to get separators + std::unique_ptr indentManager_; + // The root JSON schema + picojson::value json_schema_; + // Whether to use strict mode in conversion. See JSONSchemaToEBNF(). + bool strict_mode_; + // The colon separator + std::string colon_; + // The rules constructed + std::vector> rules_; + // The cache for basic rules. Mapping from the key of schema returned by GetSchemaCacheIndex() + // to the basic rule name. + std::map basic_rules_cache_; +}; + +JSONSchemaToEBNFConverter::JSONSchemaToEBNFConverter( + const picojson::value& json_schema, std::optional indent, + std::optional> separators, bool strict_mode) + : json_schema_(json_schema), strict_mode_(strict_mode) { + if (!separators.has_value()) { + separators = (indent == std::nullopt) ? std::make_pair(", ", ": ") : std::make_pair(",", ": "); + } + indentManager_ = std::make_unique(indent, separators->first); + colon_ = separators->second; + + AddBasicRules(); +} + +std::string JSONSchemaToEBNFConverter::Convert() { + CreateRuleFromSchema(json_schema_, "main"); + std::string res; + for (auto& rule : rules_) { + res += rule.first + " ::= " + rule.second + "\n"; + } + return res; +} + +void JSONSchemaToEBNFConverter::AddBasicRules() { + bool past_strict_mode = strict_mode_; + strict_mode_ = false; + + auto past_indent_manager = std::move(indentManager_); + indentManager_ = + std::make_unique(std::nullopt, past_indent_manager->GetBareSeparator()); + + AddHelperRules(); + CreateBasicRule(picojson::value(true), kBasicAny); + basic_rules_cache_[GetSchemaCacheIndex(picojson::value(picojson::object()))] = kBasicAny; + CreateBasicRule(picojson::value(picojson::object{{"type", picojson::value("integer")}}), + kBasicInteger); + CreateBasicRule(picojson::value(picojson::object{{"type", picojson::value("number")}}), + kBasicNumber); + CreateBasicRule(picojson::value(picojson::object{{"type", picojson::value("string")}}), + kBasicString); + CreateBasicRule(picojson::value(picojson::object{{"type", picojson::value("boolean")}}), + kBasicBoolean); + CreateBasicRule(picojson::value(picojson::object{{"type", picojson::value("null")}}), kBasicNull); + CreateBasicRule(picojson::value(picojson::object{{"type", picojson::value("array")}}), + kBasicArray); + CreateBasicRule(picojson::value(picojson::object{{"type", picojson::value("object")}}), + kBasicObject); + + strict_mode_ = past_strict_mode; + indentManager_ = std::move(past_indent_manager); +} + +void JSONSchemaToEBNFConverter::AddHelperRules() { + rules_.push_back(std::make_pair( + kBasicEscape, "[\"\\\\/bfnrt] | \"u\" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9]")); + rules_.push_back(std::make_pair(kBasicStringSub, "\"\" | [^\"\\\\\\r\\n] " + kBasicStringSub + + " | \"\\\\\" " + kBasicEscape + " " + + kBasicStringSub)); +} + +void JSONSchemaToEBNFConverter::CreateBasicRule(const picojson::value& schema, + const std::string& name) { + std::string rule_name = CreateRuleFromSchema(schema, name); + basic_rules_cache_[GetSchemaCacheIndex(schema)] = rule_name; +} + +std::string JSONSchemaToEBNFConverter::NextSeparator(bool is_end) { + return indentManager_->NextSeparator(is_end); +} + +void JSONSchemaToEBNFConverter::WarnUnsupportedKeywords(const picojson::value& schema, + const std::vector& keywords) { + if (schema.is()) { + return; + } + + ICHECK(schema.is()); + WarnUnsupportedKeywords(schema.get(), keywords); +} + +void JSONSchemaToEBNFConverter::WarnUnsupportedKeywords(const picojson::object& schema, + const std::vector& keywords) { + for (const auto& keyword : keywords) { + if (schema.find(keyword) != schema.end()) { + LOG(WARNING) << "Keyword " << keyword << " is not supported in schema " + << picojson::value(schema); + } + } +} + +std::string JSONSchemaToEBNFConverter::CreateRuleFromSchema(const picojson::value& schema, + const std::string& rule_name_hint) { + std::string idx = GetSchemaCacheIndex(schema); + if (basic_rules_cache_.count(idx)) { + return basic_rules_cache_[idx]; + } + + rules_.push_back(std::make_pair(rule_name_hint, VisitSchema(schema, rule_name_hint))); + return rule_name_hint; +} + +std::string JSONSchemaToEBNFConverter::GetSchemaCacheIndex(const picojson::value& schema) { + // Keys that do not effect the validation + static const std::unordered_set kSkippedKeys = { + "title", "default", "description", "examples", "deprecated", + "readOnly", "writeOnly", "$comment", "$schema", + }; + if (schema.is()) { + // remove skipped keys and sort key by lexicographical order + std::string result = "{"; + std::vector> sorted_kv; + for (const auto& kv : schema.get()) { + if (kSkippedKeys.count(kv.first) == 0) { + sorted_kv.push_back(kv); + } + } + std::sort(sorted_kv.begin(), sorted_kv.end(), + [](const auto& lhs, const auto& rhs) { return lhs.first < rhs.first; }); + int idx = 0; + for (const auto& [key, value] : sorted_kv) { + if (idx != 0) { + result += ","; + } + ++idx; + result += "\"" + key + "\":" + GetSchemaCacheIndex(value); + } + return result + "}"; + } else if (schema.is()) { + std::string result = "["; + int idx = 0; + for (const auto& item : schema.get()) { + if (idx != 0) { + result += ","; + } + ++idx; + result += GetSchemaCacheIndex(item); + } + return result + "]"; + } + // If the object is neither an array nor an object, return it directly + return schema.serialize(false); +} + +std::string JSONSchemaToEBNFConverter::VisitSchema(const picojson::value& schema, + const std::string& rule_name) { + if (schema.is()) { + ICHECK(schema.get()); + return VisitAny(schema, rule_name); + } + + WarnUnsupportedKeywords(schema, { + "allof", + "oneof", + "not", + "if", + "then", + "else", + "dependentRequired", + "dependentSchemas", + }); + + ICHECK(schema.is()); + + const auto& schema_obj = schema.get(); + + if (schema_obj.count("$ref")) { + return VisitRef(schema_obj, rule_name); + } else if (schema_obj.count("const")) { + return VisitConst(schema_obj, rule_name); + } else if (schema_obj.count("enum")) { + return VisitEnum(schema_obj, rule_name); + } else if (schema_obj.count("anyOf")) { + return VisitAnyOf(schema_obj, rule_name); + } else if (schema_obj.count("type")) { + const std::string& type = schema_obj.at("type").get(); + if (type == "integer") { + return VisitInteger(schema_obj, rule_name); + } else if (type == "number") { + return VisitNumber(schema_obj, rule_name); + } else if (type == "string") { + return VisitString(schema_obj, rule_name); + } else if (type == "boolean") { + return VisitBoolean(schema_obj, rule_name); + } else if (type == "null") { + return VisitNull(schema_obj, rule_name); + } else if (type == "array") { + return VisitArray(schema_obj, rule_name); + } else if (type == "object") { + return VisitObject(schema_obj, rule_name); + } else { + LOG(FATAL) << "Unsupported type " << type << " in schema " << schema; + } + } + + // If no above keyword is detected, we treat it as any + return VisitAny(schema, rule_name); +} + +std::string JSONSchemaToEBNFConverter::VisitRef(const picojson::object& schema, + const std::string& rule_name) { + ICHECK(schema.count("$ref")); + picojson::value new_schema = URIToSchema(schema.at("$ref")); + if (!new_schema.is()) { + picojson::object new_schema_obj = new_schema.get(); + for (const auto& [k, v] : schema) { + if (k != "$ref") { + new_schema_obj[k] = v; + } + } + new_schema = picojson::value(new_schema_obj); + } + return VisitSchema(new_schema, rule_name); +} + +picojson::value JSONSchemaToEBNFConverter::URIToSchema(const picojson::value& uri) { + if (uri.get().substr(0, 8) == "#/$defs/") { + return json_schema_.get("$defs").get(uri.get().substr(8)); + } + LOG(WARNING) << "Now only support URI starting with '#/$defs/' but got " << uri; + return picojson::value(true); +} + +std::string JSONSchemaToEBNFConverter::VisitConst(const picojson::object& schema, + const std::string& rule_name) { + ICHECK(schema.count("const")); + // TODO(yixin): Customize serialize to support indent logics + return "\"" + JSONStrToPrintableStr(schema.at("const").serialize()) + "\""; +} + +std::string JSONSchemaToEBNFConverter::VisitEnum(const picojson::object& schema, + const std::string& rule_name) { + ICHECK(schema.count("enum")); + std::string result = ""; + int idx = 0; + for (auto value : schema.at("enum").get()) { + if (idx != 0) { + result += " | "; + } + ++idx; + result += "(\"" + JSONStrToPrintableStr(value.serialize()) + "\")"; + } + return result; +} + +std::string JSONSchemaToEBNFConverter::JSONStrToPrintableStr(const std::string& json_str) { + static const std::vector> kReplaceMapping = {{"\\", "\\\\"}, + {"\"", "\\\""}}; + std::string result = json_str; + for (const auto& [k, v] : kReplaceMapping) { + size_t pos = 0; + while ((pos = result.find(k, pos)) != std::string::npos) { + result.replace(pos, k.length(), v); + pos += v.length(); + } + } + return result; +} + +std::string JSONSchemaToEBNFConverter::VisitAnyOf(const picojson::object& schema, + const std::string& rule_name) { + ICHECK(schema.count("anyOf")); + std::string result = ""; + int idx = 0; + for (auto anyof_schema : schema.at("anyOf").get()) { + if (idx != 0) { + result += " | "; + } + result += CreateRuleFromSchema(anyof_schema, rule_name + "_case_" + std::to_string(idx)); + ++idx; + } + return result; +} + +std::string JSONSchemaToEBNFConverter::VisitAny(const picojson::value& schema, + const std::string& rule_name) { + // Note integer is a subset of number, so we don't need to add integer here + return kBasicNumber + " | " + kBasicString + " | " + kBasicBoolean + " | " + kBasicNull + " | " + + kBasicArray + " | " + kBasicObject; +} + +std::string JSONSchemaToEBNFConverter::VisitInteger(const picojson::object& schema, + const std::string& rule_name) { + ICHECK(schema.count("type")); + ICHECK(schema.at("type").get() == "integer"); + WarnUnsupportedKeywords(schema, { + "multipleOf", + "minimum", + "maximum", + "exclusiveMinimum", + "exclusiveMaximum", + }); + return "(\"0\" | \"-\"? [1-9] [0-9]*) \".0\"?"; +} + +std::string JSONSchemaToEBNFConverter::VisitNumber(const picojson::object& schema, + const std::string& rule_name) { + ICHECK(schema.count("type")); + ICHECK(schema.at("type").get() == "number"); + WarnUnsupportedKeywords(schema, { + "multipleOf", + "minimum", + "maximum", + "exclusiveMinimum", + "exclusiveMaximum", + }); + return "(\"0\" | \"-\"? [1-9] [0-9]*) (\".\" [0-9]+)? ([eE] [+-]? [0-9]+)?"; +} + +std::string JSONSchemaToEBNFConverter::VisitString(const picojson::object& schema, + const std::string& rule_name) { + ICHECK(schema.count("type")); + ICHECK(schema.at("type").get() == "string"); + WarnUnsupportedKeywords(schema, { + "minLength", + "maxLength", + "pattern", + "format", + }); + return "[\"] " + kBasicStringSub + " [\"]"; +} + +std::string JSONSchemaToEBNFConverter::VisitBoolean(const picojson::object& schema, + const std::string& rule_name) { + ICHECK(schema.count("type")); + ICHECK(schema.at("type").get() == "boolean"); + return "\"true\" | \"false\""; +} + +std::string JSONSchemaToEBNFConverter::VisitNull(const picojson::object& schema, + const std::string& rule_name) { + ICHECK(schema.count("type")); + ICHECK(schema.at("type").get() == "null"); + return "\"null\""; +} + +std::string JSONSchemaToEBNFConverter::VisitArray(const picojson::object& schema, + const std::string& rule_name) { + ICHECK(schema.count("type")); + ICHECK(schema.at("type").get() == "array"); + WarnUnsupportedKeywords(schema, { + "uniqueItems", + "contains", + "minContains", + "maxContains", + "minItems", + "maxItems", + }); + + std::string result = "\"[\""; + + indentManager_->StartIndent(); + + // 1. Handle prefix items + if (schema.count("prefixItems")) { + const auto& prefix_items = schema.at("prefixItems").get(); + for (int i = 0; i < prefix_items.size(); ++i) { + ICHECK(prefix_items[i].is()); + result += " " + NextSeparator() + " "; + result += CreateRuleFromSchema(prefix_items[i], rule_name + "_item_" + std::to_string(i)); + } + } + + // 2. Find additional items + picojson::value additional_item = picojson::value(false); + std::string additional_suffix = ""; + + if (schema.count("items") && (!schema.at("items").is() || schema.at("items").get())) { + additional_item = schema.at("items"); + additional_suffix = "items"; + } + + // If items is specified in the schema, we don't need to consider unevaluatedItems + if (schema.count("items") == 0) { + picojson::value unevaluated = schema.count("unevaluatedItems") ? schema.at("unevaluatedItems") + : picojson::value(!strict_mode_); + if (!unevaluated.is() || unevaluated.get()) { + additional_item = unevaluated; + additional_suffix = "uneval"; + } + } + + // 3. Handle additional items and the end separator + bool could_be_empty = false; + if (additional_item.is() && !additional_item.get()) { + result += " " + NextSeparator(true); + } else { + std::string additional_pattern = + CreateRuleFromSchema(additional_item, rule_name + "_" + additional_suffix); + if (schema.count("prefixItems")) { + result += " (" + NextSeparator() + " " + additional_pattern + ")* "; + result += NextSeparator(true); + } else { + result += " " + NextSeparator() + " " + additional_pattern + " ("; + result += NextSeparator() + " " + additional_pattern + ")* "; + result += NextSeparator(true); + could_be_empty = true; + } + } + + indentManager_->EndIndent(); + + result += " \"]\""; + + if (could_be_empty) { + result = "(" + result + ") | \"[]\""; + } + + return result; +} + +std::string JSONSchemaToEBNFConverter::GetPropertyPattern(const std::string& prop_name, + const picojson::value& prop_schema, + const std::string& rule_name, int idx) { + // the outer quote is for the string in EBNF grammar, and the inner quote is for + // the string in JSON + std::string key = "\"\\\"" + prop_name + "\\\"\""; + std::string colon = "\"" + colon_ + "\""; + std::string value = CreateRuleFromSchema(prop_schema, rule_name + "_prop_" + std::to_string(idx)); + return key + " " + colon + " " + value; +} + +std::string JSONSchemaToEBNFConverter::GetOtherPropertyPattern( + const std::string& key_pattern, const picojson::value& prop_schema, + const std::string& rule_name, const std::string& rule_name_suffix) { + std::string colon = "\"" + colon_ + "\""; + std::string value = CreateRuleFromSchema(prop_schema, rule_name + "_" + rule_name_suffix); + return key_pattern + " " + colon + " " + value; +} + +std::string JSONSchemaToEBNFConverter::GetPartialRuleForPropertiesAllOptional( + const std::vector>& properties, + const picojson::value& additional, const std::string& rule_name, + const std::string& additional_suffix) { + ICHECK(properties.size() >= 1); + + std::string first_sep = NextSeparator(); + std::string mid_sep = NextSeparator(); + std::string last_sep = NextSeparator(true); + + std::string res = ""; + + std::vector prop_patterns; + int idx = 0; + for (const auto& [prop_name, prop_schema] : properties) { + prop_patterns.push_back(GetPropertyPattern(prop_name, prop_schema, rule_name, idx)); + ++idx; + } + + std::vector rule_names(properties.size(), ""); + + // construct the last rule + std::string additional_prop_pattern; + if (!additional.is() || additional.get()) { + additional_prop_pattern = + GetOtherPropertyPattern(kBasicString, additional, rule_name, additional_suffix); + std::string last_rule_body = "(" + mid_sep + " " + additional_prop_pattern + ")*"; + std::string last_rule_name = rule_name + "_part_" + std::to_string(properties.size() - 1); + rules_.push_back(std::make_pair(last_rule_name, last_rule_body)); + rule_names.back() = last_rule_name; + } else { + rule_names.back() = "\"\""; + } + + // construct 0~(len(properties) - 2) rules + for (int i = properties.size() - 2; i >= 0; --i) { + const std::string& prop_pattern = prop_patterns[i + 1]; + const std::string& last_rule_name = rule_names[i + 1]; + std::string cur_rule_body = + last_rule_name + " | " + mid_sep + " " + prop_pattern + " " + last_rule_name; + std::string cur_rule_name = rule_name + "_part_" + std::to_string(i); + rules_.push_back(std::make_pair(cur_rule_name, cur_rule_body)); + rule_names[i] = cur_rule_name; + } + + // construct the main rule + for (int i = 0; i < properties.size(); ++i) { + if (i != 0) { + res += " | "; + } + res += "(" + prop_patterns[i] + " " + rule_names[i] + ")"; + } + + if (!additional.is() || additional.get()) { + res += " | " + additional_prop_pattern + " " + rule_names.back(); + } + + // add separators and the empty string option + res = first_sep + " (" + res + ") " + last_sep; + return res; +} + +std::string JSONSchemaToEBNFConverter::GetPartialRuleForPropertiesContainRequired( + const std::vector>& properties, + const std::unordered_set& required, const std::string& rule_name) { + // Find the index of the first required property + int first_required_idx = properties.size(); + for (int i = 0; i < properties.size(); ++i) { + if (required.count(properties[i].first)) { + first_required_idx = i; + break; + } + } + ICHECK(first_required_idx < properties.size()); + + std::string res = NextSeparator(); + + // Handle the properties before the first required property + for (int i = 0; i < first_required_idx; ++i) { + const auto& [prop_name, prop_schema] = properties[i]; + ICHECK(!prop_schema.is() || prop_schema.get()); + std::string property_pattern = GetPropertyPattern(prop_name, prop_schema, rule_name, i); + res += " (" + property_pattern + " " + NextSeparator() + ")?"; + } + + // Handle the first required property + const auto& [prop_name, prop_schema] = properties[first_required_idx]; + std::string property_pattern = + GetPropertyPattern(prop_name, prop_schema, rule_name, first_required_idx); + res += " " + property_pattern; + + // Handle the properties after the first required property + for (int i = first_required_idx + 1; i < properties.size(); ++i) { + const auto& [prop_name, prop_schema] = properties[i]; + ICHECK(!prop_schema.is() || prop_schema.get()); + std::string property_pattern = GetPropertyPattern(prop_name, prop_schema, rule_name, i); + if (required.count(prop_name)) { + res += " " + NextSeparator() + " " + property_pattern; + } else { + res += " (" + NextSeparator() + " " + property_pattern + ")?"; + } + } + + return res; +} + +std::string JSONSchemaToEBNFConverter::VisitObject(const picojson::object& schema, + const std::string& rule_name) { + ICHECK(schema.count("type")); + ICHECK(schema.at("type").get() == "object"); + WarnUnsupportedKeywords(schema, { + "patternProperties", + "minProperties", + "maxProperties", + "propertyNames", + }); + + std::string result = "\"{\""; + + // could_be_empty will be set to True when the rule could be "{}". We will handle this case at + // last, and handle non-empty cases before that. + bool could_be_empty = false; + + indentManager_->StartIndent(); + + // 1. Handle properties + std::vector> properties; + if (schema.count("properties")) { + auto properties_obj = schema.at("properties").get(); + for (const auto& key : properties_obj.ordered_keys()) { + properties.push_back({key, properties_obj.at(key)}); + } + } + + std::unordered_set required; + if (schema.count("required")) { + for (const auto& required_prop : schema.at("required").get()) { + required.insert(required_prop.get()); + } + } + + // 2. Find additional properties + picojson::value additional_property = picojson::value(false); + std::string additional_suffix = ""; + + if (schema.count("additionalProperties") && (!schema.at("additionalProperties").is() || + schema.at("additionalProperties").get())) { + additional_property = schema.at("additionalProperties"); + additional_suffix = "addl"; + } + + if (schema.count("additionalProperties") == 0) { + picojson::value unevaluated = schema.count("unevaluatedProperties") + ? schema.at("unevaluatedProperties") + : picojson::value(!strict_mode_); + if (!unevaluated.is() || unevaluated.get()) { + additional_property = unevaluated; + additional_suffix = "uneval"; + } + } + + bool is_all_properties_optional = + std::all_of(properties.begin(), properties.end(), + [&](const auto& prop) { return required.count(prop.first) == 0; }); + + if (is_all_properties_optional && properties.size() > 0) { + // 3.1 Case 1: properties are defined and all properties are optional + result += " " + GetPartialRuleForPropertiesAllOptional(properties, additional_property, + rule_name, additional_suffix); + could_be_empty = true; + } else if (properties.size() > 0) { + // 3.2 Case 2: properties are defined and some properties are required + result += " " + GetPartialRuleForPropertiesContainRequired(properties, required, rule_name); + if (!additional_property.is() || additional_property.get()) { + std::string other_property_pattern = + GetOtherPropertyPattern(kBasicString, additional_property, rule_name, additional_suffix); + result += " (" + NextSeparator() + " " + other_property_pattern + ")*"; + } + result += " " + NextSeparator(true); + } else if (!additional_property.is() || additional_property.get()) { + // 3.3 Case 3: no properties are defined and additional properties are allowed + std::string other_property_pattern = + GetOtherPropertyPattern(kBasicString, additional_property, rule_name, additional_suffix); + result += " " + NextSeparator() + " " + other_property_pattern + " ("; + result += NextSeparator() + " " + other_property_pattern + ")* "; + result += NextSeparator(true); + could_be_empty = true; + } + + indentManager_->EndIndent(); + + result += " \"}\""; + if (could_be_empty) { + result = "(" + result + ") | \"{}\""; + } + + return result; +}; + +std::string JSONSchemaToEBNF(std::string schema, std::optional indent, + std::optional> separators, + bool strict_mode) { + picojson::value schema_value; + std::string err = picojson::parse(schema_value, schema); + if (!err.empty()) { + LOG(FATAL) << "Failed to parse JSON: err. The JSON string is:" << schema; + } + JSONSchemaToEBNFConverter converter(schema_value, indent, separators, strict_mode); + return converter.Convert(); +} + +TVM_REGISTER_GLOBAL("mlc.serve.DebugJSONSchemaToEBNF").set_body([](TVMArgs args, TVMRetValue* rv) { + std::optional indent; + if (args[1].type_code() != kTVMNullptr) { + indent = args[1]; + } else { + indent = std::nullopt; + } + + std::optional> separators; + if (args[2].type_code() != kTVMNullptr) { + Array separators_arr = args[2]; + CHECK(separators_arr.size() == 2); + separators = std::make_pair(separators_arr[0], separators_arr[1]); + } else { + separators = std::nullopt; + } + + *rv = JSONSchemaToEBNF(args[0], indent, separators, args[3]); +}); + +} // namespace serve +} // namespace llm +} // namespace mlc diff --git a/cpp/serve/grammar/json_schema_converter.h b/cpp/serve/grammar/json_schema_converter.h new file mode 100644 index 0000000000..22c730aa41 --- /dev/null +++ b/cpp/serve/grammar/json_schema_converter.h @@ -0,0 +1,44 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file serve/grammar/json_grammar_converter.h + * \brief The header for translating JSON schema to EBNF grammar. + */ + +#ifndef MLC_LLM_SERVE_GRAMMAR_JSON_SCHEMA_CONVERTER_H_ +#define MLC_LLM_SERVE_GRAMMAR_JSON_SCHEMA_CONVERTER_H_ + +#include +#include +#include + +namespace mlc { +namespace llm { +namespace serve { + +/*! + * \brief Convert JSON schema string to EBNF grammar string. + * \param json_schema The JSON schema string. + * \param indent The number of spaces for indentation. If set to std::nullopt, the output will be + * in one line. Default: std::nullopt. + * \param separators Two separators used in the schema: comma and colon. Examples: {",", ":"}, + * {", ", ": "}. If std::nullopt, the default separators will be used: {",", ": "} when the + * indent is not -1, and {", ", ": "} otherwise. This follows the convention in python json.dumps(). + * Default: std::nullopt. + * \param strict_mode Whether to use strict mode. In strict mode, the generated grammar will not + * allow properties and items that is not specified in the schema. This is equivalent to + * setting unevaluatedProperties and unevaluatedItems to false. + * + * This helps LLM to generate accurate output in the grammar-guided generation with JSON + * schema. Default: true. + * \returns The EBNF grammar string. + */ +std::string JSONSchemaToEBNF( + std::string schema, std::optional indent = std::nullopt, + std::optional> separators = std::nullopt, + bool strict_mode = true); + +} // namespace serve +} // namespace llm +} // namespace mlc + +#endif // MLC_LLM_SERVE_GRAMMAR_JSON_SCHEMA_CONVERTER_H_ diff --git a/cpp/serve/logit_processor.cc b/cpp/serve/logit_processor.cc index 9dc4b1b9c5..f7190d50ac 100644 --- a/cpp/serve/logit_processor.cc +++ b/cpp/serve/logit_processor.cc @@ -289,7 +289,7 @@ class LogitProcessorImpl : public LogitProcessorObj { p_penalties[num_token_for_penalty * 3 + 2] = generation_cfg[i]->repetition_penalty; ++num_token_for_penalty; if (j > 0) { - mstates[i]->AddDraftToken(draft_tokens->at(i)[j - 1], NDArray()); + mstates[i]->AddDraftToken(draft_tokens->at(i)[j - 1], NDArray(), NDArray()); } } if (num_token_to_process != 1) { @@ -368,7 +368,7 @@ class LogitProcessorImpl : public LogitProcessorObj { p_seq_ids[token_start_offset + j] = 1; } if (j > 0) { - mstates[i]->AddDraftToken(draft_tokens->at(i)[j - 1], NDArray()); + mstates[i]->AddDraftToken(draft_tokens->at(i)[j - 1], NDArray(), NDArray()); } } if (token_number != 1) { diff --git a/cpp/serve/model.cc b/cpp/serve/model.cc index 5ebf26a061..17121d8e28 100644 --- a/cpp/serve/model.cc +++ b/cpp/serve/model.cc @@ -25,10 +25,10 @@ class ModelImpl; TVM_REGISTER_OBJECT_TYPE(ModelObj); -Model Model::Create(TVMArgValue reload_lib, String model_path, DLDevice device, +Model Model::Create(String reload_lib_path, String model_path, DLDevice device, int max_num_sequence, bool trace_enabled) { return Model( - make_object(reload_lib, model_path, device, max_num_sequence, trace_enabled)); + make_object(reload_lib_path, model_path, device, max_num_sequence, trace_enabled)); } class ModelImpl : public ModelObj { @@ -37,7 +37,7 @@ class ModelImpl : public ModelObj { * \brief Constructor of ModelImpl. * \sa Model::Create */ - explicit ModelImpl(TVMArgValue reload_lib, String model_path, DLDevice device, + explicit ModelImpl(String reload_lib_path, String model_path, DLDevice device, int max_num_sequence, bool trace_enabled) : device_(device) { // Step 1. Process model config json string. @@ -53,7 +53,7 @@ class ModelImpl : public ModelObj { // Step 2. Initialize vm, we use the packed function mechanism // so there is no explicit abi dependency on these extra // classes other than basic tvm runtime. - this->ft_.Init(reload_lib, device_, model_config); + this->ft_.Init(reload_lib_path, device_, model_config); // Step 3. Load params in nd-array cache. this->params_ = ft_.LoadParams(model_path, device_); // Step 4. Set max_num_sequence @@ -116,6 +116,223 @@ class ModelImpl : public ModelObj { } } + bool CanGetLogits() final { + return ft_.get_logits_func_.defined() && ft_.batch_get_logits_func_.defined(); + } + + NDArray GetLogits(const ObjectRef& last_hidden_states, int batch_size, int seq_len) final { + NVTXScopedRange nvtx_scope("GetLogits"); + CHECK(ft_.get_logits_func_.defined()) << "`get_logits` function is not found in the model."; + + ObjectRef hidden_states_dref_or_nd; + CHECK(!last_hidden_states->IsInstance()); + // hidden_states: (b, s, h) + NDArray hidden_states = Downcast(last_hidden_states); + ICHECK_NE(hidden_size_, -1); + ICHECK_EQ(hidden_states->ndim, 3); + ICHECK_EQ(hidden_states->shape[0], batch_size); + ICHECK_EQ(hidden_states->shape[1], seq_len); + ICHECK_EQ(hidden_states->shape[2], hidden_size_); + ICHECK_EQ(hidden_states->device.device_type, device_.device_type); + ICHECK_EQ(hidden_states->device.device_id, device_.device_id); + + hidden_states_dref_or_nd = + hidden_states.CreateView({batch_size * seq_len, hidden_size_}, hidden_states->dtype); + + ObjectRef ret = ft_.get_logits_func_(hidden_states_dref_or_nd, params_); + if (trace_enabled_) { + TVMSynchronize(device_.device_type, device_.device_id, nullptr); + } + + NDArray logits; + logits = Downcast(ret); + CHECK(logits.defined()); + // logits: (b * s, v) + ICHECK_EQ(logits->ndim, 2); + ICHECK_EQ(logits->shape[0], batch_size * seq_len); + return logits.CreateView({batch_size, seq_len, logits->shape[1]}, logits->dtype); + } + + NDArray BatchGetLogits(const ObjectRef& last_hidden_states, const std::vector& seq_ids, + const std::vector& lengths) { + NVTXScopedRange nvtx_scope("BatchGetLogits"); + CHECK(!seq_ids.empty()); + CHECK_EQ(seq_ids.size(), lengths.size()); + int num_sequences = seq_ids.size(); + int total_length = 0; + + int* p_logit_pos = static_cast(logit_pos_arr_->data); + for (int i = 0; i < num_sequences; ++i) { + total_length += lengths[i]; + p_logit_pos[i] = total_length - 1; + } + NDArray logit_pos_nd = logit_pos_arr_.CreateView({num_sequences}, DataType::Int(32)); + ObjectRef logit_pos_dref_or_nd = + ft_.CopyToWorker0(logit_pos_nd, "logit_pos", {max_num_sequence_}); + + CHECK(ft_.batch_get_logits_func_.defined()) + << "`batch_get_logits` function is not found in the model."; + + ObjectRef hidden_states_dref_or_nd; + CHECK(!last_hidden_states->IsInstance()); + // hidden_states: (b, s, h) + NDArray hidden_states = Downcast(last_hidden_states); + ICHECK_NE(hidden_size_, -1); + ICHECK_EQ(hidden_states->ndim, 3); + ICHECK_EQ(hidden_states->shape[0], 1); + ICHECK_EQ(hidden_states->shape[1], total_length); + ICHECK_EQ(hidden_states->shape[2], hidden_size_); + ICHECK_EQ(hidden_states->device.device_type, device_.device_type); + ICHECK_EQ(hidden_states->device.device_id, device_.device_id); + + hidden_states_dref_or_nd = + hidden_states.CreateView({total_length, hidden_size_}, hidden_states->dtype); + + ObjectRef ret = + ft_.batch_get_logits_func_(hidden_states_dref_or_nd, logit_pos_dref_or_nd, params_); + if (trace_enabled_) { + TVMSynchronize(device_.device_type, device_.device_id, nullptr); + } + + NDArray logits; + logits = Downcast(ret); + CHECK(logits.defined()); + // logits: (b * s, v) + ICHECK_EQ(logits->ndim, 2); + ICHECK_EQ(logits->shape[0], num_sequences); + return logits.CreateView({1, num_sequences, logits->shape[1]}, logits->dtype); + } + + NDArray BatchSelectLastHidden(const ObjectRef& last_hidden_states, + const std::vector& seq_ids, + const std::vector& lengths) { + NVTXScopedRange nvtx_scope("BatchSelectLastHidden"); + CHECK(!seq_ids.empty()); + CHECK_EQ(seq_ids.size(), lengths.size()); + int num_sequences = seq_ids.size(); + int total_length = 0; + + int* p_logit_pos = static_cast(logit_pos_arr_->data); + for (int i = 0; i < num_sequences; ++i) { + total_length += lengths[i]; + p_logit_pos[i] = total_length - 1; + } + NDArray logit_pos_nd = logit_pos_arr_.CreateView({num_sequences}, DataType::Int(32)); + ObjectRef logit_pos_dref_or_nd = + ft_.CopyToWorker0(logit_pos_nd, "logit_pos", {max_num_sequence_}); + + CHECK(ft_.batch_select_last_hidden_func_.defined()) + << "`batch_select_last_hidden_states` function is not found in the model."; + + ObjectRef hidden_states_dref_or_nd; + CHECK(!last_hidden_states->IsInstance()); + // hidden_states: (b, s, h) + NDArray hidden_states = Downcast(last_hidden_states); + ICHECK_NE(hidden_size_, -1); + ICHECK_EQ(hidden_states->ndim, 3); + ICHECK_EQ(hidden_states->shape[0], 1); + ICHECK_EQ(hidden_states->shape[1], total_length); + ICHECK_EQ(hidden_states->shape[2], hidden_size_); + ICHECK_EQ(hidden_states->device.device_type, device_.device_type); + ICHECK_EQ(hidden_states->device.device_id, device_.device_id); + + hidden_states_dref_or_nd = + hidden_states.CreateView({total_length, hidden_size_}, hidden_states->dtype); + + ObjectRef ret = + ft_.batch_select_last_hidden_func_(hidden_states_dref_or_nd, logit_pos_dref_or_nd, params_); + if (trace_enabled_) { + TVMSynchronize(device_.device_type, device_.device_id, nullptr); + } + + NDArray hidden; + hidden = Downcast(ret); + // hidden: (b * s, v) + ICHECK_EQ(hidden->ndim, 2); + ICHECK_EQ(hidden->shape[0], num_sequences); + return hidden.CreateView({1, num_sequences, hidden->shape[1]}, hidden->dtype); + } + + NDArray ConcatLastHidden(std::vector& hidden_states, ObjectRef* dst) final { + NVTXScopedRange nvtx_scope("ConcatLastHidden"); + + CHECK(dst->defined()); + + int cum_length = 0; + ICHECK_GE(hidden_states.size(), 1); + for (auto hidden : hidden_states) { + ICHECK_EQ(hidden->ndim, 1); + // No ICHECK_EQ(hidden->shape[0], hidden_size_) here to allow different hidden_sizes. + hidden = hidden.CreateView({1, hidden_size_}, hidden->dtype); + // Reuse the copy embedding function + ft_.nd_copy_embedding_to_offset_func_(hidden, *dst, cum_length); + cum_length += 1; + } + NDArray ret = Downcast(*dst); + ret = ret.CreateView({cum_length, hidden_size_}, hidden_states[0]->dtype); + return ret; + } + + ObjectRef FuseEmbedHidden(const ObjectRef& embeddings, const ObjectRef& previous_hidden_states, + int batch_size, int seq_len) final { + NVTXScopedRange nvtx_scope("FuseEmbedHidden"); + + ObjectRef embeddings_dref_or_nd; + if (!embeddings->IsInstance()) { + // embeddings: (n, h) + NDArray embeddings_nd = Downcast(embeddings); + ICHECK_NE(hidden_size_, -1); + ICHECK_EQ(embeddings_nd->ndim, 2); + ICHECK_GE(embeddings_nd->shape[0], batch_size * seq_len); + ICHECK_EQ(embeddings_nd->shape[1], hidden_size_); + ICHECK_EQ(embeddings_nd->device.device_type, device_.device_type); + ICHECK_EQ(embeddings_nd->device.device_id, device_.device_id); + embeddings_dref_or_nd = + embeddings_nd.CreateView({batch_size * seq_len, hidden_size_}, embeddings_nd->dtype); + + if (!ft_.fuse_embed_hidden_func_.defined() || !previous_hidden_states.defined()) { + // Model has no support for fuse_embed_hidden_states or this is the first model (base model) + return embeddings_nd.CreateView({batch_size, seq_len, hidden_size_}, embeddings_nd->dtype); + } + } else { + ShapeTuple embedding_shape{batch_size, seq_len, hidden_size_}; + embeddings_dref_or_nd = ft_.nd_view_func_(embeddings, embedding_shape); + + if (!ft_.fuse_embed_hidden_func_.defined() || !previous_hidden_states.defined()) { + // Model has no support for fuse_embed_hidden_states or this is the first model (base model) + ShapeTuple embedding_shape{batch_size, seq_len, hidden_size_}; + return ft_.nd_view_func_(embeddings, embedding_shape); + } + } + + NDArray hidden_states = Downcast(previous_hidden_states); + CHECK(hidden_states.defined()); + ICHECK_EQ(hidden_states->ndim, 3); + ICHECK_EQ(hidden_states->shape[0], batch_size); + ICHECK_EQ(hidden_states->shape[1], seq_len); + ICHECK_EQ(hidden_states->shape[2], hidden_size_); + ICHECK_EQ(hidden_states->device.device_type, device_.device_type); + ICHECK_EQ(hidden_states->device.device_id, device_.device_id); + NDArray hidden_states_2d = + hidden_states.CreateView({batch_size * seq_len, hidden_size_}, hidden_states->dtype); + auto hidden_states_dref_or_nd = + ft_.CopyToWorker0(hidden_states_2d, "hidden_states_2d", + {max_num_sequence_ * prefill_chunk_size_, hidden_size_}); + + ObjectRef ret = + ft_.fuse_embed_hidden_func_(embeddings_dref_or_nd, hidden_states_dref_or_nd, params_); + if (trace_enabled_) { + TVMSynchronize(device_.device_type, device_.device_id, nullptr); + } + if (!ret->IsInstance()) { + NDArray fused = Downcast(ret); + return fused.CreateView({batch_size, seq_len, hidden_size_}, fused->dtype); + } else { + ShapeTuple fused_shape{batch_size, seq_len, hidden_size_}; + return ft_.nd_view_func_(ret, fused_shape); + } + } + NDArray BatchPrefill(const ObjectRef& embeddings, const std::vector& seq_ids, const std::vector& lengths) final { NVTXScopedRange nvtx_scope("BatchPrefill"); @@ -187,6 +404,74 @@ class ModelImpl : public ModelObj { return logits; } + NDArray BatchPrefillToLastHidden(const ObjectRef& hidden_states, + const std::vector& seq_ids, + const std::vector& lengths) final { + NVTXScopedRange nvtx_scope("BatchPrefillToLastHidden"); + CHECK(!seq_ids.empty()); + CHECK_EQ(seq_ids.size(), lengths.size()); + int num_sequences = seq_ids.size(); + int total_length = 0; + + for (int i = 0; i < num_sequences; ++i) { + total_length += lengths[i]; + } + + ObjectRef hidden_states_dref_or_nd; + if (!hidden_states->IsInstance()) { + // hidden_states: (1, n, h) + NDArray hidden_states_nd = Downcast(hidden_states); + ICHECK_EQ(hidden_states_nd->ndim, 3); + ICHECK_EQ(hidden_states_nd->shape[0], 1); + ICHECK_EQ(hidden_states_nd->shape[1], total_length); + ICHECK_EQ(hidden_states_nd->shape[2], hidden_size_); + hidden_states_dref_or_nd = + hidden_states_nd.CreateView({1, total_length, hidden_size_}, hidden_states_nd->dtype); + } else { + ShapeTuple hidden_states_shape{1, total_length, hidden_size_}; + hidden_states_dref_or_nd = ft_.nd_view_func_(hidden_states, hidden_states_shape); + } + + CHECK(ft_.prefill_to_last_hidden_func_.defined()) + << "`prefill_to_last_hidden_states` function is not found in the model."; + ICHECK(ft_.kv_cache_begin_forward_func_.defined()); + ICHECK(ft_.kv_cache_end_forward_func_.defined()); + ICHECK(kv_cache_.defined()) << "KV cache has not been initialized."; + + // Begin forward with the sequence ids and new lengths. + IntTuple seq_ids_tuple(seq_ids); + IntTuple lengths_tuple(lengths.begin(), lengths.end()); + ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple); + + // args: embeddings, logit_pos, kv_cache, params + ObjectRef ret; + if (seq_ids.size() == 1) { + CHECK(ft_.single_batch_prefill_to_last_hidden_func_.defined()) + << "`single_batch_prefill_to_last_hidden_states` function is not found in the model."; + ret = ft_.single_batch_prefill_to_last_hidden_func_(hidden_states_dref_or_nd, kv_cache_, + params_); + } else { + ret = ft_.prefill_to_last_hidden_func_(hidden_states_dref_or_nd, kv_cache_, params_); + } + NDArray last_hidden_states; + if (ft_.use_disco) { + Array result = Downcast(ret)->DebugGetFromRemote(0); + last_hidden_states = Downcast(result[0]); + } else { + last_hidden_states = Downcast>(ret)[0]; + } + if (trace_enabled_) { + TVMSynchronize(device_.device_type, device_.device_id, nullptr); + } + ft_.kv_cache_end_forward_func_(kv_cache_); + + // hidden_states: (1, total_length, v) + ICHECK_EQ(last_hidden_states->ndim, 3); + ICHECK_EQ(last_hidden_states->shape[0], 1); + ICHECK_EQ(last_hidden_states->shape[1], total_length); + return last_hidden_states; + } + NDArray BatchDecode(const ObjectRef& embeddings, const std::vector& seq_ids) final { NVTXScopedRange nvtx_scope("BatchDecode"); int num_sequence = seq_ids.size(); @@ -247,6 +532,67 @@ class ModelImpl : public ModelObj { return logits; } + NDArray BatchDecodeToLastHidden(const ObjectRef& hidden_states, + const std::vector& seq_ids) final { + NVTXScopedRange nvtx_scope("BatchDecodeToLastHidden"); + int num_sequence = seq_ids.size(); + + CHECK(ft_.decode_to_last_hidden_func_.defined()) + << "`batch_decode_to_last_hidden_states` function is not found in the model."; + ICHECK(ft_.kv_cache_begin_forward_func_.defined()); + ICHECK(ft_.kv_cache_end_forward_func_.defined()); + ICHECK(kv_cache_.defined()) << "KV cache has not been initialized."; + + ObjectRef hidden_states_dref_or_nd; + if (!hidden_states->IsInstance()) { + // hidden_states: (1, n, h) + NDArray hidden_states_nd = Downcast(hidden_states); + ICHECK_EQ(hidden_states_nd->ndim, 3); + ICHECK_EQ(hidden_states_nd->shape[0], num_sequence); + ICHECK_EQ(hidden_states_nd->shape[1], 1); + ICHECK_EQ(hidden_states_nd->shape[2], hidden_size_); + hidden_states_dref_or_nd = + hidden_states_nd.CreateView({num_sequence, 1, hidden_size_}, hidden_states_nd->dtype); + } else { + ShapeTuple hidden_states_shape{num_sequence, 1, hidden_size_}; + hidden_states_dref_or_nd = ft_.nd_view_func_(hidden_states, hidden_states_shape); + } + + // Reserve in KV cache for the lengths of the input. + // Begin forward with the sequence ids and new lengths. + IntTuple seq_ids_tuple(seq_ids); + IntTuple lengths_tuple(std::vector(/*n=*/seq_ids.size(), /*v=*/1)); + ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple); + + // args: embeddings, kv_cache, params + ObjectRef ret; + if (seq_ids.size() == 1) { + CHECK(ft_.single_batch_decode_to_last_hidden_func_.defined()) + << "`decode_to_last_hidden_states` function is not found in the model."; + ret = ft_.single_batch_decode_to_last_hidden_func_(hidden_states_dref_or_nd, kv_cache_, + params_); + } else { + ret = ft_.decode_to_last_hidden_func_(hidden_states_dref_or_nd, kv_cache_, params_); + } + NDArray last_hidden_states; + if (ft_.use_disco) { + Array result = Downcast(ret)->DebugGetFromRemote(0); + last_hidden_states = Downcast(result[0]); + } else { + last_hidden_states = Downcast>(ret)[0]; + } + if (trace_enabled_) { + TVMSynchronize(device_.device_type, device_.device_id, nullptr); + } + ft_.kv_cache_end_forward_func_(kv_cache_); + + // hidden_states: (b, 1, v) + ICHECK_EQ(last_hidden_states->ndim, 3); + ICHECK_EQ(last_hidden_states->shape[0], num_sequence); + ICHECK_EQ(last_hidden_states->shape[1], 1); + return last_hidden_states; + } + NDArray BatchVerify(const ObjectRef& embeddings, const std::vector& seq_ids, const std::vector& lengths) final { NVTXScopedRange nvtx_scope("BatchVerify"); @@ -307,34 +653,77 @@ class ModelImpl : public ModelObj { return logits; } - /*********************** KV Cache Management ***********************/ + NDArray BatchVerifyToLastHidden(const ObjectRef& hidden_states, + const std::vector& seq_ids, + const std::vector& lengths) final { + NVTXScopedRange nvtx_scope("BatchVerifyToLastHidden"); + CHECK(!seq_ids.empty()); + CHECK_EQ(seq_ids.size(), lengths.size()); + int num_sequences = seq_ids.size(); + int total_length = 0; + for (int i = 0; i < num_sequences; ++i) { + total_length += lengths[i]; + } - LogitProcessor CreateLogitProcessor(int max_num_token, - Optional trace_recorder) { - return LogitProcessor(max_num_token, vocab_size_, &this->ft_, device_, - std::move(trace_recorder)); - } + CHECK(ft_.verify_to_last_hidden_func_.defined()) + << "`batch_verify_to_last_hidden_states` function is not found in the model."; + ICHECK(ft_.kv_cache_begin_forward_func_.defined()); + ICHECK(ft_.kv_cache_end_forward_func_.defined()); + ICHECK(kv_cache_.defined()) << "KV cache has not been initialized."; - Sampler CreateSampler(int max_num_sample, int num_models, - Optional trace_recorder) { - if (num_models > 1) { // speculative decoding uses cpu sampler - return Sampler::CreateCPUSampler(std::move(trace_recorder)); - } else if (Sampler::SupportGPUSampler(device_)) { - return Sampler::CreateGPUSampler(max_num_sample, vocab_size_, &this->ft_, device_, - std::move(trace_recorder)); + ObjectRef hidden_states_dref_or_nd; + if (!hidden_states->IsInstance()) { + // hidden_states: (1, n, h) + NDArray hidden_states_nd = Downcast(hidden_states); + ICHECK_EQ(hidden_states_nd->ndim, 3); + ICHECK_EQ(hidden_states_nd->shape[0], 1); + ICHECK_EQ(hidden_states_nd->shape[1], total_length); + ICHECK_EQ(hidden_states_nd->shape[2], hidden_size_); + hidden_states_dref_or_nd = + hidden_states_nd.CreateView({1, total_length, hidden_size_}, hidden_states_nd->dtype); } else { - return Sampler::CreateCPUSampler(std::move(trace_recorder)); + ShapeTuple hidden_states_shape{1, total_length, hidden_size_}; + hidden_states_dref_or_nd = ft_.nd_view_func_(hidden_states, hidden_states_shape); } + + // Begin forward with the sequence ids and new lengths. + IntTuple seq_ids_tuple(seq_ids); + IntTuple lengths_tuple(lengths.begin(), lengths.end()); + ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple); + + // args: embeddings, logit_pos, kv_cache, params + ObjectRef ret = ft_.verify_to_last_hidden_func_(hidden_states_dref_or_nd, kv_cache_, params_); + NDArray last_hidden_states; + if (ft_.use_disco) { + Array result = Downcast(ret)->DebugGetFromRemote(0); + last_hidden_states = Downcast(result[0]); + } else { + last_hidden_states = Downcast>(ret)[0]; + } + if (trace_enabled_) { + TVMSynchronize(device_.device_type, device_.device_id, nullptr); + } + ft_.kv_cache_end_forward_func_(kv_cache_); + + // hidden_states: (1, total_length, v) + ICHECK_EQ(last_hidden_states->ndim, 3); + ICHECK_EQ(last_hidden_states->shape[0], 1); + ICHECK_EQ(last_hidden_states->shape[1], total_length); + return last_hidden_states; } - void CreateKVCache(KVCacheConfig kv_cache_config) final { - IntTuple max_num_sequence{kv_cache_config->max_num_sequence}; - IntTuple max_total_sequence_length{kv_cache_config->max_total_sequence_length}; - IntTuple prefill_chunk_size{kv_cache_config->prefill_chunk_size}; - IntTuple page_size{kv_cache_config->page_size}; + /*********************** KV Cache Management ***********************/ + + void CreateKVCache(int page_size, int max_num_sequence, int max_total_sequence_length, + int prefill_chunk_size) final { + IntTuple max_num_sequence_tuple{max_num_sequence}; + IntTuple max_total_sequence_length_tuple{max_total_sequence_length}; + IntTuple prefill_chunk_size_tuple{prefill_chunk_size}; + IntTuple page_size_tuple{page_size}; IntTuple support_sliding_window{sliding_window_size_ != -1}; - kv_cache_ = ft_.create_kv_cache_func_(max_num_sequence, max_total_sequence_length, - prefill_chunk_size, page_size, support_sliding_window); + kv_cache_ = ft_.create_kv_cache_func_(max_num_sequence_tuple, max_total_sequence_length_tuple, + prefill_chunk_size_tuple, page_size_tuple, + support_sliding_window); local_kv_cache_ = ft_.use_disco ? Downcast(kv_cache_)->DebugGetFromRemote(0) : kv_cache_; } @@ -371,6 +760,24 @@ class ModelImpl : public ModelObj { /*********************** Utilities ***********************/ + LogitProcessor CreateLogitProcessor(int max_num_token, + Optional trace_recorder) { + return LogitProcessor(max_num_token, vocab_size_, &this->ft_, device_, + std::move(trace_recorder)); + } + + Sampler CreateSampler(int max_num_sample, int num_models, + Optional trace_recorder) { + if (num_models > 1) { // speculative decoding uses cpu sampler + return Sampler::CreateCPUSampler(std::move(trace_recorder)); + } else if (Sampler::SupportGPUSampler(device_)) { + return Sampler::CreateGPUSampler(max_num_sample, vocab_size_, &this->ft_, device_, + std::move(trace_recorder)); + } else { + return Sampler::CreateCPUSampler(std::move(trace_recorder)); + } + } + int EstimateHostCPURequirement() const final { CHECK_NE(num_shards_, -1) << "The model has not been initialized"; return num_shards_ > 1 ? num_shards_ : 0; @@ -400,6 +807,26 @@ class ModelImpl : public ModelObj { return embedding; } + ObjectRef AllocHiddenStatesTensor() final { + // Allocate the hidden_states tensor. + // Use the same function as embeddings. + ObjectRef hidden_states = ft_.alloc_embedding_tensor_func_(); + // Get the shape of the hidden_states tensor for hidden size. + ShapeTuple hidden_states_shape; + if (ft_.use_disco) { + ICHECK(hidden_states->IsInstance()); + ObjectRef shape_ref = ft_.nd_get_shape_func_(hidden_states); + hidden_states_shape = Downcast(shape_ref)->DebugGetFromRemote(0); + } else { + NDArray hidden_states_nd = Downcast(hidden_states); + hidden_states_shape = hidden_states_nd.Shape(); + } + ICHECK_EQ(hidden_states_shape.size(), 2); + ICHECK_EQ(hidden_states_shape[0], prefill_chunk_size_); + this->hidden_size_ = hidden_states_shape[1]; + return hidden_states; + } + void Reset() final { // Reset the KV cache. if (kv_cache_.defined()) { @@ -407,6 +834,12 @@ class ModelImpl : public ModelObj { } } + /************** Debug/Profile **************/ + + void DebugCallFuncOnAllAllWorker(const String& func_name) final { + ft_.DebugCallFuncOnAllAllWorker(func_name); + } + private: /*! \brief Load model configuration from JSON. */ picojson::object LoadModelConfigJSON(const std::string& config_str) { diff --git a/cpp/serve/model.h b/cpp/serve/model.h index 4e57d499ef..da532f83e8 100644 --- a/cpp/serve/model.h +++ b/cpp/serve/model.h @@ -39,6 +39,11 @@ struct ModelWorkspace { * model parallelism is not enabled, or a DRef when using tensor model parallelism. */ ObjectRef embeddings{nullptr}; + /*! + * \brief The hidden_states tensor. It can be either an NDArray when tensor + * model parallelism is not enabled, or a DRef when using tensor model parallelism. + */ + ObjectRef hidden_states{nullptr}; }; /*! @@ -91,6 +96,61 @@ class ModelObj : public Object { */ virtual ObjectRef ImageEmbed(const NDArray& image, ObjectRef* dst = nullptr, int offset = 0) = 0; + /*! + * \brief Fuse the embeddings and hidden_states. + * \param embeddings The embedding of the input to be prefilled. + * \param previous_hidden_states The hidden_states from previous base model. + * \param batch_size Batch size. + * \param seq_len Sequence length. + * \return The fused hidden_states. + */ + virtual ObjectRef FuseEmbedHidden(const ObjectRef& embeddings, + const ObjectRef& previous_hidden_states, int batch_size, + int seq_len) = 0; + + /*! + * \brief Return if the model has lm_head so that we can get logits. + */ + virtual bool CanGetLogits() = 0; + + /*! + * \brief Compute logits for last hidden_states. + * \param last_hidden_states The last hidden_states to compute logits for. + * \param batch_size The batch size of last_hidden_states + * \param seq_len The length of tokens in last_hidden_states + * \return The computed logits. + */ + virtual NDArray GetLogits(const ObjectRef& last_hidden_states, int batch_size, int seq_len) = 0; + + /*! + * \brief Compute logits for last hidden_states in a batch. + * \param last_hidden_states The last hidden_states to compute logits for. + * \param seq_ids The id of the sequence in the KV cache. + * \param lengths The length of each sequence to prefill. + * \return The computed logits. + */ + virtual NDArray BatchGetLogits(const ObjectRef& last_hidden_states, + const std::vector& seq_ids, + const std::vector& lengths) = 0; + + /*! + * \brief Select desired hidden_states for last hidden_states in a batch. + * \param last_hidden_states The last hidden_states to select from. + * \param seq_ids The id of the sequence in the KV cache. + * \param lengths The length of each sequence to prefill. + * \return The last hidden_states for the batch. + */ + virtual NDArray BatchSelectLastHidden(const ObjectRef& last_hidden_states, + const std::vector& seq_ids, + const std::vector& lengths) = 0; + + /*! + * \brief Concat a list of 1D hidden_states to 2D tensor. + * \param hidden_states The hidden_states to concat. + * \param dst The copy destination. + */ + virtual NDArray ConcatLastHidden(std::vector& hidden_states, ObjectRef* dst) = 0; + /*! * \brief Batch prefill function. Embedding in, logits out. * The embedding order of sequences in `embedding_arr` follows @@ -103,6 +163,18 @@ class ModelObj : public Object { virtual NDArray BatchPrefill(const ObjectRef& embeddings, const std::vector& seq_ids, const std::vector& lengths) = 0; + /*! + * \brief Batch prefill function. Input hidden_states are computed from + * input embeddings and previous hidden_states, output last hidden_states. + * \param hidden_states The hidden_states of the input to be prefilled. + * \param seq_id The id of the sequence in the KV cache. + * \param lengths The length of each sequence to prefill. + * \return The hidden_states for the next token. + */ + virtual NDArray BatchPrefillToLastHidden(const ObjectRef& hidden_states, + const std::vector& seq_ids, + const std::vector& lengths) = 0; + /*! * \brief Batch decode function. Embedding in, logits out. * The embedding order of sequences in `embeddings` follows @@ -113,6 +185,16 @@ class ModelObj : public Object { */ virtual NDArray BatchDecode(const ObjectRef& embeddings, const std::vector& seq_ids) = 0; + /*! + * \brief Batch decode function. Input hidden_states are computed from + * input embeddings and previous hidden_states, output last hidden_states. + * \param hidden_states The hidden_states of last generated token in the entire batch. + * \param seq_id The id of the sequence in the KV cache. + * \return The hidden_states for the next token for each sequence in the batch. + */ + virtual NDArray BatchDecodeToLastHidden(const ObjectRef& hidden_states, + const std::vector& seq_ids) = 0; + /*! * \brief Batch verify function. Embedding in, logits out. * \param embeddings The embedding of the input to be verified. @@ -126,13 +208,35 @@ class ModelObj : public Object { virtual NDArray BatchVerify(const ObjectRef& embeddings, const std::vector& seq_ids, const std::vector& lengths) = 0; + /*! + * \brief Batch verify function. Input hidden_states are computed from + * input embeddings and previous hidden_states, output last hidden_states. + * \param hidden_states The hidden_states of the input to be verified. + * \param seq_id The id of the sequence in the KV cache. + * \param lengths The length of each sequence to verify. + * \return The hidden_states for the draft token for each sequence in the batch. + * \note The function runs for **every** sequence in the batch. + * That is to say, it does not accept "running a verify step for a subset + * of the full batch". + */ + virtual NDArray BatchVerifyToLastHidden(const ObjectRef& hidden_states, + const std::vector& seq_ids, + const std::vector& lengths) = 0; + /*********************** KV Cache Management ***********************/ /*! * \brief Create the KV cache inside the model with regard to the input config. - * \param kv_cache_config The configuration of KV cache. + * \param page_size The number of consecutive tokens handled in each page in paged KV cache. + * \param max_num_sequence The maximum number of sequences that are allowed to be + * processed by the KV cache at any time. + * \param max_total_sequence_length The maximum length allowed for a single sequence + * in the engine. + * \param prefill_chunk_size The maximum total number of tokens whose KV data + * are allowed to exist in the KV cache at any time. */ - virtual void CreateKVCache(KVCacheConfig kv_cache_config) = 0; + virtual void CreateKVCache(int page_size, int max_num_sequence, int max_total_sequence_length, + int prefill_chunk_size) = 0; /*! \brief Add a new sequence with the given sequence id to the KV cache. */ virtual void AddNewSequence(int64_t seq_id) = 0; @@ -188,9 +292,17 @@ class ModelObj : public Object { /*! \brief Allocate an embedding tensor with the prefill chunk size. */ virtual ObjectRef AllocEmbeddingTensor() = 0; + /*! \brief Allocate an hidden_states tensor with the prefill chunk size. */ + virtual ObjectRef AllocHiddenStatesTensor() = 0; + /*! \brief Reset the model KV cache and other statistics. */ virtual void Reset() = 0; + /************** Debug/Profile **************/ + + /*! \brief Call the given global function on all workers. Only for debug purpose. */ + virtual void DebugCallFuncOnAllAllWorker(const String& func_name) = 0; + static constexpr const char* _type_key = "mlc.serve.Model"; static constexpr const bool _type_has_method_sequal_reduce = false; static constexpr const bool _type_has_method_shash_reduce = false; @@ -201,15 +313,14 @@ class Model : public ObjectRef { public: /*! * \brief Create the runtime module for LLM functions. - * \param reload_lib The model library. It might be a path to the binary - * file or an executable module that is pre-loaded. + * \param reload_lib_path The model library path. * \param model_path The path to the model weight parameters. * \param device The device to run the model on. * \param max_num_sequence The maximum number of sequences to be processed * \param trace_enabled A boolean indicating whether tracing is enabled. * \return The created runtime module. */ - TVM_DLL static Model Create(TVMArgValue reload_lib, String model_path, DLDevice device, + TVM_DLL static Model Create(String reload_lib_path, String model_path, DLDevice device, int max_num_sequence, bool trace_enabled); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Model, ObjectRef, ModelObj); diff --git a/cpp/serve/request_state.cc b/cpp/serve/request_state.cc index 2a035ad387..b1f5ae27a2 100644 --- a/cpp/serve/request_state.cc +++ b/cpp/serve/request_state.cc @@ -59,9 +59,11 @@ void RequestModelStateNode::CommitToken(SampleResult sampled_token) { } } -void RequestModelStateNode::AddDraftToken(SampleResult sampled_token, NDArray prob_dist) { +void RequestModelStateNode::AddDraftToken(SampleResult sampled_token, NDArray prob_dist, + NDArray last_hidden_on_device) { draft_output_tokens.push_back(std::move(sampled_token)); draft_output_prob_dist.push_back(std::move(prob_dist)); + draft_last_hidden_on_device.push_back(std::move(last_hidden_on_device)); appeared_token_ids[sampled_token.sampled_token_id.first] += 1; } @@ -116,14 +118,6 @@ RequestStateEntry::RequestStateEntry( DeltaRequestReturn RequestStateEntryNode::GetReturnTokenIds(const Tokenizer& tokenizer, int max_single_sequence_length) { - // - Case 0. There is remaining draft output ==> Unfinished - // All draft outputs are supposed to be processed before finish. - for (RequestModelState mstate : this->mstates) { - if (!mstate->draft_output_tokens.empty()) { - return {{}, {}, Optional()}; - } - } - std::vector return_token_ids; std::vector logprob_json_strs; Optional finish_reason; diff --git a/cpp/serve/request_state.h b/cpp/serve/request_state.h index 7764a38c3e..950bb6e290 100644 --- a/cpp/serve/request_state.h +++ b/cpp/serve/request_state.h @@ -70,6 +70,12 @@ class RequestModelStateNode : public Object { * and draft outputs in speculative inference settings. */ std::vector draft_output_prob_dist; + /*! + * \brief The last hidden_states used to get probs in drafting. + * \note We only need this value when we have multiple parallel small models + * and draft outputs in speculative inference settings. + */ + std::vector draft_last_hidden_on_device; /*! \brief The appeared committed and draft tokens and their occurrence times. */ std::unordered_map appeared_token_ids; @@ -95,7 +101,8 @@ class RequestModelStateNode : public Object { /*! \brief Commit a new token into committed_tokens. Update appeared_token_ids. */ void CommitToken(SampleResult sampled_token); /*! \brief Add a draft token into draft_output_tokens. Update appeared_token_ids. */ - void AddDraftToken(SampleResult sampled_token, NDArray prob_dist); + void AddDraftToken(SampleResult sampled_token, NDArray prob_dist, + NDArray draft_last_hidden_on_device = NDArray()); /*! \brief Remove the last token from draft_output_tokens. Update appeared_token_ids. */ void RemoveLastDraftToken(); /*! \brief Remove all draft tokens from draft_output_tokens. Update appeared_token_ids. */ diff --git a/cpp/serve/sampler/cpu_sampler.cc b/cpp/serve/sampler/cpu_sampler.cc index e1316e57f0..02b7e2a81d 100644 --- a/cpp/serve/sampler/cpu_sampler.cc +++ b/cpp/serve/sampler/cpu_sampler.cc @@ -22,7 +22,8 @@ namespace serve { * The input is a batch of distributions, and we use `unit_offset` to specify * which distribution to sample from. * \param prob The input batch of probability distributions. - * \param unit_offset The offset specifying which distribution to sample from. + * \param unit_offset The offset specifying which distribution to output + * \param input_prob_offset The offset specifying which distribution to sample from. * \param top_p The top-p value of sampling. * \param uniform_sample The random number in [0, 1] for sampling. * \param output_prob_dist Optional pointer to store the corresponding probability distribution of @@ -31,7 +32,8 @@ namespace serve { * \note This function is an enhancement of SampleTopPFromProb in TVM Unity. * We will upstream the enhancement after it gets stable. */ -TokenProbPair SampleTopPFromProb(NDArray prob, int unit_offset, double top_p, double uniform_sample, +TokenProbPair SampleTopPFromProb(NDArray prob, int unit_offset, int input_prob_offset, double top_p, + double uniform_sample, std::vector* output_prob_dist = nullptr) { // prob: (*, v) // The prob array may have arbitrary ndim and shape. @@ -50,10 +52,11 @@ TokenProbPair SampleTopPFromProb(NDArray prob, int unit_offset, double top_p, do int64_t ndata = prob->shape[prob->ndim - 1]; const float* __restrict p_prob = - static_cast(__builtin_assume_aligned(prob->data, 4)) + (unit_offset * ndata); + static_cast(__builtin_assume_aligned(prob->data, 4)) + (input_prob_offset * ndata); constexpr double one = 1.0f - 1e-5f; if (output_prob_dist) { + ICHECK_LT(unit_offset, static_cast(output_prob_dist->size())); if (!(*output_prob_dist)[unit_offset].defined()) { (*output_prob_dist)[unit_offset] = NDArray::Empty({ndata}, prob->dtype, DLDevice{kDLCPU, 0}); } @@ -294,7 +297,7 @@ class CPUSampler : public SamplerObj { RECORD_EVENT(this->trace_recorder_, request_ids[i], "start sample token"); // Sample top p from probability. sample_results[i].sampled_token_id = SampleTopPFromProb( - probs_host, sample_indices[i], + probs_host, i, sample_indices[i], generation_cfg[i]->temperature < eps_ ? 0.0 : generation_cfg[i]->top_p, rngs[i]->GetRandomNumber(), output_prob_dist); if (output_prob_dist == nullptr) { @@ -341,7 +344,9 @@ class CPUSampler : public SamplerObj { [&](int i) { int verify_start = cum_verify_lengths[i]; int verify_end = cum_verify_lengths[i + 1]; - for (int cur_token_idx = 0; cur_token_idx < verify_end - verify_start; ++cur_token_idx) { + int cur_token_idx = 0; + // Sub 1 to ignore the last prediction. + for (; cur_token_idx < verify_end - verify_start - 1; ++cur_token_idx) { float* p_probs = global_p_probs + (verify_start + cur_token_idx) * vocab_size; int cur_token = draft_output_tokens[i][cur_token_idx].sampled_token_id.first; float q_value = draft_output_tokens[i][cur_token_idx].sampled_token_id.second; @@ -383,7 +388,7 @@ class CPUSampler : public SamplerObj { // sample a new token from the new distribution SampleResult sample_result; sample_result.sampled_token_id = SampleTopPFromProb( - probs_host, verify_start + cur_token_idx, + probs_host, verify_start + cur_token_idx, verify_start + cur_token_idx, generation_cfg[i]->temperature < eps_ ? 0.0 : generation_cfg[i]->top_p, rngs[i]->GetRandomNumber()); sample_result.top_prob_tokens = ComputeTopProbs( @@ -391,6 +396,20 @@ class CPUSampler : public SamplerObj { sample_results[i].push_back(sample_result); break; } + // if cur_token_idx == verify_end - verify_start - 1 + // all draft tokens are accepted + // we sample a new token + if (cur_token_idx == verify_end - verify_start - 1) { + SampleResult sample_result; + // sample a new token from the original distribution + sample_result.sampled_token_id = SampleTopPFromProb( + probs_host, verify_start + cur_token_idx, verify_start + cur_token_idx, + generation_cfg[i]->temperature < eps_ ? 0.0 : generation_cfg[i]->top_p, + rngs[i]->GetRandomNumber()); + sample_result.top_prob_tokens = ComputeTopProbs( + probs_host, verify_start + cur_token_idx, generation_cfg[i]->top_logprobs); + sample_results[i].push_back(sample_result); + } }, 0, num_sequence); RECORD_EVENT(trace_recorder_, request_ids, "finish draft verification"); diff --git a/cpp/serve/sampler/gpu_sampler.cc b/cpp/serve/sampler/gpu_sampler.cc index a290e64b4d..b376523dac 100644 --- a/cpp/serve/sampler/gpu_sampler.cc +++ b/cpp/serve/sampler/gpu_sampler.cc @@ -92,6 +92,7 @@ class GPUSampler : public SamplerObj { NVTXScopedRange nvtx_scope("BatchSampleTokens"); // probs_on_device: (n, v) RECORD_EVENT(trace_recorder_, request_ids, "start sampling"); + CHECK(output_prob_dist == nullptr) << "GPU sampler does not support collecting output probs."; CHECK_EQ(probs_on_device->ndim, 2); int num_samples = sample_indices.size(); int num_probs = probs_on_device->shape[0]; @@ -100,6 +101,50 @@ class GPUSampler : public SamplerObj { ICHECK_EQ(generation_cfg.size(), num_samples); ICHECK_EQ(rngs.size(), num_samples); + // Since `num_samples` may be larger than `max_num_sample_` in some cases, + // we apply chunking to support large `num_samples`. + std::vector sample_results; + if (num_samples <= max_num_sample_) { + sample_results = ChunkSampleTokensImpl(probs_on_device, sample_indices, generation_cfg, rngs); + } else { + for (int chunk_start = 0; chunk_start < num_samples; chunk_start += max_num_sample_) { + int chunk_end = std::min(chunk_start + max_num_sample_, num_samples); + std::vector sample_indices_chunk(sample_indices.begin() + chunk_start, + sample_indices.begin() + chunk_end); + Array generation_cfg_chunk(generation_cfg.begin() + chunk_start, + generation_cfg.begin() + chunk_end); + std::vector rngs_chunk(rngs.begin() + chunk_start, + rngs.begin() + chunk_end); + std::vector sample_results_chunk = ChunkSampleTokensImpl( + probs_on_device, sample_indices_chunk, generation_cfg_chunk, rngs_chunk); + sample_results.insert(sample_results.end(), sample_results_chunk.begin(), + sample_results_chunk.end()); + } + } + + RECORD_EVENT(trace_recorder_, request_ids, "finish sampling"); + return sample_results; + } + + std::vector> BatchVerifyDraftTokens( + NDArray probs_on_device, const Array& request_ids, + const std::vector& cum_verify_lengths, const Array& generation_cfg, + const std::vector& rngs, + const std::vector>& draft_output_tokens, + const std::vector>& draft_output_prob_dist) final { + LOG(FATAL) << "GPU sampler does not support batch verification for now."; + } + + private: + std::vector ChunkSampleTokensImpl(NDArray probs_on_device, // + const std::vector& sample_indices, // + const Array& generation_cfg, // + const std::vector& rngs) { + // probs_on_device: (n, v) + int num_samples = sample_indices.size(); + int num_probs = probs_on_device->shape[0]; + int vocab_size = probs_on_device->shape[1]; + // - Generate random numbers. // Copy the random numbers and sample indices. auto [uniform_samples_device, sample_indices_device] = @@ -148,20 +193,9 @@ class GPUSampler : public SamplerObj { SampleResult{{p_sampled_token_ids[i], sampled_prob}, top_prob_tokens}); } - RECORD_EVENT(trace_recorder_, request_ids, "finish sampling"); return sample_results; } - std::vector> BatchVerifyDraftTokens( - NDArray probs_on_device, const Array& request_ids, - const std::vector& cum_verify_lengths, const Array& generation_cfg, - const std::vector& rngs, - const std::vector>& draft_output_tokens, - const std::vector>& draft_output_prob_dist) final { - LOG(FATAL) << "GPU sampler does not support batch verification for now."; - } - - private: /*! \brief Generate uniform random numbers, and copy the numbers and sample indices to GPU. */ std::pair CopySamplesAndIndicesToGPU(const std::vector& sample_indices, const std::vector& rngs, diff --git a/cpp/serve/threaded_engine.cc b/cpp/serve/threaded_engine.cc new file mode 100644 index 0000000000..458d2ae5d7 --- /dev/null +++ b/cpp/serve/threaded_engine.cc @@ -0,0 +1,262 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file serve/threaded_engine.cc + * \brief The implementation for threaded serving engine in MLC LLM. + */ +#include "threaded_engine.h" + +#include +#include +#include + +#include +#include +#include + +#include "engine.h" +#include "request.h" + +namespace mlc { +namespace llm { +namespace serve { + +using tvm::Device; +using namespace tvm::runtime; + +/*! \brief The threaded engine instruction kind. */ +enum class InstructionKind : int { + kAddRequest = 0, + kAbortRequest = 1, + kUnloadEngine = 2, + kReloadEngine = 3, + kDebugCallFuncOnAllAllWorker = 4, +}; + +/*! \brief The implementation of ThreadedEngine. */ +class ThreadedEngineImpl : public ThreadedEngine { + public: + void InitBackgroundEngine(EngineConfig engine_config, + Optional request_stream_callback, + Optional trace_recorder) final { + CHECK(request_stream_callback.defined()) + << "ThreadedEngine requires request stream callback function, but it is not given."; + request_stream_callback_ = request_stream_callback.value(); + + auto frequest_stream_callback_wrapper = [this](TVMArgs args, TVMRetValue* ret) { + ICHECK_EQ(args.size(), 1); + Array delta_outputs = args[0]; + bool need_notify = false; + { + std::lock_guard lock(request_stream_callback_mutex_); + request_stream_callback_inputs_.push_back(std::move(delta_outputs)); + ++pending_request_stream_callback_cnt_; + need_notify = stream_callback_waiting_; + } + if (need_notify) { + request_stream_callback_cv_.notify_one(); + } + }; + + request_stream_callback = PackedFunc(frequest_stream_callback_wrapper); + background_engine_ = Engine::Create( + std::move(engine_config), std::move(request_stream_callback), std::move(trace_recorder)); + } + + void AddRequest(Request request) final { + bool need_notify = false; + { + std::lock_guard lock(background_loop_mutex_); + instruction_queue_.emplace_back(InstructionKind::kAddRequest, request); + ++pending_request_operation_cnt_; + need_notify = engine_waiting_; + } + if (need_notify) { + background_loop_cv_.notify_one(); + } + } + + void AbortRequest(const String& request_id) final { + bool need_notify = false; + { + std::lock_guard lock(background_loop_mutex_); + instruction_queue_.emplace_back(InstructionKind::kAbortRequest, request_id); + ++pending_request_operation_cnt_; + need_notify = engine_waiting_; + } + if (need_notify) { + background_loop_cv_.notify_one(); + } + } + + void RunBackgroundLoop() final { + // The local vectors that load the requests from critical regions. + std::vector> local_instruction_queue; + + while (!exit_now_.load(std::memory_order_relaxed)) { + { + std::unique_lock lock(background_loop_mutex_); + engine_waiting_ = true; + background_loop_cv_.wait(lock, [this] { + return !background_engine_->Empty() || pending_request_operation_cnt_.load() > 0 || + exit_now_.load(std::memory_order_relaxed); + }); + engine_waiting_ = false; + + local_instruction_queue = instruction_queue_; + instruction_queue_.clear(); + pending_request_operation_cnt_ = 0; + } + for (const auto& [kind, arg] : local_instruction_queue) { + if (kind == InstructionKind::kAddRequest) { + background_engine_->AddRequest(Downcast(arg)); + } else if (kind == InstructionKind::kAbortRequest) { + background_engine_->AbortRequest(Downcast(arg)); + } else if (kind == InstructionKind::kUnloadEngine) { + // Todo(mlc-team): implement engine unload + LOG(FATAL) << "Not implemented yet."; + } else if (kind == InstructionKind::kReloadEngine) { + // Todo(mlc-team): implement engine reload + LOG(FATAL) << "Not implemented yet."; + } else if (kind == InstructionKind::kDebugCallFuncOnAllAllWorker) { + background_engine_->DebugCallFuncOnAllAllWorker(Downcast(arg)); + } else { + LOG(FATAL) << "Cannot reach here"; + } + } + background_engine_->Step(); + } + } + + void RunBackgroundStreamBackLoop() final { + // The local vectors that load the request stream callback inputs from critical regions. + std::vector> local_request_stream_callback_inputs; + std::vector flattened_callback_inputs; + + while (!exit_now_.load(std::memory_order_relaxed)) { + { + std::unique_lock lock(request_stream_callback_mutex_); + stream_callback_waiting_ = true; + request_stream_callback_cv_.wait(lock, [this] { + return pending_request_stream_callback_cnt_.load() > 0 || + exit_now_.load(std::memory_order_relaxed); + }); + stream_callback_waiting_ = false; + + local_request_stream_callback_inputs = request_stream_callback_inputs_; + request_stream_callback_inputs_.clear(); + pending_request_stream_callback_cnt_ = 0; + } + for (const Array& callback_inputs : + local_request_stream_callback_inputs) { + for (const RequestStreamOutput& callback_input : callback_inputs) { + flattened_callback_inputs.push_back(callback_input); + } + } + if (!flattened_callback_inputs.empty()) { + request_stream_callback_(Array(flattened_callback_inputs)); + } + flattened_callback_inputs.clear(); + } + } + + void ExitBackgroundLoop() final { + { + std::lock_guard lock(background_loop_mutex_); + exit_now_.store(true); + } + background_loop_cv_.notify_one(); + request_stream_callback_cv_.notify_one(); + } + + /************** Debug/Profile **************/ + + void DebugCallFuncOnAllAllWorker(const String& func_name) final { + bool need_notify = false; + { + std::lock_guard lock(background_loop_mutex_); + instruction_queue_.emplace_back(InstructionKind::kDebugCallFuncOnAllAllWorker, func_name); + ++pending_request_operation_cnt_; + need_notify = engine_waiting_; + } + if (need_notify) { + background_loop_cv_.notify_one(); + } + } + + private: + /*! \brief The background normal engine for request processing. */ + std::unique_ptr background_engine_; + /*! \brief The request stream callback. */ + PackedFunc request_stream_callback_; + + /*! \brief The mutex ensuring only one thread can access critical regions. */ + std::mutex background_loop_mutex_; + std::mutex request_stream_callback_mutex_; + /*! \brief The condition variable preventing threaded engine from spinning. */ + std::condition_variable background_loop_cv_; + std::condition_variable request_stream_callback_cv_; + /*! \brief A boolean flag denoting if the engine needs to exit background loop. */ + std::atomic exit_now_ = false; + + /************** Critical Regions **************/ + /*! + * \brief The instruction queue for the threaded engine. + * The instructions include: + * - requests to add into the background engine, + * - requests to abort from the background engine, + * - engine unload/reload, + * - and other debugging instructions. + * Elements are sended from other threads and consumed by + * the threaded engine in the background loop. + */ + std::vector> instruction_queue_; + /*! + * \brief The delta outputs to pass through callback. + * Elements are sended from the background loop thread and + * consumed by the foreground thread. + */ + std::vector> request_stream_callback_inputs_; + /*! + * \brief Number of pending request operations, should be the size of + * `requests_to_add_` and `requests_to_abort_`. + */ + std::atomic pending_request_operation_cnt_ = 0; + /*! + * \brief Number of pending request stream callback invocations. + * It should be the size of `request_stream_callback_inputs_`. + */ + std::atomic pending_request_stream_callback_cnt_ = 0; + /*! \brief A boolean flag indicating if the engine is waiting for new requests/aborts. */ + bool engine_waiting_ = false; + /*! \brief A boolean flag indicating if the stream callback loop is waiting. */ + bool stream_callback_waiting_ = false; +}; + +/*! \brief The implementation of ThreadedEngine. */ +class ThreadedEngineModule : public ThreadedEngineImpl, public ModuleNode { + public: + TVM_MODULE_VTABLE_BEGIN("mlc.serve.async_threaded_engine"); + TVM_MODULE_VTABLE_ENTRY("init_background_engine", &ThreadedEngineImpl::InitBackgroundEngine); + TVM_MODULE_VTABLE_ENTRY("add_request", &ThreadedEngineImpl::AddRequest); + TVM_MODULE_VTABLE_ENTRY("abort_request", &ThreadedEngineImpl::AbortRequest); + TVM_MODULE_VTABLE_ENTRY("run_background_loop", &ThreadedEngineImpl::RunBackgroundLoop); + TVM_MODULE_VTABLE_ENTRY("run_background_stream_back_loop", + &ThreadedEngineImpl::RunBackgroundStreamBackLoop); + TVM_MODULE_VTABLE_ENTRY("exit_background_loop", &ThreadedEngineImpl::ExitBackgroundLoop); + TVM_MODULE_VTABLE_ENTRY("debug_call_func_on_all_worker", + &ThreadedEngineImpl::DebugCallFuncOnAllAllWorker); + TVM_MODULE_VTABLE_END(); +}; + +TVM_REGISTER_GLOBAL("mlc.serve.create_threaded_engine").set_body_typed([]() { + return Module(make_object()); +}); + +std::unique_ptr ThreadedEngine::Create() { + std::unique_ptr threaded_engine = std::make_unique(); + return std::move(threaded_engine); +} + +} // namespace serve +} // namespace llm +} // namespace mlc diff --git a/cpp/serve/threaded_engine.h b/cpp/serve/threaded_engine.h new file mode 100644 index 0000000000..3d11ba36f1 --- /dev/null +++ b/cpp/serve/threaded_engine.h @@ -0,0 +1,75 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file serve/threaded_engine.h + * \brief The header of threaded serving engine in MLC LLM. + */ +#ifndef MLC_LLM_SERVE_THREADED_ENGINE_H_ +#define MLC_LLM_SERVE_THREADED_ENGINE_H_ + +#include + +#include "data.h" +#include "engine.h" +#include "request.h" + +namespace mlc { +namespace llm { +namespace serve { + +using namespace tvm::runtime; + +/*! + * \brief The interface threaded engine in MLC LLM. + * The threaded engine keeps running a background request processing + * loop on a standalone thread. Ensuring thread safety, it exposes + * `AddRequest` and `AbortRequest` to receive new requests or + * abortions from other threads, and the internal request processing + * is backed by a normal engine wrapped inside. + */ +class ThreadedEngine { + public: + /*! \brief Create a ThreadedEngine. */ + static std::unique_ptr Create(); + + virtual ~ThreadedEngine() = default; + + /*! + * \brief Initialize the threaded engine from packed arguments in TVMArgs. + * \param engine_config The engine config. + * \param request_stream_callback The request stream callback function to. + * \param trace_recorder Event trace recorder for requests. + */ + virtual void InitBackgroundEngine(EngineConfig engine_config, + Optional request_stream_callback, + Optional trace_recorder) = 0; + + /*! \brief Starts the background request processing loop. */ + virtual void RunBackgroundLoop() = 0; + + /*! \brief Starts the request stream callback loop. */ + virtual void RunBackgroundStreamBackLoop() = 0; + + /*! + * \brief Notify the ThreadedEngine to exit the background + * request processing loop. This method is invoked by threads + * other than the engine-driving thread. + */ + virtual void ExitBackgroundLoop() = 0; + + /*! \brief Add a new request to the engine. */ + virtual void AddRequest(Request request) = 0; + + /*! \brief Abort the input request (specified by id string) from engine. */ + virtual void AbortRequest(const String& request_id) = 0; + + /************** Debug/Profile **************/ + + /*! \brief Call the given global function on all workers. Only for debug purpose. */ + virtual void DebugCallFuncOnAllAllWorker(const String& func_name) = 0; +}; + +} // namespace serve +} // namespace llm +} // namespace mlc + +#endif // MLC_LLM_SERVE_THREADED_ENGINE_H_ diff --git a/cpp/support/utils.h b/cpp/support/utils.h new file mode 100644 index 0000000000..5360f0496c --- /dev/null +++ b/cpp/support/utils.h @@ -0,0 +1,24 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file utils.h + * \brief Utility functions. + */ +#include +#include +#include + +namespace mlc { +namespace llm { + +inline std::vector Split(const std::string& str, char delim) { + std::string item; + std::istringstream is(str); + std::vector ret; + while (std::getline(is, item, delim)) { + ret.push_back(item); + } + return ret; +} + +} // namespace llm +} // namespace mlc diff --git a/docs/_static/img/project-workflow.svg b/docs/_static/img/project-workflow.svg new file mode 100644 index 0000000000..eac1313a44 --- /dev/null +++ b/docs/_static/img/project-workflow.svg @@ -0,0 +1,1173 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/community/faq.rst b/docs/community/faq.rst index 3913dd9639..4bc6f9deb8 100644 --- a/docs/community/faq.rst +++ b/docs/community/faq.rst @@ -6,7 +6,7 @@ Frequently Asked Questions This is a list of Frequently Asked Questions (FAQ) about the MLC-LLM. Feel free to suggest new entries! ... How can I customize the temperature, and repetition penalty of models? - Please check our :doc:`/get_started/mlc_chat_config` tutorial. + Please check our :ref:`configure-mlc-chat-json` tutorial. ... What's the quantization algorithm MLC-LLM using? Please check our :doc:`/compilation/configure_quantization` tutorial. diff --git a/docs/compilation/compile_models.rst b/docs/compilation/compile_models.rst index b30076f018..00beb5cc4d 100644 --- a/docs/compilation/compile_models.rst +++ b/docs/compilation/compile_models.rst @@ -21,7 +21,7 @@ We compile ``RedPajama-INCITE-Chat-3B-v1`` with ``q4f16_1`` as an example for al Before you proceed, make sure you followed :ref:`install-tvm-unity`, a required backend to compile models with MLC LLM. - Please also follow the instructions in :ref:`deploy-cli` / :ref:`deploy-python` to obtain + Please also follow the instructions in :ref:`deploy-cli` / :ref:`deploy-python-chat-module` to obtain the CLI app / Python API that can be used to chat with the compiled model. Finally, we strongly recommend you to read :ref:`project-overview` first to get familiarized with the high-level terminologies. diff --git a/docs/compilation/convert_weights.rst b/docs/compilation/convert_weights.rst index 2507687c21..aa65256fd6 100644 --- a/docs/compilation/convert_weights.rst +++ b/docs/compilation/convert_weights.rst @@ -24,7 +24,7 @@ This can be extended to, e.g.: Before you proceed, make sure you followed :ref:`install-tvm-unity`, a required backend to compile models with MLC LLM. - Please also follow the instructions in :ref:`deploy-cli` / :ref:`deploy-python` to obtain + Please also follow the instructions in :ref:`deploy-cli` / :ref:`deploy-python-chat-module` to obtain the CLI app / Python API that can be used to chat with the compiled model. Finally, we strongly recommend you to read :ref:`project-overview` first to get familiarized with the high-level terminologies. diff --git a/docs/compilation/get-vicuna-weight.rst b/docs/compilation/get-vicuna-weight.rst deleted file mode 100644 index 2ea4ba5d97..0000000000 --- a/docs/compilation/get-vicuna-weight.rst +++ /dev/null @@ -1,68 +0,0 @@ -Getting Vicuna Weights -====================== - -.. contents:: Table of Contents - :local: - :depth: 2 - -`Vicuna `_ is an open-source chatbot trained by fine-tuning `LLaMA `_ on `ShartGPT `_ data. - -Please note that the official Vicuna weights are delta weights applied to the LLaMA weights in order to comply with the LLaMA license. Users are responsible for applying these delta weights themselves. - -In this tutorial, we will show how to apply the delta weights to LLaMA weights to get Vicuna weights. - -Install FastChat ----------------- - -FastChat offers convenient utility functions for applying the delta to LLaMA weights. You can easily install it using pip. - -.. code-block:: bash - - pip install fschat - -Download HuggingFace LLaMA Weights ----------------------------------- - -The HuggingFace LLaMA weights are hosted using Git-LFS. Therefore, it is necessary to install Git-LFS first (you can ignore this step if git-lfs is already installed). - -.. code-block:: bash - - conda install git-lfs - git lfs install - -Then download the weights (both the LLaMA weight and Vicuna delta weight): - -.. code-block:: bash - - git clone https://huggingface.co/decapoda-research/llama-7b-hf - git clone https://huggingface.co/lmsys/vicuna-7b-delta-v1.1 - - -There is a name misalignment issue in the LLaMA weights and Vicuna delta weights. -Please follow these steps to modify the content of the "config.json" file: - -.. code-block:: bash - - sed -i 's/LLaMAForCausalLM/LlamaForCausalLM/g' llama-7b-hf/config.json - -Then use ``fschat`` to apply the delta to LLaMA weights - -.. code-block:: bash - - python3 -m fastchat.model.apply_delta \ - --base-model-path llama-7b-hf \ - --target-model-path vicuna-7b-v1.1 \ - --delta-path vicuna-7b-delta-v1.1 - -You will get the Vicuna weights in ``vicuna-7b-v1.1`` folder, which can be used as input of MLC-LLM to further compile models. - - -(Optional) Move Vicuna Weights to dist folder ---------------------------------------------- - -The default model path of MLC-LLM is ``dist`` folder. Therefore, it is recommended to move the Vicuna weights to ``dist`` folder. - -.. code-block:: bash - - mkdir -p dist/models - mv vicuna-7b-v1.1 dist/models/vicuna-7b-v1.1 diff --git a/docs/conf.py b/docs/conf.py index 0f7ed19014..7743ef2985 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -9,8 +9,6 @@ sys.path.insert(0, os.path.abspath("../python")) sys.path.insert(0, os.path.abspath("../")) autodoc_mock_imports = ["torch"] -# do not load mlc-llm.so in docs -os.environ["SKIP_LOADING_MLCLLM_SO"] = "1" # General information about the project. project = "mlc-llm" diff --git a/docs/deploy/javascript.rst b/docs/deploy/javascript.rst index 57f192f61a..bd92908cff 100644 --- a/docs/deploy/javascript.rst +++ b/docs/deploy/javascript.rst @@ -1,6 +1,6 @@ .. _webllm-runtime: -WebLLM and Javascript API +WebLLM and JavaScript API ========================= .. contents:: Table of Contents diff --git a/docs/deploy/mlc_chat_config.rst b/docs/deploy/mlc_chat_config.rst new file mode 100644 index 0000000000..948d50bddd --- /dev/null +++ b/docs/deploy/mlc_chat_config.rst @@ -0,0 +1,210 @@ +.. _configure-mlc-chat-json: + +Customize MLC Config File in JSON +================================= + +``mlc-chat-config.json`` is required for both compile-time and runtime, hence serving two purposes: + +1. Specify how we compile a model (shown in :ref:`compile-model-libraries`), and +2. Specify conversation behavior in runtime. + +**This page focuses on the second purpose.** We explain the components of a chat +configuration and how to customize them by modifying the file. Additionally, +the runtimes also provide APIs to optionally override some of the configurations. + +In runtime, this file is stored under the directory of each compiled model +(e.g. `RedPajama chat config `__). + + +.. _struct-mlc-chat-conv: + +Structure of MLCChat Configuration +---------------------------------- + +Below is the ``mlc-chat-config.json`` file corresponding to Llama2 model: + +.. code:: json + + // mlc-chat-config.json + { + // 1. Metadata used to specify how to compile a model + "model_type": "llama", + "quantization": "q4f16_1", + "version": "0.1.0", + "model_config": { + "hidden_size": 4096, + "intermediate_size": 11008, + // more fields here... + }, + "vocab_size": 32000, + "context_window_size": 4096, + "sliding_window_size": -1, + "prefill_chunk_size": 4096, + "tensor_parallel_shards": 1, + + // 2. Tokenizer-related fields + "pad_token_id": 0, + "bos_token_id": 1, + "eos_token_id": 2, + "tokenizer_files": [ + "tokenizer.model", + "tokenizer.json", + "tokenizer_config.json" + ] + + // 3. Conversation template related fields + "conv_template": { + "name": "llama-2", + "system_template": "[INST] <>\n{system_message}\n<>\n\n ", + "system_message": "You are a helpful, respectful and honest assistant.", + // more fields here... + }, + + // 4. Chat related fields that affect runtime behavior + "mean_gen_len": 128, + "max_gen_len": 512, + "shift_fill_factor": 0.3, + "temperature": 0.6, + "repetition_penalty": 1.0, + "top_p": 0.9 + } + +.. note:: + Fields in the first part of ``mlc-chat-config.json`` (e.g. ``context-window-size``) + is only for compile-time. Changing them during runtime may lead to unexpected behavior. + +**As shown above, the file is divided into three parts. We focus on the third part, which +can be customized to change the behavior of the model.** + +``conv_template`` + .. note:: + Legacy ``mlc-chat-config.json`` may specify a string for this field to look up a registered conversation + template. It will be deprecated in the future. Re-generate config using the latest version of mlc_llm + to make sure this field is a complete JSON object. + + The conversation template that this chat uses. For more information, please refer to :ref:`conversation structure `. + +``temperature`` + The temperature applied to logits before sampling. The default value is ``0.7``. A higher temperature encourages more diverse outputs, while a lower temperature produces more deterministic outputs. + +``repetition_penalty`` + The repetition penalty controls the likelihood of the model generating repeated texts. The default value is set to ``1.0``, indicating that no repetition penalty is applied. Increasing the value reduces the likelihood of repeat text generation. However, setting a high ``repetition_penalty`` may result in the model generating meaningless texts. The ideal choice of repetition penalty may vary among models. + + For more details on how repetition penalty controls text generation, please check out the `CTRL paper `_. + +``top_p`` + This parameter determines the set of tokens from which we sample during decoding. The default value is set to ``0.95``. At each step, we select tokens from the minimal set that has a cumulative probability exceeding the ``top_p`` parameter. + + For additional information on top-p sampling, please refer to this `blog post `_. + +``mean_gen_len`` + The approximated average number of generated tokens in each round. Used to determine whether the maximum window size would be exceeded. + +``max_gen_len`` + This parameter determines the maximum length of the generated text. If it is not set, the model will generate text until it encounters a stop token. + +``shift_fill_factor`` + The fraction of maximum window size to shift when it is exceeded. + +.. _struct-conv: + +Conversation Structure +^^^^^^^^^^^^^^^^^^^^^^ + +MLC-LLM provided a set of pre-defined conversation templates, which you can directly use by +specifying ``--conv-template [name]`` when generating config. Below is a list (not complete) of +supported conversation templates: + +- ``llama-2`` +- ``mistral_default`` +- ``chatml`` +- ``phi-2`` +- ... + +Please refer to `conversation_template.py `_ for the full list of supported templates and their implementations. + +Below is a generic structure of a JSON conversation configuration (we use vicuna as an example): + +.. code:: json + + // mlc-chat-config.json + { + // ... + "conv_template": { + "name": "llama-2", + "system_template": "[INST] <>\n{system_message}\n<>\n\n ", + "system_message": "You are a helpful, respectful and honest assistant.", + "roles": { + "user": "[INST]", + "assistant": "[/INST]", + "tool": "[INST]" + }, + "role_templates": { + "user": "{user_message}", + "assistant": "{assistant_message}", + "tool": "{tool_message}" + }, + "messages": [], + "seps": [ + " " + ], + "role_content_sep": " ", + "role_empty_sep": " ", + "stop_str": [ + "[INST]" + ], + "stop_token_ids": [ + 2 + ], + "function_string": "", + "use_function_calling": false + } + } + +``name`` + Name of the conversation. +``system_template`` + The system prompt template, it optionally contains the system + message placeholder, and the placeholder will be replaced with + the system message below. +``system_message`` + The content of the system prompt (without the template format). +``system_prefix_token_ids`` + The system token ids to be prepended at the beginning of tokenized + generated prompt. +``roles`` + The conversation roles +``role_templates`` + The roles prompt template, it optionally contains the defaults + message placeholders and will be replaced by actual content +``messages`` + The conversation history messages. + Each message is a pair of strings, denoting "(role, content)". + The content can be None. +``seps`` + An array of strings indicating the separators to be used after a user + message and a model message respectively. +``role_content_sep`` + The separator between the role and the content in a message. +``role_empty_sep`` + The separator between the role and empty contents. +``stop_str`` + When the ``stop_str`` is encountered, the model will stop generating output. +``stop_token_ids`` + A list of token IDs that act as stop tokens. +``function_string`` + The function calling string. +``use_function_calling`` + Whether using function calling or not, helps check for output message format in API call. + + +Given a conversation template, the corresponding prompt generated out +from it is in the following format: + +.. code:: text + + <><><><><> + <><><><> + ... + <><><><> + <><> diff --git a/docs/deploy/python.rst b/docs/deploy/python_chat_module.rst similarity index 96% rename from docs/deploy/python.rst rename to docs/deploy/python_chat_module.rst index 38cdec2f85..5776e29138 100644 --- a/docs/deploy/python.rst +++ b/docs/deploy/python_chat_module.rst @@ -1,15 +1,21 @@ -.. _deploy-python: +.. _deploy-python-chat-module: -Python API -========== +Python API (Chat Module) +======================== + +.. note:: + ❗ The Python API with :class:`mlc_llm.ChatModule` introduced in this page will be + deprecated in the near future. + Please go to :ref:`deploy-python-engine` for the latest Python API with complete + OpenAI API support. .. contents:: Table of Contents :local: :depth: 2 -We expose Python API for the MLC-Chat for easy integration into other Python projects. +We expose ChatModule Python API for the MLC-LLM for easy integration into other Python projects. -The Python API is a part of the MLC-Chat package, which we have prepared pre-built pip wheels via +The Python API is a part of the MLC-LLM package, which we have prepared pre-built pip wheels via the :doc:`installation page <../install/mlc_llm>`. Instead of following this page, you could also checkout the following tutorials in @@ -340,7 +346,7 @@ We provide an example below. API Reference ------------- -User can initiate a chat module by creating :class:`mlc_llm.ChatModule` class, which is a wrapper of the MLC-Chat model. +User can initiate a chat module by creating :class:`mlc_llm.ChatModule` class, which is a wrapper of the MLC-LLM model. The :class:`mlc_llm.ChatModule` class provides the following methods: .. currentmodule:: mlc_llm diff --git a/docs/deploy/python_engine.rst b/docs/deploy/python_engine.rst new file mode 100644 index 0000000000..c5d9a072a7 --- /dev/null +++ b/docs/deploy/python_engine.rst @@ -0,0 +1,15 @@ +.. _deploy-python-engine: + +Python API +========== + +.. note:: + This page introduces the Python API with LLMEngine in MLC LLM. + If you want to check out the old Python API which uses :class:`mlc_llm.ChatModule`, + please go to :ref:`deploy-python-chat-module` + +.. contents:: Table of Contents + :local: + :depth: 2 + +🚧 Under construction... diff --git a/docs/deploy/rest.rst b/docs/deploy/rest.rst index 621a22fb71..e59abc1257 100644 --- a/docs/deploy/rest.rst +++ b/docs/deploy/rest.rst @@ -1,6 +1,10 @@ .. _deploy-rest-api: +<<<<<<< HEAD Rest API +======= +REST API +>>>>>>> upstream/main ======== .. contents:: Table of Contents @@ -8,11 +12,12 @@ Rest API :depth: 2 We provide `REST API `_ -for a user to interact with MLC-Chat in their own programs. +for a user to interact with MLC-LLM in their own programs. -Install MLC-Chat Package +Install MLC-LLM Package ------------------------ +<<<<<<< HEAD SERVE is a part of the MLC-Chat package, installation instruction for which we be found here :doc:`<../install/mlc_llm>`. Verify Installation @@ -23,18 +28,73 @@ Verify Installation python -m mlc_llm.serve.server --help You are expected to see the help information of the MLC SERVE. +======= +SERVE is a part of the MLC-LLM package, installation instruction for which can be found :ref:`here `. Once you have install the MLC-LLM package, you can run the following command to check if the installation was successful: -.. _mlcchat_package_build_from_source: +.. code:: bash + + mlc_llm serve --help + +You should see serve help message if the installation was successful. +>>>>>>> upstream/main +Quick start +------------ + +<<<<<<< HEAD +======= +This section provides a quick start guide to work with MLC-LLM REST API. To launch a server, run the following command: + +.. code:: bash + + mlc_llm serve MODEL [--model-lib-path MODEL_LIB_PATH] + +where ``MODEL`` is the model folder after compiling with :ref:`MLC-LLM build process `. Information about other arguments can be found under :ref:`Launch the server ` section. + +Once you have launched the Server, you can use the API in your own program to send requests. Below is an example of using the API to interact with MLC-LLM in Python without Streaming (suppose the server is running on ``http://127.0.0.1:8080/``): + +.. code:: bash + + import requests + + # Get a response using a prompt without streaming + payload = { + "model": "./dist/Llama-2-7b-chat-hf-q4f16_1-MLC/", + "messages": [ + {"role": "user", "content": "Write a haiku about apples."}, + ], + "stream": False, + # "n": 1, + "max_tokens": 300, + } + r = requests.post("http://127.0.0.1:8080/v1/chat/completions", json=payload) + choices = r.json()["choices"] + for choice in choices: + print(f"{choice['message']['content']}\n") + +------------------------------------------------ + + +.. _rest_launch_server: + +>>>>>>> upstream/main Launch the Server ----------------- +<<<<<<< HEAD To launch the MLC Server for MLC-Chat, run the following command in your terminal. .. code:: bash python -m mlc_llm serve MODEL [--model-lib-path MODEL_LIB_PATH] [--device DEVICE] [--max-batch-size MAX_BATCH_SIZE] [--max-total-seq-length MAX_TOTAL_SEQ_LENGTH] [--prefill-chunk-size PREFILL_CHUNK_SIZE] [--enable-tracing] [--host HOST] [--port PORT] [--allow-credentials] [--allowed-origins ALLOWED_ORIGINS] [--allowed-methods ALLOWED_METHODS] [--allowed-headers ALLOWED_HEADERS] +======= +To launch the MLC Server for MLC-LLM, run the following command in your terminal. + +.. code:: bash + + mlc_llm serve MODEL [--model-lib-path MODEL_LIB_PATH] [--device DEVICE] [--max-batch-size MAX_BATCH_SIZE] [--max-total-seq-length MAX_TOTAL_SEQ_LENGTH] [--prefill-chunk-size PREFILL_CHUNK_SIZE] [--enable-tracing] [--host HOST] [--port PORT] [--allow-credentials] [--allowed-origins ALLOWED_ORIGINS] [--allowed-methods ALLOWED_METHODS] [--allowed-headers ALLOWED_HEADERS] +>>>>>>> upstream/main MODEL The model folder after compiling with MLC-LLM build process. The parameter can either be the model name with its quantization scheme @@ -71,7 +131,11 @@ The REST API provides the following endpoints: ------------------------------------------------ +<<<<<<< HEAD Get a list of models available for MLC-Chat. +======= + Get a list of models available for MLC-LLM. +>>>>>>> upstream/main **Example** @@ -89,12 +153,15 @@ The REST API provides the following endpoints: print(response.json()) else: print("Error:", response.status_code) +<<<<<<< HEAD .. http:post:: /v1/chat/completions +======= +>>>>>>> upstream/main ------------------------------------------------- +<<<<<<< HEAD Get a response from MLC-Chat using a prompt, either with or without streaming. **Chat Completion Request Object** @@ -197,16 +264,123 @@ The REST API provides the following endpoints: - **system_fingerprint** (*str*, required): A unique identifier for the system generating the streaming completions. - **object** (*Literal["chat.completion.chunk"]*, required, default="chat.completion.chunk"): A literal indicating that this object represents a chunk of a streaming chat completion. +======= +.. http:post:: /v1/chat/completions ------------------------------------------------ + Get a response from MLC-LLM using a prompt, either with or without streaming. + +**Chat Completion Request Object** + +- **messages** (*List[ChatCompletionMessage]*, required): A sequence of messages that have been exchanged in the conversation so far. Each message in the conversation is represented by a `ChatCompletionMessage` object, which includes the following fields: + - **content** (*Optional[Union[str, List[Dict[str, str]]]]*): The text content of the message or structured data in case of tool-generated messages. + - **role** (*Literal["system", "user", "assistant", "tool"]*): The role of the message sender, indicating whether the message is from the system, user, assistant, or a tool. + - **name** (*Optional[str]*): An optional name for the sender of the message. + - **tool_calls** (*Optional[List[ChatToolCall]]*): A list of calls to external tools or functions made within this message, applicable when the role is `tool`. + - **tool_call_id** (*Optional[str]*): A unique identifier for the tool call, relevant when integrating external tools or services. + +- **model** (*str*, required): The model to be used for generating responses. + +- **frequency_penalty** (*float*, optional, default=0.0): Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model’s likelihood to repeat tokens. + +- **presence_penalty** (*float*, optional, default=0.0): Positive values penalize new tokens if they are already present in the text so far, decreasing the model’s likelihood to repeat tokens. + +- **logprobs** (*bool*, optional, default=False): Indicates whether to include log probabilities for each token in the response. + +- **top_logprobs** (*int*, optional, default=0): An integer ranging from 0 to 5. It determines the number of tokens, most likely to appear at each position, to be returned. Each token is accompanied by a log probability. If this parameter is used, 'logprobs' must be set to true. + +- **logit_bias** (*Optional[Dict[int, float]]*): Allows specifying biases for or against specific tokens during generation. + +- **max_tokens** (*Optional[int]*): The maximum number of tokens to generate in the response(s). + +- **n** (*int*, optional, default=1): Number of responses to generate for the given prompt. + +- **seed** (*Optional[int]*): A seed for deterministic generation. Using the same seed and inputs will produce the same output. + +- **stop** (*Optional[Union[str, List[str]]]*): One or more strings that, if encountered, will cause generation to stop. + +- **stream** (*bool*, optional, default=False): If `True`, responses are streamed back as they are generated. + +- **temperature** (*float*, optional, default=1.0): Controls the randomness of the generation. Lower values lead to less random completions. + +- **top_p** (*float*, optional, default=1.0): Nucleus sampling parameter that controls the diversity of the generated responses. + +- **tools** (*Optional[List[ChatTool]]*): Specifies external tools or functions that can be called as part of the chat. + +- **tool_choice** (*Optional[Union[Literal["none", "auto"], Dict]]*): Controls how tools are selected for use in responses. + +- **user** (*Optional[str]*): An optional identifier for the user initiating the request. + +- **ignore_eos** (*bool*, optional, default=False): If `True`, the model will ignore the end-of-sequence token for generating responses. + +- **response_format** (*RequestResponseFormat*, optional): Specifies the format of the response. Can be either "text" or "json_object", with optional schema definition for JSON responses. + +**Returns** + +- If `stream` is `False`, a `ChatCompletionResponse` object containing the generated response(s). +- If `stream` is `True`, a stream of `ChatCompletionStreamResponse` objects, providing a real-time feed of generated responses. + +**ChatCompletionResponseChoice** + +- **finish_reason** (*Optional[Literal["stop", "length", "tool_calls", "error"]]*, optional): The reason the completion process was terminated. It can be due to reaching a stop condition, the maximum length, output of tool calls, or an error. + +- **index** (*int*, required, default=0): Indicates the position of this choice within the list of choices. + +- **message** (*ChatCompletionMessage*, required): The message part of the chat completion, containing the content of the chat response. + +- **logprobs** (*Optional[LogProbs]*, optional): Optionally includes log probabilities for each output token + +**ChatCompletionStreamResponseChoice** + +- **finish_reason** (*Optional[Literal["stop", "length", "tool_calls"]]*, optional): Specifies why the streaming completion process ended. Valid reasons are "stop", "length", and "tool_calls". + +- **index** (*int*, required, default=0): Indicates the position of this choice within the list of choices. + +- **delta** (*ChatCompletionMessage*, required): Represents the incremental update or addition to the chat completion message in the stream. + +- **logprobs** (*Optional[LogProbs]*, optional): Optionally includes log probabilities for each output token + +**ChatCompletionResponse** +>>>>>>> upstream/main + +- **id** (*str*, required): A unique identifier for the chat completion session. + +- **choices** (*List[ChatCompletionResponseChoice]*, required): A collection of `ChatCompletionResponseChoice` objects, representing the potential responses generated by the model. + +- **created** (*int*, required, default=current time): The UNIX timestamp representing when the response was generated. + +- **model** (*str*, required): The name of the model used to generate the chat completions. + +- **system_fingerprint** (*str*, required): A system-generated fingerprint that uniquely identifies the computational environment. + +- **object** (*Literal["chat.completion"]*, required, default="chat.completion"): A string literal indicating the type of object, here always "chat.completion". + +- **usage** (*UsageInfo*, required, default=empty `UsageInfo` object): Contains information about the API usage for this specific request. + + +<<<<<<< HEAD **Example** +======= +- **id** (*str*, required): A unique identifier for the streaming chat completion session. + +- **choices** (*List[ChatCompletionStreamResponseChoice]*, required): A list of `ChatCompletionStreamResponseChoice` objects, each representing a part of the streaming chat response. + +- **created** (*int*, required, default=current time): The creation time of the streaming response, represented as a UNIX timestamp. + +- **model** (*str*, required): Specifies the model that was used for generating the streaming chat completions. + +- **system_fingerprint** (*str*, required): A unique identifier for the system generating the streaming completions. + +- **object** (*Literal["chat.completion.chunk"]*, required, default="chat.completion.chunk"): A literal indicating that this object represents a chunk of a streaming chat completion. +>>>>>>> upstream/main Once you have launched the Server, you can use the API in your own program. Below is an example of using the API to interact with MLC-Chat in Python without Streaming (suppose the server is running on ``http://127.0.0.1:8080/``): .. code:: bash +<<<<<<< HEAD import requests # Get a response using a prompt without streaming @@ -228,9 +402,14 @@ Once you have launched the Server, you can use the API in your own program. Belo choices = r.json()["choices"] for choice in choices: print(f"{choice['message']['content']}\n") +======= ------------------------------------------------- +**Example** +>>>>>>> upstream/main + +Below is an example of using the API to interact with MLC-LLM in Python with Streaming. +<<<<<<< HEAD Below is an example of using the API to interact with MLC-Chat in Python with Streaming. .. code:: bash @@ -257,6 +436,31 @@ Below is an example of using the API to interact with MLC-Chat in Python with St ------------------------------------------------ +======= +.. code:: bash + + import requests + import json + + # Get a response using a prompt with streaming + payload = { + "model": "./dist/Llama-2-7b-chat-hf-q4f16_1-MLC/", + "messages": [{"role": "user", "content": "Write a haiku"}], + "stream": True, + } + with requests.post("http://127.0.0.1:8080/v1/chat/completions", json=payload, stream=True) as r: + for chunk in r.iter_content(chunk_size=None): + chunk = chunk.decode("utf-8") + if "[DONE]" in chunk[6:]: + break + response = json.loads(chunk[6:]) + content = response["choices"][0]["delta"].get("content", "") + print(content, end="", flush=True) + print("\n") + +------------------------------------------------ + +>>>>>>> upstream/main There is also support for function calling similar to OpenAI (https://platform.openai.com/docs/guides/function-calling). Below is an example on how to use function calling in Python. .. code:: bash diff --git a/docs/get_started/introduction.rst b/docs/get_started/introduction.rst new file mode 100644 index 0000000000..282b4764c2 --- /dev/null +++ b/docs/get_started/introduction.rst @@ -0,0 +1,319 @@ +.. _introduction-to-mlc-llm: + +Introduction to MLC LLM +======================= + +.. contents:: Table of Contents + :local: + :depth: 2 + +Machine Learning Compilation for Large Language Models (MLC LLM) is a high-performance +universal LLM deployment engine. The mission of this project is to enable everyone to develop, +optimize and deploy AI models natively on everyone's devices with ML compilation techniques. + +This page is a quick tutorial to introduce how to try out MLC LLM, and the steps to +deploy your own models with MLC LLM. + +Installation +------------ + +:ref:`MLC LLM ` is available via pip. +It is always recommended to install it in an isolated conda virtual environment. + +To verify the installation, activate your virtual environment, run + +.. code:: bash + + python -c "import mlc_llm; print(mlc_llm.__path__)" + +You are expected to see the installation path of MLC LLM Python package. + + +Chat CLI +-------- + +As the first example, we try out the chat CLI in MLC LLM with 4-bit quantized 7B Llama-2 model. +You can run MLC chat through a one-liner command: + +.. code:: bash + + mlc_llm chat HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC + +It may take 1-2 minutes for the first time running this command. +After waiting, this command launch a chat interface where you can enter your prompt and chat with the model. + +.. code:: + + You can use the following special commands: + /help print the special commands + /exit quit the cli + /stats print out the latest stats (token/sec) + /reset restart a fresh chat + /set [overrides] override settings in the generation config. For example, + `/set temperature=0.5;max_gen_len=100;stop=end,stop` + Note: Separate stop words in the `stop` option with commas (,). + Multi-line input: Use escape+enter to start a new line. + + [INST]: What's the meaning of life? + [/INST]: + Ah, a question that has puzzled philosophers and theologians for centuries! ... + + +The figure below shows what run under the hood of this chat CLI command. +For the first time running the command, there are three major phases. + +- **Phase 1. Pre-quantized weight download.** This phase automatically downloads pre-quantized Llama-2 model from `Hugging Face `_ and saves it to your local cache directory. +- **Phase 2. Model compilation.** This phase automatically optimizes the Llama-2 model to accelerate model inference on GPU with techniques of machine learning compilation in `Apache TVM `_ compiler, and generate the binary model library that enables the execution language models on your local GPU. +- **Phase 3. Chat runtime.** This phase consumes the model library built in phase 2 and the model weights downloaded in phase 1, launches a platform-native chat runtime to drive the execution of Llama-2 model. + +We cache the pre-quantized model weights and compiled model library locally. +Therefore, phase 1 and 2 will only execute **once** over multiple runs. + +.. figure:: /_static/img/project-workflow.svg + :width: 700 + :align: center + :alt: Project Workflow + + Workflow in MLC LLM + +| + +.. _introduction-to-mlc-llm-python-api: + +Python API +---------- + +In the second example, we run the Llama-2 model with the chat completion Python API of MLC LLM. +You can save the code below into a Python file and run it. + +.. code:: python + + from mlc_llm import LLMEngine + + # Create engine + model = "HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC" + engine = LLMEngine(model) + + # Run chat completion in OpenAI API. + for response in engine.chat.completions.create( + messages=[{"role": "user", "content": "What is the meaning of life?"}], + model=model, + stream=True, + ): + for choice in response.choices: + print(choice.delta.content, end="", flush=True) + print("\n") + + engine.terminate() + +.. figure:: https://raw.githubusercontent.com/mlc-ai/web-data/main/images/mlc-llm/tutorials/python-engine-api.jpg + :width: 500 + :align: center + + MLC LLM Python API + +This code example first creates an :class:`mlc_llm.LLMEngine` instance with the the 4-bit quantized Llama-2 model. +**We design the Python API** :class:`mlc_llm.LLMEngine` **to align with OpenAI API**, +which means you can use :class:`mlc_llm.LLMEngine` in the same way of using +`OpenAI's Python package `_ +for both synchronous and asynchronous generation. + +In this code example, we use the synchronous chat completion interface and iterate over +all the stream responses. +If you want to run without streaming, you can run + +.. code:: python + + response = engine.chat.completions.create( + messages=[{"role": "user", "content": "What is the meaning of life?"}], + model=model, + stream=False, + ) + print(response) + +You can also try different arguments supported in `OpenAI chat completion API `_. +If you would like to do concurrent asynchronous generation, you can use :class:`mlc_llm.AsyncLLMEngine` instead. + +REST Server +----------- + +For the third example, we launch a REST server to serve the 4-bit quantized Llama-2 model +for OpenAI chat completion requests. The server can be launched in command line with + +.. code:: bash + + mlc_llm serve HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC + +The server is hooked at ``http://127.0.0.1:8000`` by default, and you can use ``--host`` and ``--port`` +to set a different host and port. +When the server is ready (showing ``INFO: Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit)``), +we can open a new shell and send a cURL request via the following command: + +.. code:: bash + + curl -X POST \ + -H "Content-Type: application/json" \ + -d '{ + "model": "HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC", + "messages": [ + {"role": "user", "content": "Hello! Our project is MLC LLM. What is the name of our project?"} + ] + }' \ + http://127.0.0.1:8000/v1/chat/completions + +The server will process this request and send back the response. +Similar to :ref:`introduction-to-mlc-llm-python-api`, you can pass argument ``"stream": true`` +to request for stream responses. + + +Deploy Your Own Model +--------------------- + +So far we have been using pre-converted models weights from Hugging Face. +This section introduces the core workflow regarding how you can *run your own models with MLC LLM*. + +We use the `Phi-2 `_ as the example model. +Assuming the Phi-2 model is downloaded and placed under ``models/phi-2``, +there are two major steps to prepare your own models. + +- **Step 1. Generate MLC config.** The first step is to generate the configuration file of MLC LLM. + + .. code:: bash + + export LOCAL_MODEL_PATH=models/phi-2 # The path where the model resides locally. + export MLC_MODEL_PATH=dist/phi-2-MLC/ # The path where to place the model processed by MLC. + export QUANTIZATION=q0f16 # The choice of quantization. + export CONV_TEMPLATE=phi-2 # The choice of conversation template. + mlc_llm gen_config $LOCAL_MODEL_PATH \ + --quantization $QUANTIZATION \ + --conv-template $CONV_TEMPLATE \ + -o $MLC_MODEL_PATH + + The config generation command takes in the local model path, the target path of MLC output, + the conversation template name in MLC and the quantization name in MLC. + Here the quantization ``q0f16`` means float16 without quantization, + and the conversation template ``phi-2`` is the Phi-2 model's template in MLC. + + If you want to enable tensor parallelism on multiple GPUs, add argument + ``--tensor-parallel-shards $NGPU`` to the config generation command. + + - `The full list of supported quantization in MLC `_. You can try different quantization methods with MLC LLM. Typical quantization methods are ``q4f16_1`` for 4-bit group quantization, ``q4f16_ft`` for 4-bit FasterTransformer format quantization. + - `The full list of conversation template in MLC `_. + +- **Step 2. Convert model weights.** In this step, we convert the model weights to MLC format. + + .. code:: bash + + mlc_llm convert_weight $LOCAL_MODEL_PATH \ + --quantization $QUANTIZATION \ + -o $MLC_MODEL_PATH + + This step consumes the raw model weights and converts them to for MLC format. + The converted weights will be stored under ``$MLC_MODEL_PATH``, + which is the same directory where the config file generated in Step 1 resides. + +Now, we can try to run your own model with chat CLI: + +.. code:: bash + + mlc_llm chat $MLC_MODEL_PATH + +For the first run, model compilation will be triggered automatically to optimize the +model for GPU accelerate and generate the binary model library. +The chat interface will be displayed after model JIT compilation finishes. +You can also use this model in Python API, MLC serve and other use scenarios. + +(Optional) Compile Model Library +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +In previous sections, model libraries are compiled when the :class:`mlc_llm.LLMEngine` launches, +which is what we call "JIT (Just-in-Time) model compilation". +In some cases, it is beneficial to explicitly compile the model libraries. +We can deploy LLMs with reduced dependencies by shipping the library for deployment without going through compilation. +It will also enable advanced options such as cross-compiling the libraries for web and mobile deployments. + + +Below is an example command of compiling model libraries in MLC LLM: + +.. code:: bash + + export $MODEL_LIB_PATH=$MLC_MODEL_PATH/lib.so # ".dylib" for Intel Macs. + # ".dll" for Windows. + # ".wasm" for web. + # ".tar" for iPhone/Android. + mlc_llm compile $MLC_MODEL_PATH -o $MODEL_LIB_PATH + +At runtime, we need to specify this model library path to use it. For example, + +.. code:: bash + + # For chat CLI + mlc_llm chat $MLC_MODEL_PATH --model-lib-path $MODEL_LIB_PATH + # For REST server + mlc_llm serve $MLC_MODEL_PATH --model-lib-path $MODEL_LIB_PATH + +.. code:: python + + from mlc_llm import LLMEngine + + # For Python API + model = "models/phi-2" + model_lib_path = "models/phi-2/lib.so" + engine = LLMEngine(model, model_lib_path=model_lib_path) + +:ref:`compile-model-libraries` introduces the model compilation command in detail, +where you can find instructions and example commands to compile model to different +hardware backends, such as WebGPU, iOS and Android. + +Universal Deployment +-------------------- + +MLC LLM is a high-performance universal deployment solution for large language models, +to enable native deployment of any large language models with native APIs with compiler acceleration +So far, we have gone through several examples running on a local GPU environment. +The project supports multiple kinds of GPU backends. + +You can use `--device` option in compilation and runtime to pick a specific GPU backend. +For example, if you have an NVIDIA or AMD GPU, you can try to use the option below +to run chat through the vulkan backend. Vulkan-based LLM applications run in less typical +environments (e.g. SteamDeck). + +.. code:: bash + + mlc_llm chat HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC --device vulkan + +The same core LLM runtime engine powers all the backends, enabling the same model to be deployed across backends as +long as they fit within the memory and computing budget of the corresponding hardware backend. +We also leverage machine learning compilation to build backend-specialized optimizations to +get out the best performance on the targetted backend when possible, and reuse key insights and optimizations +across backends we support. + +Please checkout the what to do next sections below to find out more about different deployment scenarios, +such as WebGPU-based browser deployment, mobile and other settings. + +Summary and What to Do Next +--------------------------- + +To briefly summarize this page, + +- We went through three examples (chat CLI, Python API, and REST server) of MLC LLM, +- we introduced how to convert model weights for your own models to run with MLC LLM, and (optionally) how to compile your models. +- We also discussed the the universal deployment capability of MLC LLM. + +Next, please feel free to check out the pages below for quick start examples and more detailed information +on specific platforms + +- :ref:`Quick start examples ` for Python API, chat CLI, REST server, web browser, iOS and Android. +- Depending on your use case, check out our API documentation and tutorial pages: + + - :ref:`webllm-runtime` + - :ref:`deploy-rest-api` + - :ref:`deploy-cli` + - :ref:`deploy-python-engine` + - :ref:`deploy-ios` + - :ref:`deploy-android` + - :ref:`deploy-ide-integration` + +- :ref:`Convert model weight to MLC format `, if you want to run your own models. +- :ref:`Compile model libraries `, if you want to deploy to web/iOS/Android or control the model optimizations. +- Report any problem or ask any question: open new issues in our `GitHub repo `_. diff --git a/docs/get_started/project_overview.rst b/docs/get_started/project_overview.rst index 2b6ff7495a..ef631e40c8 100644 --- a/docs/get_started/project_overview.rst +++ b/docs/get_started/project_overview.rst @@ -52,7 +52,7 @@ There are several ways to prepare the model weights and model lib. A default chat config usually comes with the model weight directory. You can further customize the system prompt, temperature, and other options by modifying the JSON file. MLC chat runtimes also provide API to override these options during model reload. -Please refer to :doc:`/get_started/mlc_chat_config` for more details. +Please refer to :ref:`configure-mlc-chat-json` for more details. Runtime Flow Overview @@ -82,7 +82,7 @@ Thank you for reading and learning the high-level concepts. Moving next, feel free to check out documents on the left navigation panel and learn about topics you are interested in. -- :doc:`/get_started/mlc_chat_config` shows how to configure specific chat behavior. +- :ref:`configure-mlc-chat-json` shows how to configure specific chat behavior. - Build and Deploy App section contains guides to build apps and platform-specific MLC chat runtimes. - Compile models section provides guidelines to convert model weights and produce model libs. diff --git a/docs/get_started/quick_start.rst b/docs/get_started/quick_start.rst new file mode 100644 index 0000000000..bd3b41218e --- /dev/null +++ b/docs/get_started/quick_start.rst @@ -0,0 +1,190 @@ +.. _quick-start: + +Quick Start +=========== + +Examples +-------- + +To begin with, try out MLC LLM support for int4-quantized Llama2 7B. +It is recommended to have at least 6GB free VRAM to run it. + +.. tabs:: + + .. tab:: Python + + **Install MLC LLM**. :ref:`MLC LLM ` is available via pip. + It is always recommended to install it in an isolated conda virtual environment. + + **Run chat completion in Python.** The following Python script showcases the Python API of MLC LLM: + + .. code:: python + + from mlc_llm import LLMEngine + + # Create engine + model = "HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC" + engine = LLMEngine(model) + + # Run chat completion in OpenAI API. + for response in engine.chat.completions.create( + messages=[{"role": "user", "content": "What is the meaning of life?"}], + model=model, + stream=True, + ): + for choice in response.choices: + print(choice.delta.content, end="", flush=True) + print("\n") + + engine.terminate() + + .. Todo: link the colab notebook when ready: + + **Documentation and tutorial.** Python API reference and its tutorials are :ref:`available online `. + + .. figure:: https://raw.githubusercontent.com/mlc-ai/web-data/main/images/mlc-llm/tutorials/python-engine-api.jpg + :width: 600 + :align: center + + MLC LLM Python API + + .. tab:: REST Server + + **Install MLC LLM**. :ref:`MLC LLM ` is available via pip. + It is always recommended to install it in an isolated conda virtual environment. + + **Launch a REST server.** Run the following command from command line to launch a REST server at ``http://127.0.0.1:8000``. + + .. code:: shell + + mlc_llm serve HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC + + **Send requests to server.** When the server is ready (showing ``INFO: Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit)``), + open a new shell and send a request via the following command: + + .. code:: shell + + curl -X POST \ + -H "Content-Type: application/json" \ + -d '{ + "model": "HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC", + "messages": [ + {"role": "user", "content": "Hello! Our project is MLC LLM. What is the name of our project?"} + ] + }' \ + http://127.0.0.1:8000/v1/chat/completions + + **Documentation and tutorial.** Check out :ref:`deploy-rest-api` for the REST API reference and tutorial. + Our REST API has complete OpenAI API support. + + .. figure:: https://raw.githubusercontent.com/mlc-ai/web-data/main/images/mlc-llm/tutorials/python-serve-request.jpg + :width: 600 + :align: center + + Send HTTP request to REST server in MLC LLM + + .. tab:: Command Line + + **Install MLC LLM**. :ref:`MLC LLM ` is available via pip. + It is always recommended to install it in an isolated conda virtual environment. + + For Windows/Linux users, make sure to have latest :ref:`Vulkan driver ` installed. + + **Run in command line**. + + .. code:: bash + + mlc_llm chat HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC + + + If you are using windows/linux/steamdeck and would like to use vulkan, + we recommend installing necessary vulkan loader dependency via conda + to avoid vulkan not found issues. + + .. code:: bash + + conda install -c conda-forge gcc libvulkan-loader + + + .. tab:: Web Browser + + `WebLLM `__. MLC LLM generates performant code for WebGPU and WebAssembly, + so that LLMs can be run locally in a web browser without server resources. + + **Download pre-quantized weights**. This step is self-contained in WebLLM. + + **Download pre-compiled model library**. WebLLM automatically downloads WebGPU code to execute. + + **Check browser compatibility**. The latest Google Chrome provides WebGPU runtime and `WebGPU Report `__ as a useful tool to verify WebGPU capabilities of your browser. + + .. figure:: https://blog.mlc.ai/img/redpajama/web.gif + :width: 300 + :align: center + + MLC LLM on Web + + .. tab:: iOS + + **Install MLC Chat iOS**. It is available on AppStore: + + .. image:: https://developer.apple.com/assets/elements/badges/download-on-the-app-store.svg + :width: 135 + :target: https://apps.apple.com/us/app/mlc-chat/id6448482937 + + | + + **Requirement**. Llama2-7B model needs an iOS device with a minimum of 6GB RAM, whereas the RedPajama-3B model runs with at least 4GB RAM. + + **Tutorial and source code**. The source code of the iOS app is fully `open source `__, + and a :ref:`tutorial ` is included in documentation. + + .. figure:: https://blog.mlc.ai/img/redpajama/ios.gif + :width: 300 + :align: center + + MLC Chat on iOS + + .. tab:: Android + + **Install MLC Chat Android**. A prebuilt is available as an APK: + + .. image:: https://seeklogo.com/images/D/download-android-apk-badge-logo-D074C6882B-seeklogo.com.png + :width: 135 + :target: https://github.com/mlc-ai/binary-mlc-llm-libs/releases/download/Android/mlc-chat.apk + + | + + **Requirement**. Llama2-7B model needs a device with a minimum of 6GB RAM, whereas the RedPajama-3B model runs with at least 4GB RAM. + The demo is tested on + + - Samsung S23 with Snapdragon 8 Gen 2 chip + - Redmi Note 12 Pro with Snapdragon 685 + - Google Pixel phones + + **Tutorial and source code**. The source code of the android app is fully `open source `__, + and a :ref:`tutorial ` is included in documentation. + + .. figure:: https://blog.mlc.ai/img/android/android-recording.gif + :width: 300 + :align: center + + MLC LLM on Android + + +What to Do Next +--------------- + +- Check out :ref:`introduction-to-mlc-llm` for the introduction of a complete workflow in MLC LLM. +- Depending on your use case, check out our API documentation and tutorial pages: + + - :ref:`webllm-runtime` + - :ref:`deploy-rest-api` + - :ref:`deploy-cli` + - :ref:`deploy-python-engine` + - :ref:`deploy-ios` + - :ref:`deploy-android` + - :ref:`deploy-ide-integration` + +- `Convert model weight to MLC format `_, if you want to run your own models. +- `Compile model libraries `_, if you want to deploy to web/iOS/Android or control the model optimizations. +- Report any problem or ask any question: open new issues in our `GitHub repo `_. diff --git a/docs/index.rst b/docs/index.rst index 485567b37e..e9835e152d 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -5,138 +5,15 @@ Machine Learning Compilation for Large Language Models (MLC LLM) is a high-performance universal deployment solution that allows native deployment of any large language models with native APIs with compiler acceleration. The mission of this project is to enable everyone to develop, optimize and deploy AI models natively on everyone's devices with ML compilation techniques. -.. _get_started: +Quick Start +----------- -Getting Started ---------------- +Check out :ref:`quick-start` for quick start examples of using MLC LLM. -To begin with, try out MLC LLM support for int4-quantized Llama2 7B. -It is recommended to have at least 6GB free VRAM to run it. +Introduction to MLC LLM +----------------------- -.. tabs:: - - .. tab:: Python - - **Install MLC LLM Python**. :doc:`MLC LLM ` is available via pip. - It is always recommended to install it in an isolated conda virtual environment. - - **Download pre-quantized weights**. The commands below download the int4-quantized Llama2-7B from HuggingFace: - - .. code:: bash - - git lfs install && mkdir dist/ - git clone https://huggingface.co/mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC \ - dist/Llama-2-7b-chat-hf-q4f16_1-MLC - - **Download pre-compiled model library**. The pre-compiled model library is available as below: - - .. code:: bash - - git clone https://github.com/mlc-ai/binary-mlc-llm-libs.git dist/prebuilt_libs - - **Run in Python.** The following Python script showcases the Python API of MLC LLM and its stream capability: - - .. code:: python - - from mlc_llm import ChatModule - from mlc_llm.callback import StreamToStdout - - cm = ChatModule( - model="dist/Llama-2-7b-chat-hf-q4f16_1-MLC", - model_lib_path="dist/prebuilt_libs/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-cuda.so" - # Vulkan on Linux: Llama-2-7b-chat-hf-q4f16_1-vulkan.so - # Metal on macOS: Llama-2-7b-chat-hf-q4f16_1-metal.so - # Other platforms: Llama-2-7b-chat-hf-q4f16_1-{backend}.{suffix} - ) - cm.generate(prompt="What is the meaning of life?", progress_callback=StreamToStdout(callback_interval=2)) - - **Colab walkthrough.** A Jupyter notebook on `Colab `_ - is provided with detailed walkthrough of the Python API. - - **Documentation and tutorial.** Python API reference and its tutorials are `available online `_. - - .. figure:: https://raw.githubusercontent.com/mlc-ai/web-data/main/images/mlc-llm/tutorials/python-api.jpg - :width: 600 - :align: center - - MLC LLM Python API - - .. tab:: Command Line - - **Install MLC LLM**. :doc:`MLC LLM ` is available via pip. - It is always recommended to install it in an isolated conda virtual environment. - - For Windows/Linux users, make sure to have latest :ref:`Vulkan driver ` installed. - - **Run in command line**. - - .. code:: bash - - mlc_llm chat HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC - - .. tab:: Web Browser - - `WebLLM `__. MLC LLM generates performant code for WebGPU and WebAssembly, - so that LLMs can be run locally in a web browser without server resources. - - **Download pre-quantized weights**. This step is self-contained in WebLLM. - - **Download pre-compiled model library**. WebLLM automatically downloads WebGPU code to execute. - - **Check browser compatibility**. The latest Google Chrome provides WebGPU runtime and `WebGPU Report `__ as a useful tool to verify WebGPU capabilities of your browser. - - .. figure:: https://blog.mlc.ai/img/redpajama/web.gif - :width: 300 - :align: center - - MLC LLM on Web - - .. tab:: iOS - - **Install MLC Chat iOS**. It is available on AppStore: - - .. image:: https://developer.apple.com/assets/elements/badges/download-on-the-app-store.svg - :width: 135 - :target: https://apps.apple.com/us/app/mlc-chat/id6448482937 - - | - - **Requirement**. Llama2-7B model needs an iOS device with a minimum of 6GB RAM, whereas the RedPajama-3B model runs with at least 4GB RAM. - - **Tutorial and source code**. The source code of the iOS app is fully `open source `__, - and a :doc:`tutorial ` is included in documentation. - - .. figure:: https://blog.mlc.ai/img/redpajama/ios.gif - :width: 300 - :align: center - - MLC Chat on iOS - - .. tab:: Android - - **Install MLC Chat Android**. A prebuilt is available as an APK: - - .. image:: https://seeklogo.com/images/D/download-android-apk-badge-logo-D074C6882B-seeklogo.com.png - :width: 135 - :target: https://github.com/mlc-ai/binary-mlc-llm-libs/releases/download/Android/mlc-chat.apk - - | - - **Requirement**. Llama2-7B model needs a device with a minimum of 6GB RAM, whereas the RedPajama-3B model runs with at least 4GB RAM. - The demo is tested on - - - Samsung S23 with Snapdragon 8 Gen 2 chip - - Redmi Note 12 Pro with Snapdragon 685 - - Google Pixel phones - - **Tutorial and source code**. The source code of the android app is fully `open source `__, - and a :doc:`tutorial ` is included in documentation. - - .. figure:: https://blog.mlc.ai/img/android/android-recording.gif - :width: 300 - :align: center - - MLC LLM on Android +Check out :ref:`introduction-to-mlc-llm` for the introduction and tutorial of a complete workflow in MLC LLM. .. toctree:: @@ -144,8 +21,8 @@ It is recommended to have at least 6GB free VRAM to run it. :caption: Get Started :hidden: - get_started/project_overview.rst - get_started/mlc_chat_config.rst + get_started/quick_start.rst + get_started/introduction.rst .. toctree:: :maxdepth: 1 @@ -155,10 +32,11 @@ It is recommended to have at least 6GB free VRAM to run it. deploy/javascript.rst deploy/rest.rst deploy/cli.rst - deploy/python.rst + deploy/python_engine.rst deploy/ios.rst deploy/android.rst deploy/ide_integration.rst + deploy/mlc_chat_config.rst .. toctree:: :maxdepth: 1 @@ -176,7 +54,6 @@ It is recommended to have at least 6GB free VRAM to run it. :hidden: prebuilt_models.rst - prebuilt_models_deprecated.rst .. toctree:: :maxdepth: 1 diff --git a/docs/install/mlc_llm.rst b/docs/install/mlc_llm.rst index 3003abdc72..c6602559ae 100644 --- a/docs/install/mlc_llm.rst +++ b/docs/install/mlc_llm.rst @@ -61,10 +61,17 @@ Select your operating system/compute platform and run the command in your termin .. tab:: Vulkan - Supported in all Linux packages. + Supported in all Linux packages. Checkout the following instructions + to install the latest vulkan loader to avoid vulkan not found issue. .. note:: + + .. code-block:: bash + + conda install -c conda-forge gcc libvulkan-loader + + If encountering issues with GLIBC not found, please install the latest glibc in conda: .. code-block:: bash diff --git a/docs/install/tvm.rst b/docs/install/tvm.rst index 7fbd3d08ad..849152cce6 100644 --- a/docs/install/tvm.rst +++ b/docs/install/tvm.rst @@ -160,7 +160,8 @@ While it is generally recommended to always use the prebuilt TVM Unity, if you r conda create -n tvm-build-venv -c conda-forge \ "llvmdev>=15" \ "cmake>=3.24" \ - git + git \ + python=3.11 # enter the build environment conda activate tvm-build-venv diff --git a/docs/prebuilt_models.rst b/docs/prebuilt_models.rst index e299f68138..f97909a515 100644 --- a/docs/prebuilt_models.rst +++ b/docs/prebuilt_models.rst @@ -44,7 +44,7 @@ We quickly go over how to use prebuilt models for each platform. You can find de **Prebuilt Models on CLI / Python** -For more, please see :doc:`the CLI page `, and the :doc:`the Python page `. +For more, please see :ref:`the CLI page `, and the :ref:`the Python page `. .. collapse:: Click to show details @@ -71,7 +71,7 @@ For more, please see :doc:`the CLI page `, and the :doc:`the Python mlc_llm chat HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC - To run the model with Python API, see :doc:`the Python page ` (all other downloading steps are the same as CLI). + To run the model with Python API, see :ref:`the Python page ` (all other downloading steps are the same as CLI). .. for a blank line diff --git a/examples/python/sample_mlc_engine.py b/examples/python/sample_mlc_engine.py new file mode 100644 index 0000000000..e26e17f1e2 --- /dev/null +++ b/examples/python/sample_mlc_engine.py @@ -0,0 +1,17 @@ +from mlc_llm import LLMEngine + +# Create engine +model = "HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC" +engine = LLMEngine(model) + +# Run chat completion in OpenAI API. +for response in engine.chat.completions.create( + messages=[{"role": "user", "content": "What is the meaning of life?"}], + model=model, + stream=True, +): + for choice in response.choices: + print(choice.delta.content, end="", flush=True) +print("\n") + +engine.terminate() diff --git a/python/mlc_llm/__init__.py b/python/mlc_llm/__init__.py index f68567e772..1654010664 100644 --- a/python/mlc_llm/__init__.py +++ b/python/mlc_llm/__init__.py @@ -2,6 +2,9 @@ MLC Chat is the app runtime of MLC LLM. """ + # from . import protocol, serve # from .chat_module import ChatConfig, ChatModule, ConvConfig, GenerationConfig from .libinfo import __version__ + +# from .serve import AsyncLLMEngine, LLMEngine diff --git a/python/mlc_llm/base.py b/python/mlc_llm/base.py index 13c7ba9f84..308426d210 100644 --- a/python/mlc_llm/base.py +++ b/python/mlc_llm/base.py @@ -1,4 +1,5 @@ """Load MLC LLM library and _ffi_api functions.""" + import ctypes import os import sys @@ -23,6 +24,24 @@ def _load_mlc_llm_lib(): return ctypes.CDLL(lib_path[0]), lib_path[0] +@tvm.register_func("mlc.debug_cuda_profiler_start") +def _debug_cuda_profiler_start() -> None: + """Start cuda profiler.""" + import cuda # pylint: disable=import-outside-toplevel + import cuda.cudart # pylint: disable=import-outside-toplevel + + cuda.cudart.cudaProfilerStart() # pylint: disable=c-extension-no-member + + +@tvm.register_func("mlc.debug_cuda_profiler_stop") +def _debug_cuda_profiler_stop() -> None: + """Stop cuda profiler.""" + import cuda # pylint: disable=import-outside-toplevel + import cuda.cudart # pylint: disable=import-outside-toplevel + + cuda.cudart.cudaProfilerStop() # pylint: disable=c-extension-no-member + + # only load once here if SKIP_LOADING_MLCLLM_SO == "0": _LIB, _LIB_PATH = _load_mlc_llm_lib() diff --git a/python/mlc_llm/cli/serve.py b/python/mlc_llm/cli/serve.py index 4ad2319390..9f7c1c3580 100644 --- a/python/mlc_llm/cli/serve.py +++ b/python/mlc_llm/cli/serve.py @@ -4,6 +4,7 @@ from mlc_llm.help import HELP from mlc_llm.interface.serve import serve +from mlc_llm.serve.config import SpeculativeMode from mlc_llm.support.argparse import ArgumentParser @@ -29,15 +30,33 @@ def main(argv): help=HELP["model_lib_path"] + ' (default: "%(default)s")', ) parser.add_argument( - "--max-batch-size", - type=int, - default=80, - help=HELP["max_batch_size"] + ' (default: "%(default)s")', + "--mode", + type=str, + choices=["local", "interactive", "server"], + default="local", + help=HELP["mode_serve"] + ' (default: "%(default)s")', + ) + parser.add_argument( + "--additional-models", type=str, nargs="*", help=HELP["additional_models_serve"] ) + parser.add_argument("--max-batch-size", type=int, help=HELP["max_batch_size"]) parser.add_argument( "--max-total-seq-length", type=int, help=HELP["max_total_sequence_length_serve"] ) parser.add_argument("--prefill-chunk-size", type=int, help=HELP["prefill_chunk_size_serve"]) + parser.add_argument( + "--gpu-memory-utilization", type=float, help=HELP["gpu_memory_utilization_serve"] + ) + parser.add_argument( + "--speculative-mode", + type=str, + choices=["DISABLE", "SMALL_DRAFT", "EAGLE"], + default="DISABLE", + help=HELP["speculative_mode_serve"], + ) + parser.add_argument( + "--spec-draft-length", type=int, default=4, help=HELP["spec_draft_length_serve"] + ) parser.add_argument("--enable-tracing", action="store_true", help=HELP["enable_tracing_serve"]) parser.add_argument( "--host", @@ -76,9 +95,14 @@ def main(argv): model=parsed.model, device=parsed.device, model_lib_path=parsed.model_lib_path, + mode=parsed.mode, + additional_models=parsed.additional_models, max_batch_size=parsed.max_batch_size, max_total_sequence_length=parsed.max_total_seq_length, prefill_chunk_size=parsed.prefill_chunk_size, + gpu_memory_utilization=parsed.gpu_memory_utilization, + speculative_mode=SpeculativeMode[parsed.speculative_mode], + spec_draft_length=parsed.spec_draft_length, enable_tracing=parsed.enable_tracing, host=parsed.host, port=parsed.port, diff --git a/python/mlc_llm/compiler_pass/attach_sampler.py b/python/mlc_llm/compiler_pass/attach_sampler.py index 2d28730a9b..1b7b0328a9 100644 --- a/python/mlc_llm/compiler_pass/attach_sampler.py +++ b/python/mlc_llm/compiler_pass/attach_sampler.py @@ -20,6 +20,7 @@ def __init__(self, target: tvm.target.Target, variable_bounds: Dict[str, int]): "num_samples": max_batch_size, "num_positions": 6 * max_batch_size, } + self.non_negative_var = ["vocab_size"] self.target = target def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: @@ -29,7 +30,15 @@ def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IR return mod bb = relax.BlockBuilder(mod) - vocab_size = mod["prefill"].ret_struct_info.fields[0].shape[-1] + # Prefill method exists in base models. + # Prefill_to_last_hidden method exists in base model and speculative small models + if "prefill" in mod: + vocab_size = mod["prefill"].ret_struct_info.fields[0].shape[-1] + else: + assert ( + "prefill_to_last_hidden_states" in mod + ), "Everay model should either has 'prefill' or 'prefill_to_last_hidden_states' method" + vocab_size = mod["prefill_to_last_hidden_states"].ret_struct_info.fields[0].shape[-1] gv_names = [ gv.name_hint for gv in [ @@ -42,7 +51,11 @@ def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IR mod = bb.finalize() for gv_name in gv_names: - mod[gv_name] = mod[gv_name].with_attr("tir_var_upper_bound", self.variable_bounds) + mod[gv_name] = ( + mod[gv_name] + .with_attr("tir_var_upper_bound", self.variable_bounds) + .with_attr("tir_non_negative_var", self.non_negative_var) + ) return mod diff --git a/python/mlc_llm/compiler_pass/attach_support_info.py b/python/mlc_llm/compiler_pass/attach_support_info.py index dbeb621fdc..f4a332f115 100644 --- a/python/mlc_llm/compiler_pass/attach_support_info.py +++ b/python/mlc_llm/compiler_pass/attach_support_info.py @@ -13,12 +13,15 @@ class AttachVariableBounds: # pylint: disable=too-few-public-methods def __init__(self, variable_bounds: Dict[str, int]): # Specifically for RWKV workloads, which contains -1 max_seq_len self.variable_bounds = {k: v for k, v in variable_bounds.items() if v > 0} + self.non_negative_var = ["vocab_size"] def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: """Entrypoint""" for g_var, func in mod.functions_items(): if isinstance(func, relax.Function): - mod[g_var] = func.with_attr("tir_var_upper_bound", self.variable_bounds) + mod[g_var] = func.with_attr("tir_var_upper_bound", self.variable_bounds).with_attr( + "tir_non_negative_var", self.non_negative_var + ) return mod diff --git a/python/mlc_llm/compiler_pass/dispatch_kv_cache_creation.py b/python/mlc_llm/compiler_pass/dispatch_kv_cache_creation.py index 20e4c7bdd9..d9d478cd1f 100644 --- a/python/mlc_llm/compiler_pass/dispatch_kv_cache_creation.py +++ b/python/mlc_llm/compiler_pass/dispatch_kv_cache_creation.py @@ -155,7 +155,7 @@ def create_flashinfer_paged_kv_cache( in self.metadata["model_type"] ) # filter by attention group size - or kwargs["num_attention_heads"] // kwargs["num_key_value_heads"] not in [1, 4, 8] + or kwargs["num_attention_heads"] // kwargs["num_key_value_heads"] not in [1, 4, 6, 8] ): return diff --git a/python/mlc_llm/compiler_pass/estimate_memory_usage.py b/python/mlc_llm/compiler_pass/estimate_memory_usage.py index bd2fb03d38..d69d99109d 100644 --- a/python/mlc_llm/compiler_pass/estimate_memory_usage.py +++ b/python/mlc_llm/compiler_pass/estimate_memory_usage.py @@ -23,14 +23,16 @@ def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IR """Entrypoint""" mod = mod.clone() + func_name = "_metadata" + def _emit_metadata(metadata): bb = relax.BlockBuilder() # pylint: disable=invalid-name - with bb.function("_metadata", params=[]): + with bb.function(func_name, params=[]): bb.emit_func_output(relax.StringImm(json.dumps(metadata))) - return bb.finalize()["_metadata"] + return bb.finalize()[func_name] self.metadata["memory_usage"] = _MemoryEstimator().run(mod) - mod["_metadata"] = _emit_metadata(self.metadata) + mod[func_name] = _emit_metadata(self.metadata) return mod diff --git a/python/mlc_llm/compiler_pass/fuse_dequantize_matmul_ewise.py b/python/mlc_llm/compiler_pass/fuse_dequantize_matmul_ewise.py index f8a64c8cda..0943828933 100644 --- a/python/mlc_llm/compiler_pass/fuse_dequantize_matmul_ewise.py +++ b/python/mlc_llm/compiler_pass/fuse_dequantize_matmul_ewise.py @@ -15,7 +15,7 @@ def transform_module( ) -> IRModule: """IRModule-level transformation""" seq = [] - for n_aux_tensor in [1, 2, 3, 4]: + for n_aux_tensor in [0, 1, 2, 3, 4]: for match_ewise in [0, 1, 2, 6]: if match_ewise == 6 and n_aux_tensor != 4: continue diff --git a/python/mlc_llm/compiler_pass/lift_global_buffer_alloc.py b/python/mlc_llm/compiler_pass/lift_global_buffer_alloc.py index 444c2cf3ef..1bda34b387 100644 --- a/python/mlc_llm/compiler_pass/lift_global_buffer_alloc.py +++ b/python/mlc_llm/compiler_pass/lift_global_buffer_alloc.py @@ -1,96 +1,103 @@ """A compiler pass that lifts TIR-level global allocation to Relax.""" -from typing import Dict, List, Tuple, Optional + +from typing import Dict, List, Tuple import tvm from tvm import relax, tir from tvm.ir.module import IRModule +from tvm.relax.analysis import remove_all_unused from tvm.relax.expr_functor import PyExprMutator, mutator -def LiftTIRGlobalBufferAlloc(): - @mutator - class TIRGlobalAllocRewriter(PyExprMutator): - def __init__(self, mod: IRModule): - super().__init__(mod) - self.mod = mod - - def transform(self) -> IRModule: - self.mod = self.builder_.get() - for gv, func in self.mod.functions.items(): - if isinstance(func, relax.Function): - updated_func = self.visit_expr(func) - self.builder_.update_func(gv, updated_func) - return self.builder_.get() - - def visit_call_(self, call: relax.Call): - call = self.visit_expr_post_order(call) - if call.op != tvm.ir.Op.get("relax.call_tir"): +@tvm.transform.module_pass(opt_level=0, name="LiftTIRGlobalBufferAlloc") +class LiftTIRGlobalBufferAlloc: # pylint: disable=too-few-public-methods + """A compiler pass that lifts TIR-level global allocation to Relax.""" + + def transform_module( + self, + mod: IRModule, + _ctx: tvm.transform.PassContext, + ) -> IRModule: + """IRModule-level transformation""" + return _TIRGlobalAllocRewriter(mod).transform() + + +@mutator +class _TIRGlobalAllocRewriter(PyExprMutator): # pylint: disable=abstract-method + def __init__(self, mod: IRModule): + super().__init__(mod) + self.mod = mod + self.gv2new_tensor_sinfo: Dict[ + tvm.ir.GlobalVar, Tuple[tvm.ir.GlobalVar, List[relax.TensorStructInfo], tir.PrimFunc] + ] = {} + + def transform(self) -> IRModule: + """Entry point of the transformation""" + for g_var, func in self.mod.functions_items(): + # TODO(@eric): This is a temporary hack to get around with two functions for BYOC. + if isinstance(func, tir.PrimFunc) and g_var.name_hint not in [ + "is_bfloat16_dtype", + "is_float32_dtype", + ]: + updated_func, tensor_sinfo_list = remove_global_buf_alloc(func) + if len(tensor_sinfo_list) > 0: + new_gv = self.builder_.add_func(updated_func, g_var.name_hint) + self.gv2new_tensor_sinfo[g_var] = (new_gv, tensor_sinfo_list, func) + + self.mod = self.builder_.get() + for g_var, func in self.mod.functions_items(): + if isinstance(func, relax.Function): + updated_func = self.visit_expr(func) + updated_func = remove_all_unused(updated_func) + self.builder_.update_func(g_var, updated_func) + + mod = self.builder_.get() + return relax.transform.DeadCodeElimination()(mod) + + def visit_call_(self, call: relax.Call): # pylint: disable=arguments-renamed + call = self.visit_expr_post_order(call) + if ( + call.op != tvm.ir.Op.get("relax.call_tir") + or call.args[0] not in self.gv2new_tensor_sinfo + ): + return call + + g_var = call.args[0] + new_gv, tensor_sinfo, func_before_update = self.gv2new_tensor_sinfo[g_var] + + assert len(call.sinfo_args) == 1 + if any(_has_symbolic_var(sinfo) for sinfo in tensor_sinfo): + tensor_sinfo, success = _resolve_tir_var_mapping(func_before_update, call, tensor_sinfo) + if not success: + # Cannot resolve TIR var mapping. Fall back to no lifting. + self.gv2new_tensor_sinfo.pop(g_var) return call - old_gvar = call.args[0] - - func_before_update = self.mod.functions[old_gvar] - updates = remove_global_buf_alloc(func_before_update) - if updates is None: - return call - updated_func, tensor_sinfo = updates - - assert len(call.sinfo_args) == 1 - if any(_has_symbolic_var(sinfo) for sinfo in tensor_sinfo): - tensor_sinfo, success = _resolve_tir_var_mapping( - func_before_update, call, tensor_sinfo - ) - if not success: - # Cannot resolve TIR var mapping. Fall back to no lifting. - return call - - new_gvar = self.builder_.add_func(updated_func, old_gvar.name_hint) - new_args = [new_gvar, *call.args[1:]] - - if isinstance(call.sinfo_args[0], relax.TensorStructInfo): - new_call = relax.Call( - call.op, - args=new_args, - sinfo_args=[relax.TupleStructInfo(list(call.sinfo_args) + tensor_sinfo)], - attrs=call.attrs, - ) - emitted_tuple = self.builder_.emit(new_call) - return relax.TupleGetItem(emitted_tuple, 0) - elif isinstance(call.sinfo_args[0], relax.TupleStructInfo): - return relax.Call( - call.op, - args=new_args, - sinfo_args=[ - relax.TupleStructInfo(list(call.sinfo_args[0].fields) + tensor_sinfo) - ], - attrs=call.attrs, - ) - else: - raise TypeError( - f"Expected {call.op} to return either R.Tensor or R.Tuple, " - f"but instead returned {call.sinfo_args[0]}" - ) - - @tvm.transform.module_pass(opt_level=0, name="LiftTIRGlobalBufferAlloc.Inner") - def transform_module(mod: IRModule, _: tvm.transform.PassContext) -> IRModule: - return TIRGlobalAllocRewriter(mod).transform() - - return tvm.ir.transform.Sequential( - [ - transform_module, - tvm.relax.transform.DeadCodeElimination(), - ], - name="LiftTIRGlobalBufferAlloc", - ) + args = list(call.args) + args[0] = new_gv + if isinstance(call.sinfo_args[0], relax.TensorStructInfo): + new_call = relax.Call( + call.op, + args=args, + sinfo_args=[relax.TupleStructInfo(list(call.sinfo_args) + tensor_sinfo)], + attrs=call.attrs, + ) + emitted_tuple = self.builder_.emit(new_call) + return relax.TupleGetItem(emitted_tuple, 0) + assert isinstance(call.sinfo_args[0], relax.TupleStructInfo) + return relax.Call( + call.op, + args=args, + sinfo_args=[relax.TupleStructInfo(list(call.sinfo_args[0].fields) + tensor_sinfo)], + attrs=call.attrs, + ) def remove_global_buf_alloc( func: tir.PrimFunc, -) -> Optional[Tuple[tir.PrimFunc, List[relax.TensorStructInfo]]]: +) -> Tuple[tir.PrimFunc, List[relax.TensorStructInfo]]: """Remove the global buffer allocation for a given TIR PrimFunc.""" - if not isinstance(func.body, tir.BlockRealize): - return None - + assert isinstance(func.body, tir.BlockRealize) params = list(func.params) buffer_map = dict(func.buffer_map) tensor_sinfo = [] @@ -113,7 +120,7 @@ def remove_global_buf_alloc( alloc_buffers.append(buf_alloc) if len(tensor_sinfo) == 0: - return None + return func, [] assert len(prev_root_block.iter_vars) == 0 assert len(prev_root_block.reads) == 0 diff --git a/python/mlc_llm/conversation_template.py b/python/mlc_llm/conversation_template.py index 5976517c53..1b2a06feab 100644 --- a/python/mlc_llm/conversation_template.py +++ b/python/mlc_llm/conversation_template.py @@ -335,7 +335,7 @@ def get_conv_template(name: str) -> Optional[Conversation]: roles={"user": "", "assistant": ""}, seps=["\n"], role_content_sep=": ", - role_empty_sep=": ", + role_empty_sep=":", stop_str=[""], stop_token_ids=[0], ) @@ -475,3 +475,23 @@ def get_conv_template(name: str) -> Optional[Conversation]: system_prefix_token_ids=[1], ) ) + +# GLM +ConvTemplateRegistry.register_conv_template( + Conversation( + name="glm", + system_template=f"{MessagePlaceholders.SYSTEM.value}", + system_message="", + roles={ + "user": "问", + "assistant": "答", + "tool": "问", + }, + seps=["\n\n"], + role_content_sep=": ", + role_empty_sep=":", + stop_str=[""], + stop_token_ids=[2], + system_prefix_token_ids=[64790, 64792], + ) +) diff --git a/python/mlc_llm/help.py b/python/mlc_llm/help.py index 13335c99c1..b4321ebdec 100644 --- a/python/mlc_llm/help.py +++ b/python/mlc_llm/help.py @@ -159,4 +159,54 @@ to get the Chrome Trace. For example, "curl -X POST http://127.0.0.1:8000/debug/dump_event_trace -H "Content-Type: application/json" -d '{"model": "dist/llama"}'" """.strip(), + "mode_serve": """ +The engine mode in MLC LLM. We provide three preset modes: "local", "interactive" and "server". +The default mode is "local". +The choice of mode decides the values of "--max-batch-size", "--max-total-seq-length" and +"--prefill-chunk-size" when they are not explicitly specified. +1. Mode "local" refers to the local server deployment which has low request concurrency. + So the max batch size will be set to 4, and max total sequence length and prefill chunk size + are set to the context window size (or sliding window size) of the model. +2. Mode "interactive" refers to the interactive use of server, which has at most 1 concurrent + request. So the max batch size will be set to 1, and max total sequence length and prefill + chunk size are set to the context window size (or sliding window size) of the model. +3. Mode "server" refers to the large server use case which may handle many concurrent request + and want to use GPU memory as much as possible. In this mode, we will automatically infer + the largest possible max batch size and max total sequence length. +You can manually specify arguments "--max-batch-size", "--max-total-seq-length" and +"--prefill-chunk-size" to override the automatic inferred values. +""".strip(), + "additional_models_serve": """ +The model paths and (optional) model library paths of additional models (other than the main model). +When engine is enabled with speculative decoding, additional models are needed. +The way of specifying additional models is: +"--additional-models model_path_1 model_path_2 ..." or +"--additional-models model_path_1:model_lib_path_1 model_path_2 ...". +When the model lib path of a model is not given, JIT model compilation will be activated +to compile the model automatically. +""", + "gpu_memory_utilization_serve": """ +A number in (0, 1) denoting the fraction of GPU memory used by the server in total. +It is used to infer to maximum possible KV cache capacity. +When it is unspecified, it defaults to 0.90. +Under mode "local" or "interactive", the actual memory usage may be significantly smaller than +this number. Under mode "server", the actual memory usage may be slightly larger than this number. +""", + "speculative_mode_serve": """ +The speculative decoding mode. Right now three options are supported: + - DISABLE, where speculative decoding is not enabled, + - SMALL_DRAFT, denoting the normal speculative decoding (small draft) style, + - EAGLE, denoting the eagle-style speculative decoding. +The default mode is "DISABLE". +""", + "spec_draft_length_serve": """ +The number of draft tokens to generate in speculative proposal. The default values is 4. +""", + "engine_config_serve": """ +The LLMEngine execution configuration. +Currently speculative decoding mode is specified via engine config. +For example, you can use "--engine-config='spec_draft_length=4;speculative_mode=EAGLE'" to +specify the eagle-style speculative decoding. +Check out class `EngineConfig` in mlc_llm/serve/config.py for detailed specification. +""", } diff --git a/python/mlc_llm/interface/compiler_flags.py b/python/mlc_llm/interface/compiler_flags.py index f3a6092f6d..2d0d668672 100644 --- a/python/mlc_llm/interface/compiler_flags.py +++ b/python/mlc_llm/interface/compiler_flags.py @@ -103,7 +103,13 @@ def _flashinfer(target) -> bool: def _cublas_gemm(target, quantization) -> bool: """correct cublas_gemm flag""" - if not (target.kind.name == "cuda" and quantization.name in ["q0f16", "q0f32"]): + if not target.kind.name == "cuda": + return False + if not ( + quantization.name in ["q0f16", "q0f32"] + or "e4m3" in quantization.name + or "e5m2" in quantization.name + ): return False return self.cublas_gemm diff --git a/python/mlc_llm/interface/gen_config.py b/python/mlc_llm/interface/gen_config.py index e0d401920a..d22aa7d231 100644 --- a/python/mlc_llm/interface/gen_config.py +++ b/python/mlc_llm/interface/gen_config.py @@ -289,6 +289,7 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b "rwkv_world", "rwkv", "gorilla", + "gorilla-openfunctions-v2", "guanaco", "dolly", "oasst", diff --git a/python/mlc_llm/interface/serve.py b/python/mlc_llm/interface/serve.py index c9b9b161b5..c5696ef473 100644 --- a/python/mlc_llm/interface/serve.py +++ b/python/mlc_llm/interface/serve.py @@ -1,12 +1,14 @@ """Python entrypoint of serve.""" -from typing import Any, Optional +from typing import Any, List, Literal, Optional import fastapi import uvicorn from fastapi.middleware.cors import CORSMiddleware -from mlc_llm.serve import async_engine, config +from mlc_llm.protocol import error_protocol +from mlc_llm.serve import engine +from mlc_llm.serve.config import SpeculativeMode from mlc_llm.serve.entrypoints import debug_entrypoints, openai_entrypoints from mlc_llm.serve.server import ServerContext @@ -15,9 +17,14 @@ def serve( model: str, device: str, model_lib_path: Optional[str], - max_batch_size: int, + mode: Literal["local", "interactive", "server"], + additional_models: List[str], + max_batch_size: Optional[int], max_total_sequence_length: Optional[int], prefill_chunk_size: Optional[int], + gpu_memory_utilization: Optional[float], + speculative_mode: SpeculativeMode, + spec_draft_length: int, enable_tracing: bool, host: str, port: int, @@ -27,24 +34,24 @@ def serve( allow_headers: Any, ): # pylint: disable=too-many-arguments, too-many-locals """Serve the model with the specified configuration.""" - # Initialize model loading info and KV cache config - model_info = async_engine.ModelInfo( + # Create engine and start the background loop + async_engine = engine.AsyncLLMEngine( model=model, - model_lib_path=model_lib_path, device=device, - ) - kv_cache_config = config.KVCacheConfig( - max_num_sequence=max_batch_size, + model_lib_path=model_lib_path, + mode=mode, + additional_models=additional_models, + max_batch_size=max_batch_size, max_total_sequence_length=max_total_sequence_length, prefill_chunk_size=prefill_chunk_size, - ) - # Create engine and start the background loop - engine = async_engine.AsyncThreadedEngine( - model_info, kv_cache_config, enable_tracing=enable_tracing + gpu_memory_utilization=gpu_memory_utilization, + speculative_mode=speculative_mode, + spec_draft_length=spec_draft_length, + enable_tracing=enable_tracing, ) with ServerContext() as server_context: - server_context.add_model(model, engine) + server_context.add_model(model, async_engine) app = fastapi.FastAPI() app.add_middleware( @@ -57,4 +64,7 @@ def serve( app.include_router(openai_entrypoints.app) app.include_router(debug_entrypoints.app) + app.exception_handler(error_protocol.BadRequestError)( + error_protocol.bad_request_error_handler + ) uvicorn.run(app, host=host, port=port, log_level="info") diff --git a/python/mlc_llm/loader/huggingface_loader.py b/python/mlc_llm/loader/huggingface_loader.py index 0a4bb7649c..8cdec59523 100644 --- a/python/mlc_llm/loader/huggingface_loader.py +++ b/python/mlc_llm/loader/huggingface_loader.py @@ -1,4 +1,5 @@ """A weight loader for HuggingFace's PyTorch format""" + import gc import json from collections import OrderedDict, defaultdict diff --git a/python/mlc_llm/model/chatglm3/__init__.py b/python/mlc_llm/model/chatglm3/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/mlc_llm/model/chatglm3/chatglm3_loader.py b/python/mlc_llm/model/chatglm3/chatglm3_loader.py new file mode 100644 index 0000000000..677514f491 --- /dev/null +++ b/python/mlc_llm/model/chatglm3/chatglm3_loader.py @@ -0,0 +1,63 @@ +""" +This file specifies how MLC's ChatGLM3 parameter maps from other formats, for example HuggingFace +PyTorch, HuggingFace safetensors. +""" + +import functools + +from mlc_llm.loader import ExternMapping +from mlc_llm.quantization import Quantization + +from .chatglm3_model import ChatGLMForCausalLM, GLMConfig + + +def huggingface(model_config: GLMConfig, quantization: Quantization) -> ExternMapping: + """Returns a parameter mapping that maps from the names of MLC LLM parameters to + the names of HuggingFace PyTorch parameters. + + Parameters + ---------- + model_config : GLMConfig + The configuration of the Baichuan model. + + quantization : Quantization + The quantization configuration. + + Returns + ------- + param_map : ExternMapping + The parameter mapping from MLC to HuggingFace PyTorch. + """ + model = ChatGLMForCausalLM(model_config) + if quantization is not None: + model.to(quantization.model_dtype) + _, _named_params, _ = model.export_tvm( # type: ignore[misc] + spec=model.get_default_spec(), + allow_extern=True, + ) + named_parameters = dict(_named_params) + + mapping = ExternMapping() + + mlc_name = "transformer.embedding.weight" + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + ["transformer.embedding.word_embeddings.weight"], + functools.partial( + lambda x, dtype: x.astype(dtype), + dtype=mlc_param.dtype, + ), + ) + + for mlc_name, mlc_param in named_parameters.items(): + if mlc_name not in mapping.param_map: + mapping.add_mapping( + mlc_name, + [mlc_name], + functools.partial( + lambda x, dtype: x.astype(dtype), + dtype=mlc_param.dtype, + ), + ) + return mapping diff --git a/python/mlc_llm/model/chatglm3/chatglm3_model.py b/python/mlc_llm/model/chatglm3/chatglm3_model.py new file mode 100644 index 0000000000..f7e81019e0 --- /dev/null +++ b/python/mlc_llm/model/chatglm3/chatglm3_model.py @@ -0,0 +1,438 @@ +""" +Implementation for CHATGLM3 architecture. +TODO: add docstring +""" + +import dataclasses +from typing import Any, Dict, Optional + +from tvm import te, tir +from tvm.relax.frontend import nn +from tvm.relax.frontend.nn import Tensor, op + +from mlc_llm import op as op_ext +from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.support import logging +from mlc_llm.support import tensor_parallel as tp +from mlc_llm.support.config import ConfigBase +from mlc_llm.support.style import bold + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class GLMConfig(ConfigBase): # pylint: disable=too-many-instance-attributes + """Configuration of the ChatGLM model.""" + + hidden_size: int + num_layers: int + kv_channels: int + num_attention_heads: int + ffn_hidden_size: int + layernorm_epsilon: float + post_layer_norm: bool + rmsnorm: bool + add_bias_linear: bool + add_qkv_bias: bool + apply_query_key_layer_scaling: bool + multi_query_attention: bool + multi_query_group_num: int + vocab_size: int = 0 + context_window_size: int = 0 + prefill_chunk_size: int = 0 + tensor_parallel_shards: int = 1 + head_dim: int = 0 + max_batch_size: int = 1 + kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) + + def __post_init__(self): + if self.vocab_size == 0: + for name in ["padded_vocab_size"]: + if name in self.kwargs: + self.vocab_size = self.kwargs.pop(name) + if self.context_window_size == 0: + for name in ["max_position_embeddings", "seq_length"]: + if name in self.kwargs: + self.context_window_size = self.kwargs.pop(name) + logger.info( + "%s not found in config.json. Falling back to %s (%d)", + bold("context_window_size"), + bold(name), + self.context_window_size, + ) + break + else: + raise ValueError( + "Unable to determine the maxmimum sequence length, because none of " + "`context_window_size`, `max_position_embeddings` or `max_sequence_length` is " + "provided in `config.json`." + ) + if self.head_dim == 0: + self.head_dim = self.hidden_size // self.num_attention_heads + assert self.head_dim * self.num_attention_heads == self.hidden_size + if self.prefill_chunk_size == 0: + logger.info( + "%s defaults to %s (%d)", + bold("prefill_chunk_size"), + bold("context_window_size"), + self.context_window_size, + ) + self.prefill_chunk_size = self.context_window_size + elif self.prefill_chunk_size > self.context_window_size: + logger.info( + "Overriding %s from %d to %d (%s)", + bold("prefill_chunk_size"), + self.prefill_chunk_size, + self.context_window_size, + bold("context_window_size"), + ) + self.prefill_chunk_size = self.context_window_size + + +# pylint: disable=invalid-name,missing-docstring + + +class GLMAttention(nn.Module): # pylint: disable=too-many-instance-attributes + def __init__(self, config: GLMConfig): + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads // config.tensor_parallel_shards + self.multi_query_attention = config.multi_query_attention + self.num_key_value_heads = ( + config.multi_query_group_num + if config.multi_query_attention + else config.num_attention_heads + ) // config.tensor_parallel_shards + self.head_dim = config.head_dim + self.query_key_value = nn.Linear( + config.hidden_size, + (2 * self.num_key_value_heads + self.num_heads) * self.head_dim, + bias=config.add_bias_linear or config.add_qkv_bias, + ) + self.dense = nn.Linear( + self.num_heads * self.head_dim, config.hidden_size, bias=config.add_bias_linear + ) + + def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + d, h_q, h_kv = self.head_dim, self.num_heads, self.num_key_value_heads + b, s, _ = hidden_states.shape + qkv = self.query_key_value(hidden_states) + qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d)) + output = op.reshape( + paged_kv_cache.attention_with_fused_qkv(layer_id, qkv, h_q), + (b, s, h_q * d), + ) + attn_output = self.dense(output) + return attn_output + + +class GLMMLP(nn.Module): + def __init__(self, config: GLMConfig): + self.ffn_hidden_size = config.ffn_hidden_size // config.tensor_parallel_shards + + self.dense_h_to_4h = nn.Linear( + config.hidden_size, + self.ffn_hidden_size * 2, + bias=config.add_bias_linear, + ) + self.dense_4h_to_h = nn.Linear( + self.ffn_hidden_size, + config.hidden_size, + bias=config.add_bias_linear, + ) + + def swiglu(x): + x = nn.chunk(x, 2, dim=-1) + return nn.silu(x[0]) * x[1] + + self.activation_func = swiglu + + def forward(self, x): + intermediate_parallel = self.dense_h_to_4h(x) + intermediate_parallel = self.activation_func(intermediate_parallel) + output = self.dense_4h_to_h(intermediate_parallel) + return output + + +class GLMBlock(nn.Module): + def __init__(self, config: GLMConfig): + self.self_attention = GLMAttention(config=config) + self.mlp = GLMMLP(config) + self.input_layernorm = nn.RMSNorm( + config.hidden_size, -1, config.layernorm_epsilon, bias=False + ) + self.post_attention_layernorm = nn.RMSNorm( + config.hidden_size, -1, config.layernorm_epsilon, bias=False + ) + + def _set_tp(): + def _set(layer, hint): + layer.attrs["shard_strategy"] = hint + + hd = config.head_dim + q = self.self_attention.num_heads * hd + k = self.self_attention.num_key_value_heads * hd + v = self.self_attention.num_key_value_heads * hd + _set( + self.self_attention.query_key_value.weight, + tp.ShardSingleDim("_shard_qkv_weight", dim=0, segs=[q, k, v]), + ) + if config.add_bias_linear or config.add_qkv_bias: + _set( + self.self_attention.query_key_value.bias, + tp.ShardSingleDim("_shard_qkv_bias", dim=0, segs=[q, k, v]), + ) + _set(self.self_attention.dense.weight, tp.ShardSingleDim("_shard_dense_weight", dim=1)) + if config.add_bias_linear: + _set(self.self_attention.dense.bias, tp.ShardSingleDim("_shard_dense_bias", dim=0)) + _set( + self.mlp.dense_h_to_4h.weight, + tp.ShardSingleDim("_shard_dense_h_to_4h_weight", dim=0), + ) + if config.add_bias_linear: + _set( + self.mlp.dense_h_to_4h.bias, + tp.ShardSingleDim("_shard_dense_h_to_4h_bias", dim=0), + ) + _set(self.mlp.dense_4h_to_h.weight, tp.ShardSingleDim("_shard_dense_4h_to_h", dim=1)) + if config.add_bias_linear: + _set( + self.mlp.dense_4h_to_h.bias, + tp.ShardSingleDim("_shard_dense_4h_to_h_bias", dim=1), + ) + + self.tensor_parallel_shards = config.tensor_parallel_shards + _set_tp() + + def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + out = self.self_attention(self.input_layernorm(hidden_states), paged_kv_cache, layer_id) + hidden_states = self._apply_residual(out, residual=hidden_states) + out = self.mlp(self.post_attention_layernorm(hidden_states)) + hidden_states = self._apply_residual(out, residual=hidden_states) + return hidden_states + + def _apply_residual(self, out, residual): + if self.tensor_parallel_shards > 1: + return op.ccl_allreduce(out, "sum") + residual + return out + residual + + +class GLMTransformer(nn.Module): + """Transformer class.""" + + def __init__(self, config: GLMConfig): + self.post_layer_norm = config.post_layer_norm + + # Number of layers. + self.num_layers = config.num_layers + + # Transformer layers. + self.layers = nn.ModuleList([GLMBlock(config) for _ in range(config.num_layers)]) + + if self.post_layer_norm: + if config.rmsnorm: + self.final_layernorm = nn.RMSNorm( + config.hidden_size, -1, config.layernorm_epsilon, bias=False + ) + else: + self.final_layernorm = nn.LayerNorm(config.hidden_size, config.layernorm_epsilon) + + def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache): + hidden_states = inputs + for layer_id, layer in enumerate(self.layers): + hidden_states = layer(hidden_states, paged_kv_cache, layer_id) + hidden_states = self.final_layernorm(hidden_states) + return hidden_states + + +class ChatGLMModel(nn.Module): + def __init__(self, config: GLMConfig): + self.embedding = nn.Embedding(config.vocab_size, config.hidden_size) + self.encoder = GLMTransformer(config) + self.output_layer = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache): + hidden_states = inputs + hidden_states = self.encoder(hidden_states, paged_kv_cache) + return hidden_states + + +class ChatGLMForCausalLM(nn.Module): # pylint: disable=too-many-instance-attributes + def __init__(self, config: GLMConfig): + self.transformer = ChatGLMModel(config) + self.num_hidden_layers = config.num_layers + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = ( + config.multi_query_group_num + if config.multi_query_attention + else config.num_attention_heads + ) + self.head_dim = config.head_dim + self.vocab_size = config.vocab_size + self.rope_theta = 10000 + self.tensor_parallel_shards = config.tensor_parallel_shards + self.dtype = "float32" + + def to(self, dtype: Optional[str] = None): + super().to(dtype=dtype) + if dtype is not None: + self.dtype = dtype + + def batch_forward( + self, + input_embeds: Tensor, + paged_kv_cache: PagedKVCache, + logit_positions: Optional[Tensor] = None, + ): + op_ext.configure() + + hidden_states = self.transformer(input_embeds, paged_kv_cache) + if logit_positions is not None: + hidden_states = op.take(hidden_states, logit_positions, axis=1) + logits = self.transformer.output_layer(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits + + def embed(self, input_ids: Tensor): + if self.tensor_parallel_shards > 1: + input_ids = op.ccl_broadcast_from_worker0(input_ids) + return self.transformer.embedding(input_ids) + + def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + op_ext.configure() + + def _index(x: te.Tensor): # x[:-1,:] + b, s, d = x.shape + return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name="index") + + hidden_states = self.transformer(input_embed, paged_kv_cache) + hidden_states = op.tensor_expr_op(_index, name_hint="index", args=[hidden_states]) + logits = self.transformer.output_layer(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits, paged_kv_cache + + def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + op_ext.configure() + + hidden_states = self.transformer(input_embed, paged_kv_cache) + logits = self.transformer.output_layer(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits, paged_kv_cache + + def batch_prefill( + self, input_embeds: Tensor, logit_positions: Tensor, paged_kv_cache: PagedKVCache + ): + if self.tensor_parallel_shards > 1: + logit_positions = op.ccl_broadcast_from_worker0(logit_positions) + logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions) + return logits, paged_kv_cache + + def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): + logits = self.batch_forward(input_embeds, paged_kv_cache) + return logits, paged_kv_cache + + def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): + logits = self.batch_forward(input_embeds, paged_kv_cache) + return logits, paged_kv_cache + + def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): + return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) + + def create_paged_kv_cache( # pylint: disable=too-many-arguments + self, + max_batch_size: tir.Var, + max_total_seq_len: tir.Var, + prefill_chunk_size: tir.Var, + page_size: tir.Var, + support_sliding_window: tir.Var, + ) -> PagedKVCache: + return PagedKVCache.create_generic( + max_batch_size=max_batch_size, + max_total_seq_len=max_total_seq_len, + prefill_chunk_size=prefill_chunk_size, + page_size=page_size, + support_sliding_window=support_sliding_window, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards, + num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards, + head_dim=self.head_dim, + rope_mode=RopeMode.NORMAL, + rope_scale=1, + rope_theta=self.rope_theta, + dtype=self.dtype, + ) + + def get_default_spec(self): + mod_spec = { + "embed": { + "input_ids": nn.spec.Tensor(["seq_len"], "int32"), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "prefill": { + "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "decode": { + "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_prefill": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_decode": { + "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_verify": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "softmax_with_temperature": { + "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), + "temperature": nn.spec.Tensor(["batch_size"], "float32"), + "$": { + "param_mode": "none", + "effect_mode": "none", + }, + }, + "create_paged_kv_cache": { + "max_batch_size": int, + "max_total_seq_len": int, + "prefill_chunk_size": int, + "page_size": int, + "support_sliding_window": int, + "$": { + "param_mode": "none", + "effect_mode": "none", + }, + }, + } + return nn.spec.ModuleSpec.from_raw(mod_spec, self) diff --git a/python/mlc_llm/model/chatglm3/chatglm3_quantization.py b/python/mlc_llm/model/chatglm3/chatglm3_quantization.py new file mode 100644 index 0000000000..26b404daa8 --- /dev/null +++ b/python/mlc_llm/model/chatglm3/chatglm3_quantization.py @@ -0,0 +1,53 @@ +"""This file specifies how MLC's ChatGLM parameters are quantized using group quantization +or other formats.""" +from typing import Tuple + +from tvm.relax.frontend import nn + +from mlc_llm.loader import QuantizeMapping +from mlc_llm.quantization import FTQuantize, GroupQuantize, NoQuantize + +from .chatglm3_model import ChatGLMForCausalLM, GLMConfig + + +def group_quant( + model_config: GLMConfig, + quantization: GroupQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a ChatGLM-architecture model using group quantization.""" + model: nn.Module = ChatGLMForCausalLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def ft_quant( + model_config: GLMConfig, + quantization: FTQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a ChatGLM-architecture model using FasterTransformer quantization.""" + model: nn.Module = ChatGLMForCausalLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def no_quant( + model_config: GLMConfig, + quantization: NoQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a ChatGLM model without quantization.""" + model: nn.Module = ChatGLMForCausalLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + return model, quant_map diff --git a/python/mlc_llm/model/eagle/__init__.py b/python/mlc_llm/model/eagle/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/mlc_llm/model/eagle/eagle_loader.py b/python/mlc_llm/model/eagle/eagle_loader.py new file mode 100644 index 0000000000..36ffee8a6c --- /dev/null +++ b/python/mlc_llm/model/eagle/eagle_loader.py @@ -0,0 +1,172 @@ +""" +This file specifies how MLC's EAGLE parameter maps from other formats, for example HuggingFace +PyTorch, HuggingFace safetensors. +""" + +import functools + +import numpy as np + +from mlc_llm.loader import ExternMapping +from mlc_llm.quantization import Quantization + +from .eagle_model import EagleConfig, EagleForCasualLM +from .eagle_quantization import awq_quant + + +def huggingface(model_config: EagleConfig, quantization: Quantization) -> ExternMapping: + """Returns a parameter mapping that maps from the names of MLC LLM parameters to + the names of HuggingFace PyTorch parameters. + + Parameters + ---------- + model_config : EagleConfig + The configuration of the Eagle model. + + quantization : Quantization + The quantization configuration. + + Returns + ------- + param_map : ExternMapping + The parameter mapping from MLC to HuggingFace PyTorch. + """ + model = EagleForCasualLM(model_config) + if quantization is not None: + model.to(quantization.model_dtype) + _, _named_params, _ = model.export_tvm( # type: ignore[misc] + spec=model.get_default_spec(), + allow_extern=True, + ) + named_parameters = dict(_named_params) + + mapping = ExternMapping() + + for i in range(model_config.num_hidden_layers): + # Add QKV in self attention + attn = f"layers.{i}.self_attn" + mlc_name = f"{attn}.qkv_proj.weight" + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{attn}.q_proj.weight", + f"{attn}.k_proj.weight", + f"{attn}.v_proj.weight", + ], + functools.partial( + lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + # Add gates in MLP + mlp = f"layers.{i}.mlp" + mlc_name = f"{mlp}.gate_up_proj.weight" + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{mlp}.gate_proj.weight", + f"{mlp}.up_proj.weight", + ], + functools.partial( + lambda gate, up, dtype: np.concatenate([gate, up], axis=0).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + # inv_freq is not used in the model + mapping.add_unused(f"{attn}.rotary_emb.inv_freq") + + for mlc_name, mlc_param in named_parameters.items(): + if mlc_name not in mapping.param_map: + mapping.add_mapping( + mlc_name, + [mlc_name], + functools.partial( + lambda x, dtype: x.astype(dtype), + dtype=mlc_param.dtype, + ), + ) + return mapping + + +def awq(model_config: EagleConfig, quantization: Quantization) -> ExternMapping: + """Returns a parameter mapping that maps from the names of MLC LLM parameters to + the names of AWQ parameters. + Parameters + ---------- + model_config : EagleConfig + The configuration of the Eagle model. + + quantization : Quantization + The quantization configuration. + + Returns + ------- + param_map : ExternMapping + The parameter mapping from MLC to AWQ. + """ + model, _ = awq_quant(model_config, quantization) + _, _named_params, _ = model.export_tvm( # type: ignore[misc] + spec=model.get_default_spec(), # type: ignore[attr-defined] + allow_extern=True, + ) + named_parameters = dict(_named_params) + + mapping = ExternMapping() + + for i in range(model_config.num_hidden_layers): + # Add QKV in self attention + attn = f"layers.{i}.self_attn" + for quantize_suffix in ["qweight", "qzeros", "scales"]: + mlc_name = f"{attn}.qkv_proj.{quantize_suffix}" + assert mlc_name in named_parameters + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{attn}.q_proj.{quantize_suffix}", + f"{attn}.k_proj.{quantize_suffix}", + f"{attn}.v_proj.{quantize_suffix}", + ], + functools.partial( + lambda q, k, v, dtype: np.concatenate( + [q, k, v], + axis=1, # AWQ GEMM would transpose the weight + ).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + + # Concat gate and up in MLP + mlp = f"layers.{i}.mlp" + for quantize_suffix in ["qweight", "qzeros", "scales"]: + mlc_name = f"{mlp}.gate_up_proj.{quantize_suffix}" + assert mlc_name in named_parameters + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{mlp}.gate_proj.{quantize_suffix}", + f"{mlp}.up_proj.{quantize_suffix}", + ], + functools.partial( + lambda gate, up, dtype: np.concatenate( + [gate, up], + axis=1, # AWQ GEMM would transpose the weight + ).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + + # inv_freq is not used in the model + mapping.add_unused(f"{attn}.rotary_emb.inv_freq") + + for mlc_name, mlc_param in named_parameters.items(): + if mlc_name not in mapping.param_map: + mapping.add_mapping( + mlc_name, + [mlc_name], + functools.partial(lambda x, dtype: x.astype(dtype), dtype=mlc_param.dtype), + ) + return mapping diff --git a/python/mlc_llm/model/eagle/eagle_model.py b/python/mlc_llm/model/eagle/eagle_model.py new file mode 100644 index 0000000000..355618df09 --- /dev/null +++ b/python/mlc_llm/model/eagle/eagle_model.py @@ -0,0 +1,244 @@ +""" +Implementation for EAGLE architecture. +""" + +import dataclasses +from typing import Optional + +from tvm import tir +from tvm.relax.frontend import nn +from tvm.relax.frontend.nn import Tensor, op + +from mlc_llm import op as op_ext +from mlc_llm.model.llama.llama_model import LlamaAttention, LlamaConfig, LlamaFFN +from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.support import logging +from mlc_llm.support import tensor_parallel as tp + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class EagleConfig(LlamaConfig): + """Configuration of the Eagle model.""" + + bias: bool = True # Whether to use bias in the fc layers + + +# pylint: disable=invalid-name,missing-docstring + + +class EagleDecoderLayer(nn.Module): + def __init__(self, config: EagleConfig, index: int): + rms_norm_eps = config.rms_norm_eps + self.self_attn = LlamaAttention(config) + self.mlp = LlamaFFN(config) + self.index = index + if self.index != 0: + self.input_layernorm = nn.RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False) + self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False) + + def _set_tp(): + def _set(layer, hint): + layer.weight.attrs["shard_strategy"] = hint + + hd = config.head_dim + q = self.self_attn.num_q_heads * hd + k = self.self_attn.num_kv_heads * hd + v = self.self_attn.num_kv_heads * hd + i = self.mlp.intermediate_size + _set(self.self_attn.qkv_proj, tp.ShardSingleDim("_shard_qkv", segs=[q, k, v], dim=0)) + _set(self.self_attn.o_proj, tp.ShardSingleDim("_shard_o", dim=1)) + _set(self.mlp.gate_up_proj, tp.ShardSingleDim("_shard_mlp_up", segs=[i, i], dim=0)) + _set(self.mlp.down_proj, tp.ShardSingleDim("_shard_mlp_down", dim=1)) + + self.tensor_parallel_shards = config.tensor_parallel_shards + _set_tp() + + def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + if self.index != 0: + hidden_states = self.input_layernorm(hidden_states) + out = self.self_attn(hidden_states, paged_kv_cache, layer_id) + hidden_states = self._apply_residual(out, residual=hidden_states) + out = self.mlp(self.post_attention_layernorm(hidden_states)) + hidden_states = self._apply_residual(out, residual=hidden_states) + return hidden_states + + def _apply_residual(self, out, residual): + if self.tensor_parallel_shards > 1: + return op.ccl_allreduce(out, "sum") + residual + return out + residual + + +class EagleForCasualLM(nn.Module): # pylint: disable=too-many-instance-attributes + def __init__(self, config: EagleConfig): + # Put the model definition here to align with EAGLE's original structure + assert config.hidden_size % config.num_attention_heads == 0 + self.embed_tokens = nn.Embedding("vocab_size", config.hidden_size) + self.layers = nn.ModuleList( + [EagleDecoderLayer(config, i) for i in range(config.num_hidden_layers)] + ) + self.fc = nn.Linear( + in_features=2 * config.hidden_size, out_features=config.hidden_size, bias=config.bias + ) + + self.num_hidden_layers = config.num_hidden_layers + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.head_dim = config.head_dim + self.hidden_size = config.hidden_size + self.vocab_size = config.vocab_size + self.rope_theta = config.position_embedding_base + self.tensor_parallel_shards = config.tensor_parallel_shards + self.dtype = "float32" + + def fuse_embed_hidden_states(self, input_embed: Tensor, hidden_states: Tensor): + hidden_states = op.concat([input_embed, hidden_states], dim=-1) + hidden_states = self.fc(hidden_states) + return hidden_states + + def forward_to_last_hidden_states(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache): + for layer_id, layer in enumerate(self.layers): + hidden_states = layer(hidden_states, paged_kv_cache, layer_id) + return hidden_states + + def forward(self, input_embed: Tensor, hidden_states: Tensor, paged_kv_cache: PagedKVCache): + hidden_states = self.fuse_embed_hidden_states(input_embed, hidden_states) + hidden_states = self.forward_to_last_hidden_states(hidden_states, paged_kv_cache) + return hidden_states + + def to(self, dtype: Optional[str] = None): + super().to(dtype=dtype) + if dtype is not None: + self.dtype = dtype + + def batch_forward( + self, + hidden_states: Tensor, + paged_kv_cache: PagedKVCache, + logit_positions: Optional[Tensor] = None, + ): + op_ext.configure() + + hidden_states = self.forward_to_last_hidden_states(hidden_states, paged_kv_cache) + if logit_positions is not None: + hidden_states = op.take(hidden_states, logit_positions, axis=1) + return hidden_states + + def embed(self, input_ids: Tensor): + if self.tensor_parallel_shards > 1: + input_ids = op.ccl_broadcast_from_worker0(input_ids) + return self.embed_tokens(input_ids) + + def prefill_to_last_hidden_states(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache): + op_ext.configure() + + hidden_states = self.forward_to_last_hidden_states(hidden_states, paged_kv_cache) + return hidden_states, paged_kv_cache + + def decode_to_last_hidden_states(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache): + op_ext.configure() + + hidden_states = self.forward_to_last_hidden_states(hidden_states, paged_kv_cache) + return hidden_states, paged_kv_cache + + def batch_prefill_to_last_hidden_states( + self, + hidden_states: Tensor, + paged_kv_cache: PagedKVCache, + ): + hidden_states = self.batch_forward(hidden_states, paged_kv_cache) + return hidden_states, paged_kv_cache + + def batch_decode_to_last_hidden_states( + self, hidden_states: Tensor, paged_kv_cache: PagedKVCache + ): + hidden_states = self.batch_forward(hidden_states, paged_kv_cache) + return hidden_states, paged_kv_cache + + def create_paged_kv_cache( # pylint: disable=too-many-arguments + self, + max_batch_size: tir.Var, + max_total_seq_len: tir.Var, + prefill_chunk_size: tir.Var, + page_size: tir.Var, + support_sliding_window: tir.Var, + ) -> PagedKVCache: + return PagedKVCache.create_generic( + max_batch_size=max_batch_size, + max_total_seq_len=max_total_seq_len, + prefill_chunk_size=prefill_chunk_size, + page_size=page_size, + support_sliding_window=support_sliding_window, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards, + num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards, + head_dim=self.head_dim, + rope_mode=RopeMode.NORMAL, + rope_scale=1, + rope_theta=self.rope_theta, + dtype=self.dtype, + ) + + def get_default_spec(self): + mod_spec = { + "embed": { + "input_ids": nn.spec.Tensor(["seq_len"], "int32"), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "fuse_embed_hidden_states": { + "input_embed": nn.spec.Tensor(["length", self.hidden_size], self.dtype), + "hidden_states": nn.spec.Tensor(["length", self.hidden_size], self.dtype), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "prefill_to_last_hidden_states": { + "hidden_states": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "decode_to_last_hidden_states": { + "hidden_states": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_prefill_to_last_hidden_states": { + "hidden_states": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_decode_to_last_hidden_states": { + "hidden_states": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "create_paged_kv_cache": { + "max_batch_size": int, + "max_total_seq_len": int, + "prefill_chunk_size": int, + "page_size": int, + "support_sliding_window": int, + "$": { + "param_mode": "none", + "effect_mode": "none", + }, + }, + } + return nn.spec.ModuleSpec.from_raw(mod_spec, self) diff --git a/python/mlc_llm/model/eagle/eagle_quantization.py b/python/mlc_llm/model/eagle/eagle_quantization.py new file mode 100644 index 0000000000..a926f7d9dd --- /dev/null +++ b/python/mlc_llm/model/eagle/eagle_quantization.py @@ -0,0 +1,70 @@ +"""This file specifies how MLC's Eagle parameters are quantized using group quantization +or other formats.""" + +from typing import Tuple + +from tvm.relax.frontend import nn + +from mlc_llm.loader import QuantizeMapping +from mlc_llm.quantization import AWQQuantize, FTQuantize, GroupQuantize, NoQuantize + +from .eagle_model import EagleConfig, EagleForCasualLM + + +def group_quant( + model_config: EagleConfig, + quantization: GroupQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Eagle-architecture model using group quantization.""" + model: nn.Module = EagleForCasualLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def ft_quant( + model_config: EagleConfig, + quantization: FTQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Eagle-architecture model using FasterTransformer quantization.""" + model: nn.Module = EagleForCasualLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def awq_quant( + model_config: EagleConfig, + quantization: AWQQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Eagle-architecture model using Activation-aware Weight Quantization(AWQ).""" + model: nn.Module = EagleForCasualLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def no_quant( + model_config: EagleConfig, + quantization: NoQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Eagle model without quantization.""" + model: nn.Module = EagleForCasualLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + return model, quant_map diff --git a/python/mlc_llm/model/gemma/gemma_model.py b/python/mlc_llm/model/gemma/gemma_model.py index 5950ab2972..118f3ce856 100644 --- a/python/mlc_llm/model/gemma/gemma_model.py +++ b/python/mlc_llm/model/gemma/gemma_model.py @@ -39,7 +39,7 @@ class GemmaConfig(ConfigBase): # pylint: disable=too-many-instance-attributes kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) def __post_init__(self): - if self.hidden_act != "gelu": + if self.hidden_act not in ("gelu", "gelu_pytorch_tanh"): raise ValueError("Only GeLU is supported as the activation for gemma.") if self.attention_bias: raise ValueError('Only "False" attention_bias is supported for gemma') @@ -115,7 +115,7 @@ def __init__(self, config: GemmaConfig): def forward(self, x: Tensor): concat_x1_x2 = self.gate_up_proj(x) x1, x2 = op.split(concat_x1_x2, 2, axis=-1) - return self.down_proj(op.gelu(x1) * x2) + return self.down_proj(op.gelu(x1, approximate="tanh") * x2) class GemmaAttention(nn.Module): # pylint: disable=too-many-instance-attributes diff --git a/python/mlc_llm/model/llava/llava_model.py b/python/mlc_llm/model/llava/llava_model.py index 30963f990c..1498c13fdb 100644 --- a/python/mlc_llm/model/llava/llava_model.py +++ b/python/mlc_llm/model/llava/llava_model.py @@ -23,10 +23,12 @@ from tvm.relax.op import arange, strided_slice from mlc_llm import op as op_ext +from mlc_llm.model.model_preset import MODEL_PRESETS from mlc_llm.nn import PagedKVCache, RopeMode from ...support.config import ConfigBase from ..llama.llama_model import LlamaConfig, LlamaForCasualLM +from ..mistral.mistral_model import MistralConfig, MistralForCasualLM logger = logging.getLogger(__name__) @@ -45,12 +47,15 @@ class LlavaVisionConfig(ConfigBase): # pylint: disable=too-many-instance-attrib patch_size: int projection_dim: int vocab_size: int - dtype: str = "float16" num_channels: int = 3 layer_norm_eps: float = 1e-06 kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) +CONFIG_MAP = {"LlamaForCausalLM": LlamaConfig, "MistralForCausalLM": MistralConfig} +ARCHITECTURE_MAP = {"LlamaForCausalLM": LlamaForCasualLM, "MistralForCausalLM": MistralForCasualLM} + + @dataclasses.dataclass class LlavaConfig(ConfigBase): # pylint: disable=too-many-instance-attributes """ @@ -61,11 +66,12 @@ class LlavaConfig(ConfigBase): # pylint: disable=too-many-instance-attributes text_config: LlamaConfig vision_config: LlavaVisionConfig vocab_size: int - context_window_size: int = 0 - prefill_chunk_size: int = 0 + context_window_size: int = -1 + sliding_window_size: int = -1 + prefill_chunk_size: int = -1 tensor_parallel_shards: int = 1 - dtype: str = "float16" max_batch_size: int = 1 + text_architecture: str = "LlamaForCausalLM" kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) def __post_init__(self): @@ -81,41 +87,54 @@ def __post_init__(self): self.vision_config = LlavaVisionConfig.from_dict(vision_config_dict) text_config_dict: Dict[str, Any] - if isinstance(self.text_config, LlamaConfig): + if isinstance(self.text_config, ConfigBase): text_config_dict = dataclasses.asdict(self.text_config) else: text_config_dict = dict(self.text_config) if "_name_or_path" in text_config_dict: - if text_config_dict["_name_or_path"] == "meta-llama/Llama-2-7b-hf": - text_config_dict["hidden_size"] = text_config_dict.pop("hidden_size", 4096) - text_config_dict["intermediate_size"] = text_config_dict.pop( - "intermediate_size", 11008 - ) - text_config_dict["num_attention_heads"] = text_config_dict.pop( - "num_attention_heads", 32 - ) - text_config_dict["num_hidden_layers"] = text_config_dict.pop( - "num_hidden_layers", 32 - ) - text_config_dict["rms_norm_eps"] = text_config_dict.pop("rms_norm_eps", 1e-06) - text_config_dict["vocab_size"] = text_config_dict.pop("vocab_size", 32064) - text_config_dict["context_window_size"] = text_config_dict.pop( - "context_window_size", 4096 - ) - else: - raise ValueError("Unsupported text model") + hf_config = self.get_hf_config(text_config_dict) + text_config_dict.update(hf_config) + architectures = text_config_dict["architectures"] + assert len(architectures) == 1 + self.text_architecture = architectures[0] else: for k, v in text_config_dict.pop("kwargs", {}).items(): text_config_dict[k] = v - self.text_config = LlamaConfig.from_dict(text_config_dict) - - if self.context_window_size <= 0: - self.context_window_size = self.text_config.context_window_size + self.text_config = CONFIG_MAP[self.text_architecture].from_dict(text_config_dict) + + for k in ["context_window_size", "sliding_window_size", "prefill_chunk_size"]: + if getattr(self, k) <= 0: + if hasattr(self.text_config, k): + setattr(self, k, getattr(self.text_config, k)) + + def get_hf_config(self, text_config_dict: Dict[str, Any]) -> Dict[str, Any]: + """ + Get the Hugging Face config of the text model + """ + + hf_config: Dict[str, Any] + try: + # pylint: disable=import-outside-toplevel, import-error + from transformers import AutoConfig + + hf_config = AutoConfig.from_pretrained(text_config_dict["_name_or_path"]).to_dict() + except (ImportError, OSError) as e: + # If transformers is not installed, get the config from preset + # Llama2 is gated so it throws an OSError. Get the config from preset instead + preset_mapping = { + "meta-llama/Llama-2-7b-hf": "llama2_7b", + "meta-llama/Llama-2-13b-hf": "llama2_13b", + "lmsys/vicuna-7b-v1.5": "llama2_7b", + "mistralai/Mistral-7B-v0.1": "mistral_7b", + } + if text_config_dict["_name_or_path"] in preset_mapping: + hf_config = MODEL_PRESETS[preset_mapping[text_config_dict["_name_or_path"]]] + else: + raise ValueError("Unsupported text model") from e - if self.prefill_chunk_size <= 0: - self.prefill_chunk_size = self.text_config.prefill_chunk_size + return hf_config # pylint: disable=missing-docstring @@ -128,21 +147,18 @@ def __init__(self, config: LlavaVisionConfig): self.embed_dim = config.hidden_size self.image_size = config.image_size self.patch_size = config.patch_size - self.class_embedding = nn.Parameter((self.embed_dim,), dtype=config.dtype) + self.class_embedding = nn.Parameter((self.embed_dim,)) self.patch_embedding = Conv2D( in_channels=config.num_channels, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, bias=False, - dtype=config.dtype, ) self.num_patches = (self.image_size // self.patch_size) ** 2 self.num_positions = self.num_patches + 1 - self.position_embedding = nn.Embedding( - num=self.num_positions, dim=self.embed_dim, dtype=config.dtype - ) + self.position_embedding = nn.Embedding(num=self.num_positions, dim=self.embed_dim) def forward(self, pixel_values: Tensor) -> Tensor: batch_size = pixel_values.shape[0] @@ -194,8 +210,8 @@ class CLIPMLP(Module): def __init__(self, config: LlavaVisionConfig): super().__init__() self.activation_fn = LlavaQuickGELU() - self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size, dtype=config.dtype) - self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size, dtype=config.dtype) + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) def forward(self, hidden_states: Tensor) -> Tensor: hidden_states = self.fc1(hidden_states) @@ -216,10 +232,10 @@ def __init__(self, config: LlavaVisionConfig): f" and `num_heads`: {self.num_heads})." ) self.scale = self.head_dim**-0.5 - self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, dtype=config.dtype) - self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, dtype=config.dtype) - self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, dtype=config.dtype) - self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, dtype=config.dtype) + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) def _shape(self, tensor: Tensor, seq_len: int, bsz: int): reshape_tensor = reshape(tensor, shape=(bsz, seq_len, self.num_heads, self.head_dim)) @@ -263,13 +279,9 @@ def __init__(self, config: LlavaVisionConfig): super().__init__() self.embed_dim = config.hidden_size self.self_attn = CLIPAttention(config) - self.layer_norm1 = nn.LayerNorm( - normalized_shape=self.embed_dim, eps=config.layer_norm_eps, dtype=config.dtype - ) + self.layer_norm1 = nn.LayerNorm(normalized_shape=self.embed_dim, eps=config.layer_norm_eps) self.mlp = CLIPMLP(config) - self.layer_norm2 = nn.LayerNorm( - normalized_shape=self.embed_dim, eps=config.layer_norm_eps, dtype=config.dtype - ) + self.layer_norm2 = nn.LayerNorm(normalized_shape=self.embed_dim, eps=config.layer_norm_eps) def forward(self, hidden_states: Tensor) -> Tensor: residual = hidden_states @@ -308,9 +320,9 @@ def __init__(self, config: LlavaVisionConfig): super().__init__() embed_dim = config.hidden_size self.embeddings = CLIPVisionEmbeddings(config) - self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps, dtype=config.dtype) + self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.encoder = CLIPEncoder(config) - self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps, dtype=config.dtype) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) def forward(self, pixel_values: Tensor) -> Tensor: hidden_states = self.embeddings(pixel_values) @@ -353,9 +365,15 @@ def __init__(self, config: LlavaConfig): self.config = config self.vision_tower = CLIPVisionModel(config.vision_config) self.multi_modal_projector = LlavaMultiModalProjector(config) - self.language_model = LlamaForCasualLM(config.text_config) + self.language_model = ARCHITECTURE_MAP[config.text_architecture](config.text_config) self.vocab_size = config.vocab_size - self.dtype = config.dtype + self.dtype = "float32" + + def to(self, dtype: Optional[str] = None): + super().to(dtype=dtype) + self.language_model.to(dtype=dtype) + if dtype is not None: + self.dtype = dtype def _embed_input_ids(self, input_ids: Tensor) -> Tensor: return self.language_model.embed(input_ids) diff --git a/python/mlc_llm/model/mixtral/mixtral_model.py b/python/mlc_llm/model/mixtral/mixtral_model.py index ec8025f3dc..db41dc31ce 100644 --- a/python/mlc_llm/model/mixtral/mixtral_model.py +++ b/python/mlc_llm/model/mixtral/mixtral_model.py @@ -49,11 +49,13 @@ def __init__(self, config: MixtralConfig): self.num_local_experts, in_features=config.hidden_size, out_features=2 * self.intermediate_size, + tensor_parallel_shards=config.tensor_parallel_shards, ) self.e2 = MixtralExperts( self.num_local_experts, in_features=self.intermediate_size, out_features=config.hidden_size, + tensor_parallel_shards=config.tensor_parallel_shards, ) self.dtype = "float32" diff --git a/python/mlc_llm/model/mixtral/mixtral_quantization.py b/python/mlc_llm/model/mixtral/mixtral_quantization.py index 0e8130e051..1b5dc1e9bd 100644 --- a/python/mlc_llm/model/mixtral/mixtral_quantization.py +++ b/python/mlc_llm/model/mixtral/mixtral_quantization.py @@ -1,11 +1,18 @@ """This file specifies how MLC's Mistral parameters are quantized using group quantization or other formats.""" + from typing import Tuple from tvm.relax.frontend import nn from mlc_llm.loader import QuantizeMapping -from mlc_llm.quantization import AWQQuantize, FTQuantize, GroupQuantize, NoQuantize +from mlc_llm.quantization import ( + AWQQuantize, + FTQuantize, + GroupQuantize, + NoQuantize, + PerTensorQuantize, +) from .mixtral_model import MixtralConfig, MixtralForCasualLM @@ -59,3 +66,19 @@ def no_quant( model.to(quantization.model_dtype) quant_map = QuantizeMapping({}, {}) return model, quant_map + + +def per_tensor_quant( + model_config: MixtralConfig, + quantization: PerTensorQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Mixtral model using per-tensor quantization.""" + model: nn.Module = MixtralForCasualLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map diff --git a/python/mlc_llm/model/model.py b/python/mlc_llm/model/model.py index c301634522..1c513e15d3 100644 --- a/python/mlc_llm/model/model.py +++ b/python/mlc_llm/model/model.py @@ -9,6 +9,8 @@ from mlc_llm.quantization.quantization import Quantization from .baichuan import baichuan_loader, baichuan_model, baichuan_quantization +from .chatglm3 import chatglm3_loader, chatglm3_model, chatglm3_quantization +from .eagle import eagle_loader, eagle_model, eagle_quantization from .gemma import gemma_loader, gemma_model, gemma_quantization from .gpt2 import gpt2_loader, gpt2_model, gpt2_quantization from .gpt_bigcode import gpt_bigcode_loader, gpt_bigcode_model, gpt_bigcode_quantization @@ -141,6 +143,7 @@ class Model: "no-quant": mixtral_quantization.no_quant, "group-quant": mixtral_quantization.group_quant, "ft-quant": mixtral_quantization.ft_quant, + "per-tensor-quant": mixtral_quantization.per_tensor_quant, }, ), "gpt_neox": Model( @@ -324,4 +327,33 @@ class Model: "group-quant": rwkv6_quantization.group_quant, }, ), + "chatglm": Model( + name="chatglm", + model=chatglm3_model.ChatGLMForCausalLM, + config=chatglm3_model.GLMConfig, + source={ + "huggingface-torch": chatglm3_loader.huggingface, + "huggingface-safetensor": chatglm3_loader.huggingface, + }, + quantize={ + "no-quant": chatglm3_quantization.no_quant, + "group-quant": chatglm3_quantization.group_quant, + }, + ), + "eagle": Model( + name="eagle", + model=eagle_model.EagleForCasualLM, + config=eagle_model.EagleConfig, + source={ + "huggingface-torch": eagle_loader.huggingface, + "huggingface-safetensor": eagle_loader.huggingface, + "awq": eagle_loader.awq, + }, + quantize={ + "no-quant": eagle_quantization.no_quant, + "group-quant": eagle_quantization.group_quant, + "ft-quant": eagle_quantization.ft_quant, + "awq": eagle_quantization.awq_quant, + }, + ), } diff --git a/python/mlc_llm/model/model_preset.py b/python/mlc_llm/model/model_preset.py index 8e87217d35..3bfe1cb891 100644 --- a/python/mlc_llm/model/model_preset.py +++ b/python/mlc_llm/model/model_preset.py @@ -623,4 +623,41 @@ "vision_feature_select_strategy": "default", "vocab_size": 32064, }, + "chatglm": { + "architectures": ["ChatGLMModel"], + "model_type": "chatglm", + "auto_map": { + "AutoConfig": "configuration_chatglm.ChatGLMConfig", + "AutoModel": "modeling_chatglm.ChatGLMForConditionalGeneration", + "AutoModelForCausalLM": "modeling_chatglm.ChatGLMForConditionalGeneration", + }, + "add_bias_linear": False, + "add_qkv_bias": True, + "apply_query_key_layer_scaling": True, + "apply_residual_connection_post_layernorm": False, + "attention_dropout": 0.0, + "attention_softmax_in_fp32": True, + "bias_dropout_fusion": True, + "ffn_hidden_size": 13696, + "fp32_residual_connection": False, + "hidden_dropout": 0.0, + "hidden_size": 4096, + "kv_channels": 128, + "layernorm_epsilon": 1e-05, + "multi_query_attention": True, + "multi_query_group_num": 2, + "num_attention_heads": 32, + "num_layers": 28, + "original_rope": True, + "padded_vocab_size": 65024, + "post_layer_norm": True, + "rmsnorm": True, + "seq_length": 8192, + "use_cache": True, + "torch_dtype": "float16", + "transformers_version": "4.30.2", + "tie_word_embeddings": False, + "eos_token_id": 2, + "pad_token_id": 0, + }, } diff --git a/python/mlc_llm/op/attention.py b/python/mlc_llm/op/attention.py index 801dbd66ba..dc41a5f5ef 100644 --- a/python/mlc_llm/op/attention.py +++ b/python/mlc_llm/op/attention.py @@ -103,12 +103,12 @@ def _fallback(): and k.dtype == "float16" and v.dtype == "float16" ): - if group_size not in [1, 4, 8]: + if group_size not in [1, 4, 6, 8]: global WARN_FLASHINFER_GROUP_SIZE # pylint: disable=global-statement if not WARN_FLASHINFER_GROUP_SIZE: WARN_FLASHINFER_GROUP_SIZE = True logger.warning( - "FlashInfer only supports group size in [1, 4, 8], but got %d. Skip and " + "FlashInfer only supports group size in [1, 4, 6, 8], but got %d. Skip and " "fallback to default implementation.", group_size, ) diff --git a/python/mlc_llm/op/cutlass.py b/python/mlc_llm/op/cutlass.py index 275d61f20a..6b0e21578e 100644 --- a/python/mlc_llm/op/cutlass.py +++ b/python/mlc_llm/op/cutlass.py @@ -45,23 +45,22 @@ def group_gemm( assert x.ndim == 2 assert weight.ndim == 3 assert indptr.ndim == 1 - assert weight.shape[2] == x.shape[1] assert weight.shape[0] == indptr.shape[0] assert indptr.dtype == "int64" out_dtype = out_dtype if out_dtype else x.dtype weight_dtype = weight_dtype if weight_dtype else weight.dtype - if x.dtype == "e5m2_float8" and weight.dtype == "e5m2_float8" and out_dtype == "float16": + if x.dtype == "e5m2_float8" and weight_dtype == "e5m2_float8" and out_dtype == "float16": func_name = "cutlass.group_gemm_e5m2_e5m2_fp16" - elif x.dtype == "e4m3_float8" and weight.dtype == "e5m2_float8" and out_dtype == "float16": + elif x.dtype == "e4m3_float8" and weight_dtype == "e5m2_float8" and out_dtype == "float16": func_name = "cutlass.group_gemm_e4m3_e5m2_fp16" - elif x.dtype == "e4m3_float8" and weight.dtype == "e4m3_float8" and out_dtype == "float16": + elif x.dtype == "e4m3_float8" and weight_dtype == "e4m3_float8" and out_dtype == "float16": func_name = "cutlass.group_gemm_e4m3_e4m3_fp16" - elif x.dtype == "float16" and weight.dtype == "float16" and out_dtype == "float16": + elif x.dtype == "float16" and weight_dtype == "float16" and out_dtype == "float16": func_name = "cutlass.group_gemm_fp16_sm90" else: raise NotImplementedError( - f"Unsupported data type: x={x.dtype}, weight={weight.dtype}, out={out_dtype}" + f"Unsupported data type: x={x.dtype}, weight={weight_dtype}, out={out_dtype}" ) if "float8" in x.dtype: diff --git a/python/mlc_llm/op/moe_matmul.py b/python/mlc_llm/op/moe_matmul.py index c0d880c76c..6978d8ba0e 100644 --- a/python/mlc_llm/op/moe_matmul.py +++ b/python/mlc_llm/op/moe_matmul.py @@ -1,5 +1,8 @@ """Mixture of Experts operators""" -from tvm import DataType, tir, DataTypeCode + +from typing import Literal, Optional + +from tvm import DataType, tir from tvm.relax.frontend.nn import Tensor, op from tvm.script import tir as T @@ -123,7 +126,7 @@ def dequantize_gemv( # pylint: disable=too-many-arguments num_elem_per_storage = DataType(storage_dtype).bits // quantize_dtype_bits num_group = (in_features + group_size - 1) // group_size num_storage = group_size // num_elem_per_storage * num_group - + def _dequantize(w, s, e, i, j): tir_bin_mask = tir.const((2**quantize_dtype_bits) - 1, storage_dtype) tir_max_int = tir.const((2 ** (quantize_dtype_bits - 1)) - 1, model_dtype) @@ -138,9 +141,11 @@ def _dequantize_e4m3(w, s, e, i, j): w = w[e, i, j // num_elem_per_storage] s = s[e, i, j // group_size] shift = (j % num_elem_per_storage * quantize_dtype_bits).astype(storage_dtype) - w = tir.reinterpret(DataType(quantize_dtype), tir.bitwise_and(tir.shift_right(w, shift), tir_bin_mask)).astype(model_dtype) + w = tir.reinterpret( + DataType(quantize_dtype), tir.bitwise_and(tir.shift_right(w, shift), tir_bin_mask) + ).astype(model_dtype) return w * s - + if DataType(quantize_dtype).type_code == DataTypeCode.E4M3Float: dequantize_func = _dequantize_e4m3 elif DataType(quantize_dtype).type_code == DataTypeCode.E5M2Float: @@ -157,7 +162,6 @@ def access_x(x, e, j): assert indptr.shape == [1, experts_per_tok] and indptr.dtype == "int32" assert x_leading_dim in [1, experts_per_tok] - @T.prim_func(private=True) def _func( x: T.Buffer((x_leading_dim, in_features), model_dtype), @@ -189,6 +193,7 @@ def _func( out=Tensor.placeholder([experts_per_tok, out_features], model_dtype), ) + def dequantize_gemv_no_scale( # pylint: disable=too-many-arguments x: Tensor, w: Tensor, @@ -240,9 +245,11 @@ def _dequantize_e5m2(w, e, i, j): tir_bin_mask = tir.const((2**quantize_dtype_bits) - 1, storage_dtype) w = w[e, i, j // num_elem_per_storage] shift = (j % num_elem_per_storage * quantize_dtype_bits).astype(storage_dtype) - w = tir.reinterpret(DataType(quantize_dtype), tir.bitwise_and(tir.shift_right(w, shift), tir_bin_mask)).astype(model_dtype) + w = tir.reinterpret( + DataType(quantize_dtype), tir.bitwise_and(tir.shift_right(w, shift), tir_bin_mask) + ).astype(model_dtype) return w - + if DataType(quantize_dtype).type_code == DataTypeCode.E5M2Float: dequantize_func = _dequantize_e5m2 else: @@ -256,7 +263,6 @@ def access_x(x, e, j): assert indptr.shape == [1, experts_per_tok] and indptr.dtype == "int32" assert x_leading_dim in [1, experts_per_tok] - @T.prim_func(private=True) def _func( x: T.Buffer((x_leading_dim, in_features), model_dtype), @@ -288,6 +294,124 @@ def _func( ) +def dequantize_float8_gemv( + x: Tensor, + w: Tensor, + scale: Optional[Tensor], + indptr: Tensor, + quantize_dtype: Literal["e5m2_float8", "e4m3_float8"], +) -> Tensor: + """GEMV for project-in (e1-e3) or project-out (e2) in MLP but the weight is quantized in + fp8 e5m2 or e4m3. It needs to be dequantized before the GEMV computation. + + Parameters + ---------- + x : Tensor + For project-in, the input tensor of shape (1, in_features); and for project-out, the input + shape is (experts_per_tok, in_features), where `experts_per_tok` is the number of activated + experts per token. + + w : Tensor + The quantized weight tensor of shape (local_experts, out_features, in_features) + + scale : Optional[Tensor] + The optional scale tensor of shape (1,) + + indptr : Tensor + The index pointer tensor of shape (1, experts_per_tok), where `experts_per_tok` is the + number of activated experts per token. + + quantize_dtype : Literal["e5m2_float8", "e4m3_float8"] + The quantize dtype of the weight tensor, which is either e5m2_float8 or e4m3_float8. + """ + (x_leading_dim, in_features), model_dtype = x.shape, x.dtype + (local_experts, out_features, _), storage_dtype = w.shape, w.dtype + _, experts_per_tok = indptr.shape + quantize_dtype_bits = DataType(quantize_dtype).bits + num_elem_per_storage = DataType(storage_dtype).bits // quantize_dtype_bits + num_storage = tir.ceildiv(in_features, num_elem_per_storage) + + def _dequantize(w, s, e, i, j): + if num_elem_per_storage == 1: + w = tir.reinterpret(quantize_dtype, w[e, i, j]) + else: + tir_bin_mask = tir.const((2**quantize_dtype_bits) - 1, storage_dtype) + w = w[e, i, j // num_elem_per_storage] + shift = (j % num_elem_per_storage * quantize_dtype_bits).astype(storage_dtype) + w = tir.reinterpret( + quantize_dtype, + tir.bitwise_and(tir.shift_right(w, shift), tir_bin_mask).astype("uint8"), + ) + w = w.astype(model_dtype) + if s is not None: + w = w * s[0] + return w + + def access_x(x, e, j): + return x[0, j] if x_leading_dim == 1 else x[e, j] + + @T.prim_func(private=True) + def _func_with_scale( + x: T.Buffer((x_leading_dim, in_features), model_dtype), + w: T.Buffer((local_experts, out_features, num_storage), storage_dtype), + scale: T.Buffer((1,), model_dtype), + indptr: T.Buffer((1, experts_per_tok), "int32"), + o: T.Buffer((experts_per_tok, out_features), model_dtype), + ): + T.func_attr({"op_pattern": 4, "tir.noalias": True}) # kOutEWiseFusable + for expert_id in T.thread_binding(experts_per_tok, thread="blockIdx.y"): + with T.block("gemv_o"): + e = T.axis.spatial(experts_per_tok, expert_id) + y = T.alloc_buffer((out_features, in_features), model_dtype) + for i1, i2 in T.grid(out_features, in_features): + with T.block("dequantize"): + i, j = T.axis.remap("SS", [i1, i2]) + y[i, j] = _dequantize(w, scale, indptr[0, e], i, j) + for i1, i2 in T.grid(out_features, in_features): + with T.block("gemv"): + i, j = T.axis.remap("SR", [i1, i2]) + with T.init(): + o[e, i] = T.cast(T.float16(0), model_dtype) + o[e, i] += access_x(x, e, j) * y[i, j] + + @T.prim_func(private=True) + def _func_without_scale( + x: T.Buffer((x_leading_dim, in_features), model_dtype), + w: T.Buffer((local_experts, out_features, num_storage), storage_dtype), + indptr: T.Buffer((1, experts_per_tok), "int32"), + o: T.Buffer((experts_per_tok, out_features), model_dtype), + ): + T.func_attr({"op_pattern": 4, "tir.noalias": True}) # kOutEWiseFusable + for expert_id in T.thread_binding(experts_per_tok, thread="blockIdx.y"): + with T.block("gemv_o"): + e = T.axis.spatial(experts_per_tok, expert_id) + y = T.alloc_buffer((out_features, in_features), model_dtype) + for i1, i2 in T.grid(out_features, in_features): + with T.block("dequantize"): + i, j = T.axis.remap("SS", [i1, i2]) + y[i, j] = _dequantize(w, None, indptr[0, e], i, j) + for i1, i2 in T.grid(out_features, in_features): + with T.block("gemv"): + i, j = T.axis.remap("SR", [i1, i2]) + with T.init(): + o[e, i] = T.cast(T.float16(0), model_dtype) + o[e, i] += access_x(x, e, j) * y[i, j] + + if scale is not None: + return op.tensor_ir_op( + _func_with_scale, + "moe_dequantize_gemv", + args=[x, w, scale, indptr], + out=Tensor.placeholder([experts_per_tok, out_features], model_dtype), + ) + return op.tensor_ir_op( + _func_without_scale, + "moe_dequantize_gemv", + args=[x, w, indptr], + out=Tensor.placeholder([experts_per_tok, out_features], model_dtype), + ) + + def group_gemm(x: Tensor, w: Tensor, indptr: Tensor): # pylint: disable=too-many-statements """Group GEMM in MoE models. @@ -525,14 +649,18 @@ def _dequantize_e4m3(w, s, e, i, j): w = w[e, i, j // num_elem_per_storage] s = s[e, i, j // group_size] shift = (j % num_elem_per_storage * quantize_dtype_bits).astype(storage_dtype) - w = tir.reinterpret(DataType(quantize_dtype), tir.bitwise_and(tir.shift_right(w, shift), tir_bin_mask)).astype(model_dtype) + w = tir.reinterpret( + DataType(quantize_dtype), tir.bitwise_and(tir.shift_right(w, shift), tir_bin_mask) + ).astype(model_dtype) return w * s - def _dequantize_e5m2(w, s, e, i, j): # TODO(jmcmahan): scale argument? + def _dequantize_e5m2(w, s, e, i, j): # TODO(jmcmahan): scale argument? tir_bin_mask = tir.const((2**quantize_dtype_bits) - 1, storage_dtype) w = w[e, i, j // num_elem_per_storage] shift = (j % num_elem_per_storage * quantize_dtype_bits).astype(storage_dtype) - w = tir.reinterpret(DataType(quantize_dtype), tir.bitwise_and(tir.shift_right(w, shift), tir_bin_mask)).astype(model_dtype) + w = tir.reinterpret( + DataType(quantize_dtype), tir.bitwise_and(tir.shift_right(w, shift), tir_bin_mask) + ).astype(model_dtype) return w if DataType(quantize_dtype).type_code == DataTypeCode.E4M3Float: @@ -541,7 +669,7 @@ def _dequantize_e5m2(w, s, e, i, j): # TODO(jmcmahan): scale argument? dequantize_func = _dequantize_e5m2 else: dequantize_func = _dequantize - + Ne, N, K = num_local_experts, out_features, in_features BLK_M, BLK_N, BLK_K = 8, 128, 32 TX, TY, CTA_COUNT = 8, 32, 1024 @@ -689,6 +817,7 @@ def _cooperative_fetch(block, vec_len): out=Tensor.placeholder([x.shape[0], out_features], model_dtype), ) + def dequantize_group_gemm_no_scale( x: Tensor, w: Tensor, @@ -738,14 +867,16 @@ def _dequantize_e5m2(w, e, i, j): tir_bin_mask = tir.const((2**quantize_dtype_bits) - 1, storage_dtype) w = w[e, i, j // num_elem_per_storage] shift = (j % num_elem_per_storage * quantize_dtype_bits).astype(storage_dtype) - w = tir.reinterpret(DataType(quantize_dtype), tir.bitwise_and(tir.shift_right(w, shift), tir_bin_mask)).astype(model_dtype) + w = tir.reinterpret( + DataType(quantize_dtype), tir.bitwise_and(tir.shift_right(w, shift), tir_bin_mask) + ).astype(model_dtype) return w if DataType(quantize_dtype).type_code == DataTypeCode.E5M2Float: dequantize_func = _dequantize_e5m2 else: assert False, "Only FP8 E5M2 is supported for no-scale dequantization" - + Ne, N, K = num_local_experts, out_features, in_features BLK_M, BLK_N, BLK_K = 8, 128, 32 TX, TY, CTA_COUNT = 8, 32, 1024 diff --git a/python/mlc_llm/op/moe_misc.py b/python/mlc_llm/op/moe_misc.py index 6dc7f33265..ff5e50c60c 100644 --- a/python/mlc_llm/op/moe_misc.py +++ b/python/mlc_llm/op/moe_misc.py @@ -5,9 +5,6 @@ from tvm import te, tir from tvm.relax.frontend.nn import Tensor, op from tvm.script import tir as T -from tvm.target import Target -from tvm.topi.cuda.scan import inclusive_scan -from tvm.topi.cuda.sort import topk as topi_topk # mypy: disable-error-code="attr-defined,name-defined" # pylint: disable=line-too-long,too-many-locals,invalid-name @@ -120,7 +117,9 @@ def topk_softmax_func( Tensor.placeholder([batch_size, 2], index_dtype), ), ) - expert_score, expert_indices = op.tensor_expr_op(topi_topk, "topk", args=[x, k, -1, "both", False, index_dtype]) # type: ignore[list-item] + expert_score, expert_indices = op.topk( + x, k, axis=-1, ret_type="both", largest=True, dtype=index_dtype + ) expert_score = op.softmax(expert_score.astype("float32"), axis=-1).astype(dtype) return expert_score, expert_indices @@ -203,14 +202,8 @@ def moe_cumsum(expert_indices: Tensor, num_local_experts: int) -> Tensor: .permute_dims(1, 0) .reshape(batch_size * num_local_experts) ) - with Target.current(allow_none=True) or Target( - { - "kind": "cuda", - "max_num_threads": 1024, - "arch": "sm_50", - } - ): - return op.tensor_expr_op(inclusive_scan, "cumsum", args=[expert_mask, 0, "int32"]) # type: ignore[list-item] + + return op.cumsum(expert_mask, axis=0, exclusive=False, dtype="int32") def get_indices(cumsum: Tensor, expert_indices: Tensor) -> Tuple[Tensor, Tensor]: diff --git a/python/mlc_llm/protocol/conversation_protocol.py b/python/mlc_llm/protocol/conversation_protocol.py index 1c2a3cb2e4..482cce54c8 100644 --- a/python/mlc_llm/protocol/conversation_protocol.py +++ b/python/mlc_llm/protocol/conversation_protocol.py @@ -5,8 +5,6 @@ from pydantic import BaseModel, Field, field_validator -from ..serve import data - # The message placeholders in the message prompts according to roles. class MessagePlaceholders(Enum): @@ -113,17 +111,25 @@ def from_json_dict(cls: Type[T], json_dict: Dict[str, Any]) -> T: return Conversation.model_validate(json_dict) # pylint: disable=too-many-branches - def as_prompt(self, config=None) -> List[Union[str, data.ImageData]]: + def as_prompt(self, config=None) -> List[Any]: """Convert the conversation template and history messages to a single prompt. + + Returns + ------- + prompts : List[Union[str, "mlc_llm.serve.data.Data"]] + The prompts converted from the conversation messages. + We use Any in the signature to avoid cyclic import. """ + from ..serve import data # pylint: disable=import-outside-toplevel + # - Get the system message. system_msg = self.system_template.replace( MessagePlaceholders.SYSTEM.value, self.system_message ) # - Get the message strings. - message_list: List[Union[str, data.ImageData]] = [] + message_list: List[Union[str, data.Data]] = [] separators = list(self.seps) if len(separators) == 1: separators.append(separators[0]) @@ -136,55 +142,48 @@ def as_prompt(self, config=None) -> List[Union[str, data.ImageData]]: if role not in self.roles.keys(): raise ValueError(f'Role "{role}" is not a supported role in {self.roles.keys()}') separator = separators[role == "assistant"] # check assistant role - if content is not None: - role_prefix = ( - "" - # Do not append role prefix if this is the first message and there - # is already a system message - if (not self.add_role_after_system_message and system_msg != "" and i == 0) - else self.roles[role] + self.role_content_sep + + if content is None: + message_list.append(self.roles[role] + self.role_empty_sep) + continue + + role_prefix = ( + "" + # Do not append role prefix if this is the first message and there + # is already a system message + if (not self.add_role_after_system_message and system_msg != "" and i == 0) + else self.roles[role] + self.role_content_sep + ) + if isinstance(content, str): + message_list.append( + role_prefix + + self.role_templates[role].replace( + MessagePlaceholders[role.upper()].value, content + ) + + separator ) - if isinstance(content, str): - message_string = ( - role_prefix - + self.role_templates[role].replace( - MessagePlaceholders[role.upper()].value, content - ) - + separator + continue + + message_list.append(role_prefix) + + for item in content: + assert isinstance(item, dict), "Content should be a string or a list of dicts" + assert "type" in item, "Content item should have a type field" + if item["type"] == "text": + message = self.role_templates[role].replace( + MessagePlaceholders[role.upper()].value, item["text"] ) - message_list.append(message_string) + message_list.append(message) + elif item["type"] == "image_url": + assert config is not None, "Model config is required" + image_url = _get_url_from_item(item) + message_list.append(data.ImageData.from_url(image_url, config)) else: - message_list.append(role_prefix) - for item in content: - assert isinstance( - item, dict - ), "Content should be a string or a list of dicts" - assert "type" in item, "Content item should have a type field" - if item["type"] == "text": - message_list.append( - self.role_templates[role].replace( - MessagePlaceholders[role.upper()].value, item["text"] - ) - ) - elif item["type"] == "image_url": - assert config is not None, "Model config is required" - - # pylint: disable=import-outside-toplevel - from ..serve.entrypoints.entrypoint_utils import ( - get_image_from_url, - ) - - image_url = _get_url_from_item(item) - message_list.append(get_image_from_url(image_url, config)) - else: - raise ValueError(f"Unsupported content type: {item['type']}") - - message_list.append(separator) - else: - message_string = self.roles[role] + self.role_empty_sep - message_list.append(message_string) - - prompt = _combine_consecutive_strings(message_list) + raise ValueError(f"Unsupported content type: {item['type']}") + + message_list.append(separator) + + prompt = _combine_consecutive_messages(message_list) if not any(isinstance(item, data.ImageData) for item in message_list): # Replace the last function string placeholder with actual function string @@ -215,11 +214,27 @@ def _get_url_from_item(item: Dict) -> str: return image_url -def _combine_consecutive_strings(lst): - result = [] - for item in lst: - if isinstance(item, str) and result and isinstance(result[-1], str): - result[-1] += item +def _combine_consecutive_messages(messages: List[Any]) -> List[Any]: + """Combining consecutive strings into one. + + Parameters + ---------- + messages : List[Union[str, "mlc_llm.serve.data.Data"]] + The input messages to be combined. + We use Any in the signature to avoid cyclic import. + + Returns + ------- + updated_messages : List[Union[str, "mlc_llm.serve.data.Data"]] + The combined messages + """ + if len(messages) == 0: + return [] + + combined_messages = [messages[0]] + for message in messages[1:]: + if isinstance(message, str) and isinstance(combined_messages[-1], str): + combined_messages[-1] += message else: - result.append(item) - return result + combined_messages.append(message) + return combined_messages diff --git a/python/mlc_llm/protocol/error_protocol.py b/python/mlc_llm/protocol/error_protocol.py new file mode 100644 index 0000000000..83a201f578 --- /dev/null +++ b/python/mlc_llm/protocol/error_protocol.py @@ -0,0 +1,34 @@ +"""Error protocols in MLC LLM""" + +from http import HTTPStatus + +import fastapi +from pydantic import BaseModel + + +class BadRequestError(ValueError): + """The exception for bad requests in engines.""" + + def __init__(self, *args: object) -> None: + super().__init__(*args) + + +class ErrorResponse(BaseModel): + """The class of error response.""" + + object: str = "error" + message: str + code: int = None + + +def create_error_response(status_code: HTTPStatus, message: str) -> fastapi.responses.JSONResponse: + """Create a JSON response that reports error with regarding the input message.""" + return fastapi.responses.JSONResponse( + ErrorResponse(message=message, code=status_code.value).model_dump_json(), + status_code=status_code.value, + ) + + +async def bad_request_error_handler(_request: fastapi.Request, e: BadRequestError): + """The handler of BadRequestError that converts an exception into error response.""" + return create_error_response(status_code=HTTPStatus.BAD_REQUEST, message=e.args[0]) diff --git a/python/mlc_llm/protocol/protocol_utils.py b/python/mlc_llm/protocol/protocol_utils.py index a9a68a1f82..f4273d0302 100644 --- a/python/mlc_llm/protocol/protocol_utils.py +++ b/python/mlc_llm/protocol/protocol_utils.py @@ -2,8 +2,6 @@ from typing import Any, Dict, List, Optional -from pydantic import BaseModel - from ..serve.config import GenerationConfig from . import RequestProtocol from .openai_api_protocol import ChatCompletionRequest as OpenAIChatCompletionRequest @@ -14,14 +12,6 @@ ) -class ErrorResponse(BaseModel): - """The class of error response.""" - - object: str = "error" - message: str - code: int = None - - def get_unsupported_fields(request: RequestProtocol) -> List[str]: """Get the unsupported fields of the request. Return the list of unsupported field names. diff --git a/python/mlc_llm/serve/data.py b/python/mlc_llm/serve/data.py index 8444e3f363..1c56178ad1 100644 --- a/python/mlc_llm/serve/data.py +++ b/python/mlc_llm/serve/data.py @@ -1,8 +1,9 @@ """Classes denoting multi-modality data used in MLC LLM serving""" from dataclasses import dataclass -from typing import List, Optional, Tuple +from typing import Dict, List, Optional, Tuple +import tvm import tvm._ffi from tvm.runtime import Object from tvm.runtime.ndarray import NDArray @@ -81,6 +82,60 @@ def image(self) -> NDArray: def __len__(self): return self.embed_size + @staticmethod + def from_url(url: str, config: Dict) -> "ImageData": # pylint: disable=too-many-locals + """Get the image from the given URL, process and return the image tensor as TVM NDArray.""" + + # pylint: disable=import-outside-toplevel, import-error + import base64 + from io import BytesIO + + import requests + from PIL import Image + from transformers import CLIPImageProcessor + + if url.startswith("data:image"): + # The image is encoded in base64 format + base64_image = url.split(",")[1] + image_data = base64.b64decode(base64_image) + image_tensor = Image.open(BytesIO(image_data)).convert("RGB") + elif url.startswith("http"): + response = requests.get(url, timeout=5) + image_tensor = Image.open(BytesIO(response.content)).convert("RGB") + else: + raise ValueError(f"Unsupported image URL format: {url}") + + image_input_size = ImageData.get_input_size(config) + image_embed_size = ImageData.get_embed_size(config) + + image_processor = CLIPImageProcessor( + size={"shortest_edge": image_input_size}, + crop_size={"height": image_input_size, "width": image_input_size}, + ) + quantization = config["quantization"] + out_dtype = "float16" if "f16" in quantization else "float32" + image_features = tvm.nd.array( + image_processor.preprocess(image_tensor, return_tensors="np")["pixel_values"].astype( + out_dtype + ) + ) + image_data = ImageData(image_features, image_embed_size) + return image_data + + @staticmethod + def get_embed_size(config: Dict) -> int: + """Get the image embedding size from the model config file.""" + image_size = config["model_config"]["vision_config"]["image_size"] + patch_size = config["model_config"]["vision_config"]["patch_size"] + embed_size = (image_size // patch_size) ** 2 + return embed_size + + @staticmethod + def get_input_size(config: Dict) -> int: + """Get the image input size from the model config file.""" + image_size = config["model_config"]["vision_config"]["image_size"] + return image_size + @dataclass class SingleRequestStreamOutput: diff --git a/python/mlc_llm/serve/engine_base.py b/python/mlc_llm/serve/engine_base.py new file mode 100644 index 0000000000..4c95f6e612 --- /dev/null +++ b/python/mlc_llm/serve/engine_base.py @@ -0,0 +1,1414 @@ +"""The MLC LLM Serving engine base class.""" + +# pylint: disable=too-many-lines + +import ast +import asyncio +import json +import queue +import subprocess +import sys +import threading +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union + +import tvm +from tvm.runtime import Device + +from mlc_llm.chat_module import _get_chat_config, _get_lib_module_path, _get_model_path +from mlc_llm.protocol import openai_api_protocol, protocol_utils +from mlc_llm.protocol.conversation_protocol import Conversation +from mlc_llm.serve import data, engine_utils +from mlc_llm.serve.config import EngineConfig, GenerationConfig, SpeculativeMode +from mlc_llm.serve.event_trace_recorder import EventTraceRecorder +from mlc_llm.streamer import TextStreamer +from mlc_llm.support import logging +from mlc_llm.support.auto_device import detect_device +from mlc_llm.support.style import green +from mlc_llm.tokenizer import Tokenizer + +logging.enable_logging() +logger = logging.getLogger(__name__) + + +@dataclass +class ModelInfo: + """The model info dataclass. + + Parameters + ---------- + model : str + The identifier of the input model. + It may be a compiled model's id (e.g., "Llama-2-7b-chat-hf-q4f16_1"), + or a full path to a model directory + (e.g., "dist/prebuilt/mlc-chat-Llama-2-7b-chat-hf-q4f16_1") + + model_lib_path : Optional[str] + The path to the compiled library of the model. + E.g., "dist/prebuilt/lib/Llama-2-7b-chat-hf-q4f16_1-cuda.so" + """ + + model: str + model_lib_path: Optional[str] = None + + +def _parse_models( + model: str, model_lib_path: Optional[str], additional_models: Optional[List[str]] +) -> List[ModelInfo]: + """Parse the specified model paths and model lib paths. + Return a list of ModelInfo, which is a wrapper class of the model path + lib path. + + Each additional model is expected to follow the format of either + "{MODEL_PATH}" or "{MODEL_PATH}:{MODEL_LIB_PATH}". + """ + models = [ModelInfo(model, model_lib_path)] + if additional_models is not None: + for additional_model in additional_models: + splits = additional_model.split(":", maxsplit=1) + if len(splits) == 2: + models.append(ModelInfo(splits[0], splits[1])) + else: + models.append(ModelInfo(splits[0])) + return models + + +def _process_model_args( + models: List[ModelInfo], device: tvm.runtime.Device +) -> Tuple[List[Tuple[str, str]], List[str], Conversation]: + """Process the input ModelInfo to get the engine initialization arguments.""" + conversation: Optional[Conversation] = None + config_file_paths: List[str] = [] + + def _convert_model_info(model: ModelInfo) -> Tuple[str, str]: + nonlocal conversation + + model_path, config_file_path = _get_model_path(model.model) + config_file_paths.append(config_file_path) + chat_config = _get_chat_config(config_file_path, user_chat_config=None) + if conversation is None: + assert isinstance(chat_config.conv_template, Conversation) + conversation = chat_config.conv_template + # Try look up model library, and do JIT compile if model library not found. + try: + model_lib_path = _get_lib_module_path( + model=model.model, + model_path=model_path, + chat_config=chat_config, + model_lib_path=model.model_lib_path, + device_name=device.MASK2STR[device.device_type], + config_file_path=config_file_path, + ) + except FileNotFoundError: + from mlc_llm.interface import jit # pylint: disable=import-outside-toplevel + + model_lib_path = str( + jit.jit( + model_path=Path(model_path), + chat_config=asdict(chat_config), + device=device, + ) + ) + return model_path, model_lib_path + + model_args: List[Tuple[str, str]] = [_convert_model_info(model) for model in models] + + assert conversation is not None + return model_args, config_file_paths, conversation + + +def _estimate_mem_usage_and_max_total_sequence_length( # pylint: disable=too-many-locals,too-many-arguments + models: List[ModelInfo], + device: tvm.runtime.Device, + model_config_paths: List[str], + model_config_dicts: List[Dict[str, Any]], + max_num_sequence: int, + gpu_memory_utilization: Optional[float], +) -> Tuple[float, float, float, float, float, int]: + """Estimate the memory usage and the max total sequence length (capacity) + that the KV cache can support. + """ + assert len(models) != 0 + + kv_bytes_per_token = 0 + kv_aux_workspace_bytes = 0 + model_workspace_bytes = 0 + logit_processor_workspace_bytes = 0 + params_bytes = 0 + temp_func_bytes = 0 + + for model, model_config_path, model_config_dict in zip( + models, model_config_paths, model_config_dicts + ): + # Read metadata for the parameter size and the temporary memory size. + cmd = [ + sys.executable, + "-m", + "mlc_llm.cli.model_metadata", + model.model_lib_path, + "--print-memory-usage-in-json", + "--mlc-chat-config", + model_config_path, + ] + usage_str = subprocess.check_output(cmd, universal_newlines=True) + usage_json = json.loads(usage_str) + params_bytes += usage_json["params_bytes"] + temp_func_bytes = max(temp_func_bytes, usage_json["temp_func_bytes"]) + + cmd = [ + sys.executable, + "-m", + "mlc_llm.cli.model_metadata", + model.model_lib_path, + "--print-kv-cache-metadata-in-json", + ] + kv_cache_metadata_str = subprocess.check_output(cmd, universal_newlines=True) + kv_cache_metadata = json.loads(kv_cache_metadata_str) + + # Read model config and compute the kv size per token. + model_config = model_config_dict["model_config"] + vocab_size = model_config["vocab_size"] + prefill_chunk_size = model_config["prefill_chunk_size"] + num_layers = kv_cache_metadata["num_hidden_layers"] + head_dim = kv_cache_metadata["head_dim"] + num_qo_heads = kv_cache_metadata["num_attention_heads"] + num_kv_heads = kv_cache_metadata["num_key_value_heads"] + hidden_size = head_dim * num_qo_heads + kv_bytes_per_token += head_dim * num_kv_heads * num_layers * 4 + 1.25 + kv_aux_workspace_bytes += ( + (max_num_sequence + 1) * 88 + + prefill_chunk_size * (num_qo_heads + 1) * 8 + + prefill_chunk_size * head_dim * (num_qo_heads + num_kv_heads) * 4 + + 48 * 1024 * 1024 + ) + model_workspace_bytes += ( + prefill_chunk_size * 4 + + max_num_sequence * 4 + + (prefill_chunk_size * 2 + max_num_sequence) * hidden_size * 2 + ) + logit_processor_workspace_bytes += ( + max_num_sequence * 20 + max_num_sequence * vocab_size * 16.125 + ) + + # Get single-card GPU size. + gpu_size_bytes = device.total_global_memory + if gpu_size_bytes is None: + raise ValueError("Cannot read total GPU global memory from device.") + if gpu_memory_utilization is None: + gpu_memory_utilization = 0.90 + + model_max_total_sequence_length = int( + ( + int(gpu_size_bytes) * gpu_memory_utilization + - params_bytes + - temp_func_bytes + - kv_aux_workspace_bytes + - model_workspace_bytes + - logit_processor_workspace_bytes + ) + / kv_bytes_per_token + ) + if model_max_total_sequence_length <= 0: + raise ValueError( + f"The model weight size {params_bytes} may be larger than available GPU memory " + f"size {gpu_size_bytes * gpu_memory_utilization} bytes." + ) + + if device.device_type == Device.kDLMetal: + # NOTE: Metal runtime has severe performance issues with large buffers. + # To work around the issue, we limit the KV cache capacity to 32768. + model_max_total_sequence_length = min(model_max_total_sequence_length, 32768) + + total_mem_usage_except_kv_cache = ( + params_bytes + + temp_func_bytes + + kv_aux_workspace_bytes + + model_workspace_bytes + + logit_processor_workspace_bytes + ) + return ( + total_mem_usage_except_kv_cache, + params_bytes, + kv_bytes_per_token, + kv_aux_workspace_bytes, + model_workspace_bytes + logit_processor_workspace_bytes + temp_func_bytes, + int(model_max_total_sequence_length), + ) + + +def _get_model_config_limit(model_config_dicts: List[Dict[str, Any]]) -> Tuple[int, int, int]: + """Read the model config dictionaries, and return the maximum single + sequence length the models can support, the maximum prefill chunk + size the models can support, and the max batch size the models can support. + + Returns + ------- + model_max_single_sequence_length : int + The maximum single sequence length the models can support. + model_max_prefill_chunk_size : int + The maximum prefill chunk size the models can support. + model_max_batch_size : int + The max batch size the models can support. + """ + model_max_single_sequence_length = int(1e9) + model_max_prefill_chunk_size = int(1e9) + model_max_batch_size = int(1e9) + for i, config in enumerate(model_config_dicts): + runtime_context_window_size = config["context_window_size"] + compile_time_context_window_size = config["model_config"]["context_window_size"] + if runtime_context_window_size > compile_time_context_window_size: + raise ValueError( + f"Model {i}'s runtime context window size ({runtime_context_window_size}) is " + "larger than the context window size used at compile time " + f"({compile_time_context_window_size})" + ) + if runtime_context_window_size == -1 and compile_time_context_window_size != -1: + raise ValueError( + f"Model {i}'s runtime context window size (infinite) is " + "larger than the context window size used at compile time " + f"({compile_time_context_window_size})" + ) + if runtime_context_window_size != -1: + model_max_single_sequence_length = min( + model_max_single_sequence_length, runtime_context_window_size + ) + + runtime_prefill_chunk_size = config["prefill_chunk_size"] + compile_time_prefill_chunk_size = config["model_config"]["prefill_chunk_size"] + if runtime_prefill_chunk_size > compile_time_prefill_chunk_size: + raise ValueError( + f"Model {i}'s runtime prefill chunk size ({runtime_prefill_chunk_size}) is " + "larger than the prefill chunk size used at compile time " + f"({compile_time_prefill_chunk_size})" + ) + model_max_prefill_chunk_size = min(model_max_prefill_chunk_size, runtime_prefill_chunk_size) + + model_max_batch_size = min(model_max_batch_size, config["model_config"]["max_batch_size"]) + + assert model_max_prefill_chunk_size != int(1e9) + assert model_max_batch_size != int(1e9) + return model_max_single_sequence_length, model_max_prefill_chunk_size, model_max_batch_size + + +def _infer_kv_cache_config( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements + mode: Literal["local", "interactive", "server"], + max_batch_size: Optional[int], + max_total_sequence_length: Optional[int], + prefill_chunk_size: Optional[int], + gpu_memory_utilization: Optional[float], + models: List[ModelInfo], + device: tvm.runtime.Device, + model_config_dicts: List[Dict[str, Any]], + model_config_paths: List[str], +) -> Tuple[int, int, int, int]: + """Initialize the KV cache config with user input and GPU memory usage estimation. + The returned four integers are: + - max_batch_size + - max_total_sequence_length + - prefill_chunk_size + - model_max_single_sequence_length + """ + ( + model_max_single_sequence_length, + model_max_prefill_chunk_size, + model_max_batch_size, + ) = _get_model_config_limit(model_config_dicts) + + def infer_args_under_mode( + mode: Literal["local", "interactive", "server"], + max_batch_size: Optional[int], + max_total_sequence_length: Optional[int], + prefill_chunk_size: Optional[int], + ) -> Tuple[Tuple[int, int, int], List[float]]: + logging_msg = "" + # - max_batch_size + if max_batch_size is None: + max_batch_size = ( + min(4, model_max_batch_size) + if mode == "local" + else (1 if mode == "interactive" else model_max_batch_size) + ) + logging_msg += f"max batch size is set to {max_batch_size}, " + else: + logging_msg += f"max batch size {max_batch_size} is specified by user, " + # - infer the maximum total sequence length that can fit GPU memory. + ( + total_mem_usage_except_kv_cache, + model_params_bytes, + kv_bytes_per_token, + kv_aux_workspace_bytes, + temp_workspace_bytes, + model_max_total_sequence_length, + ) = _estimate_mem_usage_and_max_total_sequence_length( + models, + device, + model_config_paths, + model_config_dicts, + max_batch_size, + gpu_memory_utilization, + ) + # - max_total_sequence_length + if max_total_sequence_length is None: + if mode == "local": + max_total_sequence_length = min( + model_max_total_sequence_length, model_max_single_sequence_length, 8192 + ) + elif mode == "interactive": + max_total_sequence_length = min( + model_max_total_sequence_length, model_max_single_sequence_length + ) + else: + max_total_sequence_length = min( + model_max_total_sequence_length, + max_batch_size * model_max_single_sequence_length, + ) + logging_msg += f"max KV cache token capacity is set to {max_total_sequence_length}, " + else: + logging_msg += ( + f"max KV cache token capacity {max_total_sequence_length} is specified by user. " + ) + # - prefill_chunk_size + if prefill_chunk_size is None: + if mode in ["local", "interactive"]: + prefill_chunk_size = min( + model_max_prefill_chunk_size, + model_max_total_sequence_length, + model_max_single_sequence_length, + ) + else: + prefill_chunk_size = model_max_prefill_chunk_size + logging_msg += f"prefill chunk size is set to {prefill_chunk_size}. " + else: + logging_msg += f"prefill chunk size {prefill_chunk_size} is specified by user. " + + if mode == "local": + logging_msg += ( + "We choose small max batch size and KV cache capacity to use less GPU memory." + ) + elif mode == "interactive": + logging_msg += "We fix max batch size to 1 for interactive single sequence use." + else: + logging_msg += ( + "We use as much GPU memory as possible (within the" + " limit of gpu_memory_utilization)." + ) + logger.info('Under mode "%s", %s', mode, logging_msg) + + # - Construct the KV cache config + # - Estimate total GPU memory usage on single GPU. + return (max_batch_size, max_total_sequence_length, prefill_chunk_size), [ + total_mem_usage_except_kv_cache + max_total_sequence_length * kv_bytes_per_token, + model_params_bytes, + kv_bytes_per_token * max_total_sequence_length + kv_aux_workspace_bytes, + temp_workspace_bytes, + ] + + # - Infer KV cache config and estimate memory usage for each mode. + local_kv_cache_config, local_mem_usage_list = infer_args_under_mode( + "local", max_batch_size, max_total_sequence_length, prefill_chunk_size + ) + interactive_kv_cache_config, interactive_mem_usage_list = infer_args_under_mode( + "interactive", max_batch_size, max_total_sequence_length, prefill_chunk_size + ) + server_kv_cache_config, server_mem_usage_list = infer_args_under_mode( + "server", max_batch_size, max_total_sequence_length, prefill_chunk_size + ) + + # - Select the config based on the actual mode. + if mode == "local": + kv_cache_config = local_kv_cache_config + mem_usage_list = local_mem_usage_list + elif mode == "interactive": + kv_cache_config = interactive_kv_cache_config + mem_usage_list = interactive_mem_usage_list + else: + kv_cache_config = server_kv_cache_config + mem_usage_list = server_mem_usage_list + + logger.info( + 'The actual engine mode is "%s". So max batch size is %s, ' + "max KV cache token capacity is %s, prefill chunk size is %s.", + green(mode), + green(str(kv_cache_config[0])), + green(str(kv_cache_config[1])), + green(str(kv_cache_config[2])), + ) + + logger.info( + "%s: %.2f MB (Parameters: %.2f MB. KVCache: %.2f MB. Temporary buffer: %.2f MB). " + "The actual usage might be slightly larger than the estimated number.", + green("Estimated total single GPU memory usage"), + *list(mem_usage / 1024 / 1024 for mem_usage in mem_usage_list), + ) + # - Final messages + override_msg = "Please override the arguments if you have particular values to set." + if mode in ["local", "interactive"]: + logger.info( + 'Please switch to mode "server" if you want to use more GPU memory ' + "and support more concurrent requests. %s", + override_msg, + ) + else: + logger.info( + 'Please switch to mode "local" or "interactive" if you want to use less GPU memory ' + "or do not have many concurrent requests to process. %s", + override_msg, + ) + + return *kv_cache_config, model_max_single_sequence_length + + +@dataclass +class CallbackStreamOutput: + """The output of LLMEngine._generate and AsyncLLMEngine._generate + + Attributes + ---------- + delta_text : str + The delta text generated since the last output. + + num_delta_tokens : int + The number of delta tokens generated since the last output. + + delta_logprob_json_strs : Optional[List[str]] + The list of logprob JSON strings since the last output, + or None if the request does not require logprobs. + + finish_reason : Optional[str] + The finish reason of the request, or None if unfinished. + """ + + delta_text: str + num_delta_tokens: int + delta_logprob_json_strs: Optional[List[str]] + finish_reason: Optional[str] + + +class AsyncRequestStream: + """The asynchronous stream for requests in AsyncLLMEngine. + + Each request has its own unique stream. + The stream exposes the method `push` for engine to push new generated + delta text to the stream, and the method `finish` for engine to mark + the finish of generation. + + The stream implements `__aiter__` and `__anext__`, which the engine + can use to iterates all the generated tokens in order asynchronously. + """ + + # The asynchronous queue to hold elements of either a list of + # CallbackStreamOutput or an exception. + if sys.version_info >= (3, 9): + _queue: asyncio.Queue[ # pylint: disable=unsubscriptable-object + Union[List[CallbackStreamOutput], Exception] + ] + else: + _queue: asyncio.Queue + # The finish flag. + _finished: bool + + def __init__(self) -> None: + self._queue = asyncio.Queue() + self._finished = False + + def push(self, item_or_exception: Union[List[CallbackStreamOutput], Exception]) -> None: + """Push a new token to the stream.""" + if self._finished: + # No new item is expected after finish. + self._queue.put_nowait( + RuntimeError( + "The request has already finished. " + "The stream is not supposed to accept new items." + ) + ) + return + self._queue.put_nowait(item_or_exception) + + def finish(self) -> None: + """Mark the finish of the generation in the stream.""" + self._queue.put_nowait(StopIteration()) + self._finished = True + + def __aiter__(self): + return self + + async def __anext__(self) -> List[CallbackStreamOutput]: + result = await self._queue.get() + if isinstance(result, StopIteration): + raise StopAsyncIteration + if isinstance(result, Exception): + raise result + return result + + +class EngineState: + """The engine states that the request stream callback function may use. + + This class is used for both AsyncLLMEngine and LLMEngine. + AsyncLLMEngine uses the fields and methods starting with "async", + and LLMEngine uses the ones starting with "sync". + + - For AsyncLLMEngine, the state contains an asynchronous event loop, + the streamers and the number of unfinished generations for each request + being processed. + - For LLMEngine, the state contains a callback output blocking queue, + the text streamers and the number of unfinished requests. + + We use this state class to avoid the callback function from capturing + the AsyncLLMEngine. + + The state also optionally maintains an event trace recorder, which can + provide Chrome tracing when enabled. + """ + + trace_recorder = None + # States used for AsyncLLMEngine + async_event_loop: Optional[asyncio.AbstractEventLoop] = None + async_streamers: Dict[str, Tuple[AsyncRequestStream, List[TextStreamer]]] = {} + async_num_unfinished_generations: Dict[str, int] = {} + # States used for LLMEngine + sync_output_queue: queue.Queue = queue.Queue() + sync_text_streamers: List[TextStreamer] = [] + sync_num_unfinished_generations: int = 0 + + def __init__(self, enable_tracing: bool) -> None: + """Constructor.""" + if enable_tracing: + self.trace_recorder = EventTraceRecorder() + + def record_event(self, request_id: str, event: str) -> None: + """Record a event for the the input request in the trace + recorder when the recorder exists. + + Parameters + ---------- + request_id : str + The subject request of the event. + + event : str + The event in a string name. + It can have one of the following patterns: + - "start xxx", which marks the start of event "xxx", + - "finish xxx", which marks the finish of event "xxx", + - "yyy", which marks the instant event "yyy". + The "starts" and "finishes" will be automatically paired in the trace recorder. + """ + if self.trace_recorder is None: + return + self.trace_recorder.add_event(request_id, event) + + def get_request_stream_callback( + self, kind: Literal["async", "sync"] + ) -> Callable[[List[data.RequestStreamOutput]], None]: + """Construct a callback function and return. + + The callback function has signature + "Callable[[List[data.RequestStreamOutput]], None]", + whose input is a list of "data.RequestStreamOutput". + Each "data.RequestStreamOutput" is the delta output of a request, + generated from the engine. + """ + + f_callback = ( + self._async_request_stream_callback + if kind == "async" + else self._sync_request_stream_callback + ) + + def _callback(delta_outputs: List[data.RequestStreamOutput]) -> None: + f_callback(delta_outputs) + + return _callback + + def async_lazy_init_event_loop(self) -> None: + """Lazily set the asyncio event loop so that the event + loop is the main driving event loop of the process. + """ + if self.async_event_loop is None: + self.async_event_loop = asyncio.get_event_loop() + + def _async_request_stream_callback(self, delta_outputs: List[data.RequestStreamOutput]) -> None: + """The request stream callback function for AsyncLLMEngine to stream back + the request generation results. + + Note + ---- + This callback function uses `call_soon_threadsafe` in asyncio to + schedule the invocation in the event loop, so that the underlying + callback logic will be executed asynchronously in the future rather + than right now. + """ + + # Schedule a callback run in the event loop without executing right now. + # NOTE: This function causes GIL during execution. + self.async_event_loop.call_soon_threadsafe( + self._async_request_stream_callback_impl, delta_outputs + ) + + def _async_request_stream_callback_impl( + self, delta_outputs: List[data.RequestStreamOutput] + ) -> None: + """The underlying implementation of request stream callback for AsyncLLMEngine.""" + for delta_output in delta_outputs: + request_id, stream_outputs = delta_output.unpack() + streamers = self.async_streamers.get(request_id, None) + if streamers is None: + continue + + self.record_event(request_id, event="start callback") + stream, text_streamers = streamers + outputs = [] + for stream_output, text_streamer in zip(stream_outputs, text_streamers): + self.record_event(request_id, event="start detokenization") + delta_text = ( + text_streamer.put(stream_output.delta_token_ids) + if len(stream_output.delta_token_ids) > 0 + else "" + ) + if stream_output.finish_reason is not None: + delta_text += text_streamer.finish() + self.record_event(request_id, event="finish detokenization") + + outputs.append( + CallbackStreamOutput( + delta_text=delta_text, + num_delta_tokens=len(stream_output.delta_token_ids), + delta_logprob_json_strs=stream_output.delta_logprob_json_strs, + finish_reason=stream_output.finish_reason, + ) + ) + if stream_output.finish_reason is not None: + self.async_num_unfinished_generations[request_id] -= 1 + + # Push new delta text to the stream. + stream.push(outputs) + if self.async_num_unfinished_generations[request_id] == 0: + stream.finish() + self.async_streamers.pop(request_id, None) + self.async_num_unfinished_generations.pop(request_id, None) + self.record_event(request_id, event="finish callback") + + def _sync_request_stream_callback(self, delta_outputs: List[data.RequestStreamOutput]) -> None: + """The request stream callback function for LLMEngine to stream back + the request generation results. + """ + # Put the delta outputs to the queue in the unblocking way. + self.sync_output_queue.put_nowait(delta_outputs) + + +class LLMEngineBase: # pylint: disable=too-many-instance-attributes,too-few-public-methods + """The base engine class, which implements common functions that + are shared by LLMEngine and AsyncLLMEngine. + + This class wraps a threaded engine that runs on a standalone + thread inside and streams back the delta generated results via + callback functions. The internal threaded engine keeps running an + loop that drives the engine. + + LLMEngine and AsyncLLMEngine inherits this LLMEngineBase class, and implements + their own methods to process the delta generated results received + from callback functions and yield the processed delta results in + the forms of standard API protocols. + + Checkout subclasses AsyncLLMEngine/LLMEngine for the docstring of constructor parameters. + """ + + def __init__( # pylint: disable=too-many-arguments,too-many-locals + self, + kind: Literal["async", "sync"], + model: str, + device: Union[str, tvm.runtime.Device], + model_lib_path: Optional[str], + mode: Literal["local", "interactive", "server"], + additional_models: Optional[List[str]], + max_batch_size: Optional[int], + max_total_sequence_length: Optional[int], + prefill_chunk_size: Optional[int], + gpu_memory_utilization: Optional[float], + speculative_mode: SpeculativeMode, + spec_draft_length: int, + enable_tracing: bool, + ) -> None: + # - Initialize model loading info. + models = _parse_models(model, model_lib_path, additional_models) + if isinstance(device, str): + device = detect_device(device) + assert isinstance(device, Device) + ( + model_args, + model_config_paths, + self.conv_template, + ) = _process_model_args(models, device) + + # - Load the raw model config into dict + self.model_config_dicts = [] + for i, model_info in enumerate(models): + model_info.model_lib_path = model_args[i][1] + with open(model_config_paths[i], "r", encoding="utf-8") as file: + self.model_config_dicts.append(json.load(file)) + + # - Decide the KV cache config based on mode and user input. + ( + max_batch_size, + max_total_sequence_length, + prefill_chunk_size, + max_single_sequence_length, + ) = _infer_kv_cache_config( + mode, + max_batch_size, + max_total_sequence_length, + prefill_chunk_size, + gpu_memory_utilization, + models, + device, + self.model_config_dicts, + model_config_paths, + ) + self.max_input_sequence_length = min(max_single_sequence_length, max_total_sequence_length) + + # - Initialize engine state and engine. + self.state = EngineState(enable_tracing) + module = tvm.get_global_func("mlc.serve.create_threaded_engine", allow_missing=False)() + self._ffi = { + key: module[key] + for key in [ + "add_request", + "abort_request", + "run_background_loop", + "run_background_stream_back_loop", + "init_background_engine", + "exit_background_loop", + "debug_call_func_on_all_worker", + ] + } + self.tokenizer = Tokenizer(model_args[0][0]) + + def _background_loop(): + self._ffi["init_background_engine"]( + EngineConfig( + model=model_args[0][0], + model_lib_path=model_args[0][1], + additional_models=[model_arg[0] for model_arg in model_args[1:]], + additional_model_lib_paths=[model_arg[1] for model_arg in model_args[1:]], + device=device, + kv_cache_page_size=16, + max_num_sequence=max_batch_size, + max_total_sequence_length=max_total_sequence_length, + max_single_sequence_length=max_single_sequence_length, + prefill_chunk_size=prefill_chunk_size, + speculative_mode=speculative_mode, + spec_draft_length=spec_draft_length, + ), + self.state.get_request_stream_callback(kind), + self.state.trace_recorder, + ) + self._ffi["run_background_loop"]() + + def _background_stream_back_loop(): + self._ffi["run_background_stream_back_loop"]() + + # - Create the background engine-driving thread and start the loop. + self._background_loop_thread: threading.Thread = threading.Thread(target=_background_loop) + self._background_stream_back_loop_thread: threading.Thread = threading.Thread( + target=_background_stream_back_loop + ) + self._background_loop_thread.start() + self._background_stream_back_loop_thread.start() + self._terminated = False + + def terminate(self): + """Terminate the engine.""" + self._terminated = True + self._ffi["exit_background_loop"]() + self._background_loop_thread.join() + self._background_stream_back_loop_thread.join() + + def _debug_call_func_on_all_worker(self, func_name: str) -> None: + """Call the given global function on all workers. Only for debug purpose.""" + self._ffi["debug_call_func_on_all_worker"](func_name) + + +def process_chat_completion_request( # pylint: disable=too-many-arguments + request: openai_api_protocol.ChatCompletionRequest, + request_id: str, + engine_state: EngineState, + model_config: Dict[str, Any], + f_tokenize: Callable[[str], List[int]], + max_input_sequence_length: int, + conv_template: Conversation, +) -> Tuple[List[Union[List[int], data.Data]], GenerationConfig, bool, int]: + """Process the given ChatCompletionRequest, apply request validity + checks, and return the processed prompts, and other info. + + Parameters + ---------- + request : openai_api_protocol.ChatCompletionRequest + The request to be processed and checked. + + request_id : str + The id of the request. + + engine_state : EngineState + The state of the engine. + + model_config : Dict[str, Any] + The model configuration dictionary. + + f_tokenize : Callable[[str], List[int]] + The tokenizer encode function. + + max_input_sequence_length : int + The maximum allowed total prompt length. + + conv_template : Conversation + The conversation template of the model. + + Returns + ------- + prompts : List[Union[List[int], data.Data]] + The prompts, in a list. + Each element is a list of token ids or a "data.Data" instance. + + generation_cfg : GenerationConfig + The generation config of the request got from the input request. + + use_function_calling : bool + A boolean flag indicating if the request uses function call. + + prompt_length : int + The total prompt length. + """ + engine_state.record_event(request_id, event="receive request") + # - Check if unsupported arguments are specified. + engine_utils.check_unsupported_fields(request) + + # - Process messages and update the conversation template in three steps: + # i. Check the message validity. + # ii. Add the input messages to the conversation template. + # iii. Add the additional message for the assistant. + request.check_message_validity() + # - Check for function calling usage and update the conversation template + request.check_function_call_usage(conv_template) + + for message in request.messages: + role = message.role + content = message.content + if role == "system": + assert isinstance(content, str) + conv_template.system_message = content if content is not None else "" + continue + assert role != "tool", "Internal error: tool role." + conv_template.messages.append((role, content)) + conv_template.messages.append(("assistant", None)) + + # - Get the prompt from template, and encode to token ids. + # - Check prompt length + engine_state.record_event(request_id, event="start tokenization") + prompts = engine_utils.process_prompts( # type: ignore + conv_template.as_prompt(model_config), f_tokenize + ) + engine_state.record_event(request_id, event="finish tokenization") + + if conv_template.system_prefix_token_ids is not None: + if isinstance(prompts[0], list): + prompts[0] = conv_template.system_prefix_token_ids + prompts[0] + else: + prompts.insert(0, conv_template.system_prefix_token_ids) + prompt_length = engine_utils.check_and_get_prompts_length(prompts, max_input_sequence_length) + + # Process generation config. Create request id. + generation_cfg = protocol_utils.get_generation_config( + request, + extra_stop_token_ids=conv_template.stop_token_ids, + extra_stop_str=conv_template.stop_str, + ) + return prompts, generation_cfg, conv_template.use_function_calling, prompt_length + + +def process_chat_completion_stream_output( # pylint: disable=too-many-arguments + delta_outputs: List[CallbackStreamOutput], + request_id: str, + engine_state: EngineState, + model: str, + generation_cfg: GenerationConfig, + use_function_calling: bool, + prompt_length: int, + finish_reasons: List[Optional[str]], + num_completion_tokens: int, +) -> Tuple[Optional[openai_api_protocol.ChatCompletionStreamResponse], int]: + """Process the delta outputs of a single request of ChatCompletion, + convert the delta output to ChatCompletionStreamResponse and return. + + Parameters + ---------- + delta_outputs : List[CallbackStreamOutput] + The delta outputs of a request. + The list length is the number of parallel generation specified by "n". + Each element corresponds to a generation. + + request_id : str + The id of the request. + + engine_state : EngineState + The state of the engine. + + model : str + The requested model. + + generation_cfg : GenerationConfig + The generation config of the request. + + use_function_calling : bool + A boolean flag indicating if the request uses function call. + + prompt_length : int + The total prompt length. + + finish_reasons : List[Optional[str]] + The list of finish reasons of each generation. + The list length is the number of parallel generation specified by "n". + This list is updated in place. + + num_completion_tokens : int + The number of total completion tokens so far. + + Returns + ------- + response : Optional[openai_api_protocol.ChatCompletionStreamResponse] + The converted OpenAI API ChatCompletionStreamResponse instance. + It can be none when there is no content. + + num_completion_tokens : int + The updated number of total completion tokens. + It is sum of the input number and the number of new completion tokens + from the given delta outputs. + """ + assert len(delta_outputs) == generation_cfg.n + choices = [] + num_new_completion_tokens = 0 + for i, delta_output in enumerate(delta_outputs): + finish_reason_updated = False + num_new_completion_tokens += delta_output.num_delta_tokens + if delta_output.finish_reason is not None and finish_reasons[i] is None: + finish_reasons[i] = ( + delta_output.finish_reason if not use_function_calling else "tool_calls" + ) + finish_reason_updated = True + if not finish_reason_updated and delta_output.delta_text == "": + # Ignore empty delta text when finish reason is not updated. + engine_state.record_event(request_id, event="skip empty delta text") + continue + + choices.append( + openai_api_protocol.ChatCompletionStreamResponseChoice( + index=i, + finish_reason=finish_reasons[i], + delta=openai_api_protocol.ChatCompletionMessage( + content=delta_output.delta_text, role="assistant" + ), + logprobs=( + openai_api_protocol.LogProbs( + content=[ + openai_api_protocol.LogProbsContent.model_validate_json( + logprob_json_str + ) + for logprob_json_str in delta_output.delta_logprob_json_strs + ] + ) + if delta_output.delta_logprob_json_strs is not None + else None + ), + ) + ) + + if len(choices) == 0 and num_new_completion_tokens == 0: + # Skip return when there is no delta output and no number of completion tokens. + return None, num_completion_tokens + num_completion_tokens += num_new_completion_tokens + response = openai_api_protocol.ChatCompletionStreamResponse( + id=request_id, + choices=choices, + model=model, + system_fingerprint="", + usage=openai_api_protocol.UsageInfo( + prompt_tokens=prompt_length, + completion_tokens=num_completion_tokens, + ), + ) + engine_state.record_event(request_id, event="yield delta output") + return response, num_completion_tokens + + +def process_completion_request( + request: openai_api_protocol.CompletionRequest, + request_id: str, + engine_state: EngineState, + tokenizer: Tokenizer, + max_input_sequence_length: int, +) -> Tuple[List[int], GenerationConfig, int, Optional[openai_api_protocol.CompletionResponse]]: + """Process the given CompletionRequest, apply request validity + checks, and return the processed prompts, and other info. + + Parameters + ---------- + request : openai_api_protocol.CompletionRequest + The request to be processed and checked. + + request_id : str + The id of the request. + + engine_state : EngineState + The state of the engine. + + tokenizer : Tokenizer + The tokenizer instance of the model. + + max_input_sequence_length : int + The maximum allowed total prompt length. + + Returns + ------- + prompt : List[int] + The prompt in a list of token ids. + + generation_cfg : GenerationConfig + The generation config of the request got from the input request. + + prompt_length : int + The total prompt length. + + echo_response : Optional[openai_api_protocol.CompletionResponse] + The CompletionResponse of the echoing part, when argument "echo" + of the input request is specified. + """ + engine_state.record_event(request_id, event="receive request") + # - Check if unsupported arguments are specified. + engine_utils.check_unsupported_fields(request) + + # - Process prompt and check validity. + engine_state.record_event(request_id, event="start tokenization") + prompts = engine_utils.process_prompts(request.prompt, tokenizer.encode) + engine_state.record_event(request_id, event="finish tokenization") + prompt_length = engine_utils.check_and_get_prompts_length(prompts, max_input_sequence_length) + prompt = prompts[0] + assert isinstance(prompt, list) + + # Process generation config. Create request id. + generation_cfg = protocol_utils.get_generation_config(request) + + # - Echo back the prompt. + echo_response = None + if request.echo: + text = tokenizer.decode(prompt) + response = openai_api_protocol.CompletionResponse( + id=request_id, + choices=[ + openai_api_protocol.CompletionResponseChoice(index=i, text=text) + for i in range(generation_cfg.n) + ], + model=request.model, + usage=openai_api_protocol.UsageInfo( + prompt_tokens=prompt_length, + completion_tokens=0, + ), + ) + echo_response = response + return prompt, generation_cfg, prompt_length, echo_response + + +def process_completion_stream_output( # pylint: disable=too-many-arguments + delta_outputs: List[CallbackStreamOutput], + request_id: str, + engine_state: EngineState, + model: str, + generation_cfg: GenerationConfig, + prompt_length: int, + finish_reasons: List[Optional[str]], + num_completion_tokens: int, +) -> Tuple[Optional[openai_api_protocol.CompletionResponse], int]: + """Process the delta outputs of a single request of Completion, + convert the delta output to CompletionResponse and return. + + Parameters + ---------- + delta_outputs : List[CallbackStreamOutput] + The delta outputs of a request. + The list length is the number of parallel generation specified by "n". + Each element corresponds to a generation. + + request_id : str + The id of the request. + + engine_state : EngineState + The state of the engine. + + model : str + The requested model. + + generation_cfg : GenerationConfig + The generation config of the request. + + prompt_length : int + The total prompt length. + + finish_reasons : List[Optional[str]] + The list of finish reasons of each generation. + The list length is the number of parallel generation specified by "n". + This list is updated in place. + + num_completion_tokens : int + The number of total completion tokens so far. + + Returns + ------- + response : Optional[openai_api_protocol.CompletionResponse] + The converted OpenAI API CompletionResponse instance. + It can be none when there is no content. + + num_completion_tokens : int + The updated number of total completion tokens. + It is sum of the input number and the number of new completion tokens + from the given delta outputs. + """ + assert len(delta_outputs) == generation_cfg.n + choices = [] + num_new_completion_tokens = 0 + for i, delta_output in enumerate(delta_outputs): + finish_reason_updated = False + if delta_output.finish_reason is not None and finish_reasons[i] is None: + finish_reasons[i] = delta_output.finish_reason + finish_reason_updated = True + num_new_completion_tokens += delta_output.num_delta_tokens + if not finish_reason_updated and delta_output.delta_text == "": + # Ignore empty delta text when finish reason is not updated. + continue + + choices.append( + openai_api_protocol.CompletionResponseChoice( + index=i, + finish_reason=finish_reasons[i], + text=delta_output.delta_text, + logprobs=( + openai_api_protocol.LogProbs( + content=[ + openai_api_protocol.LogProbsContent.model_validate_json( + logprob_json_str + ) + for logprob_json_str in delta_output.delta_logprob_json_strs + ] + ) + if delta_output.delta_logprob_json_strs is not None + else None + ), + ) + ) + + if len(choices) == 0 and num_new_completion_tokens == 0: + # Skip return when there is no delta output and no number of completion tokens. + return None, num_completion_tokens + num_completion_tokens += num_new_completion_tokens + response = openai_api_protocol.CompletionResponse( + id=request_id, + choices=choices, + model=model, + usage=openai_api_protocol.UsageInfo( + prompt_tokens=prompt_length, + completion_tokens=num_completion_tokens, + ), + ) + engine_state.record_event(request_id, event="yield delta output") + return response, num_completion_tokens + + +def create_completion_suffix_response( + request: openai_api_protocol.CompletionRequest, + request_id: str, + prompt_length: int, + finish_reasons: List[Optional[str]], + num_completion_tokens: int, +) -> Optional[openai_api_protocol.CompletionResponse]: + """Create the suffix response of Completion request + when the request requires suffix. + + Parameters + ---------- + request : openai_api_protocol.CompletionRequest + The request whose suffix response if to be created. + + request_id : str + The id of the request. + + prompt_length : int + The total prompt length. + + finish_reasons : List[Optional[str]] + The list of finish reasons of each generation. + The list length is the number of parallel generation specified by "n". + This list is updated in place. + + num_completion_tokens : int + The number of total completion tokens so far. + + Returns + ------- + suffix_response : Optional[openai_api_protocol.CompletionResponse] + The created OpenAI API CompletionResponse instance for the suffix. + Or None if the request does not require suffix. + """ + # - Echo the suffix. + if request.suffix is None: + return None + assert all(finish_reason is not None for finish_reason in finish_reasons) + response = openai_api_protocol.CompletionResponse( + id=request_id, + choices=[ + openai_api_protocol.CompletionResponseChoice( + index=i, + finish_reason=finish_reason, + text=request.suffix, + ) + for i, finish_reason in enumerate(finish_reasons) + ], + model=request.model, + usage=openai_api_protocol.UsageInfo( + prompt_tokens=prompt_length, + completion_tokens=num_completion_tokens, + ), + ) + return response + + +def convert_function_str_to_json(stringified_calls: str) -> List[Union[Dict, None]]: + """Convert a (possibly list) of function call string to a list of json objects. + Return None for invalid function call string.""" + + def parse_function_call(call_str: str): + node = ast.parse(call_str, mode="eval") + call_node = node.body + if isinstance(call_node, ast.Call) and isinstance(call_node.func, ast.Name): + name = call_node.func.id + arguments = {} + for keyword in call_node.keywords: + arguments[keyword.arg] = ast.literal_eval(keyword.value) + return {"name": name, "arguments": arguments} + return None + + if ( + stringified_calls[0] == "[" and stringified_calls[-1] == "]" + ): # hacky way to check if string list + calls = ast.literal_eval(stringified_calls) + else: + calls = [stringified_calls] + function_calls_json = [parse_function_call(call_str) for call_str in calls] + return function_calls_json + + +def process_function_call_output( + output_texts: List[str], finish_reasons: List[str] +) -> Tuple[bool, List[List[openai_api_protocol.ChatToolCall]]]: + """Process the potential function call results outputted by model, + according to the finish reasons. + Return whether the output has function call, and the list of tool calls. + """ + n = len(output_texts) + tool_calls_list: List[List[openai_api_protocol.ChatToolCall]] = [[] for _ in range(n)] + use_function_calling = any(finish_reason == "tool_calls" for finish_reason in finish_reasons) + if use_function_calling: + for i, output_text in enumerate(output_texts): + try: + fn_json_list = convert_function_str_to_json(output_text) + except (SyntaxError, ValueError): + output_text = "Got an invalid function call output from model" + finish_reasons[i] = "error" + else: + tool_calls_list[i] = [ + openai_api_protocol.ChatToolCall( + type="function", + function=openai_api_protocol.ChatFunctionCall( + name=fn_json_obj["name"], arguments=fn_json_obj["arguments"] + ), + ) + for fn_json_obj in fn_json_list + if fn_json_obj is not None + ] + if len(tool_calls_list[i]) == 0: + output_texts[i] = "Got an invalid function call output from model" + finish_reasons[i] = "error" + else: + finish_reasons[i] = "tool_calls" + return use_function_calling, tool_calls_list + + +def wrap_chat_completion_response( # pylint: disable=too-many-arguments + request_id: str, + model: str, + output_texts: List[str], + finish_reasons: List[str], + tool_calls_list: List[List[openai_api_protocol.ChatToolCall]], + logprob_results: Optional[List[List[openai_api_protocol.LogProbsContent]]], + use_function_calling: bool, + num_prompt_tokens: int, + num_completion_tokens: int, +) -> openai_api_protocol.ChatCompletionResponse: + """Wrap the non-streaming chat completion results to ChatCompletionResponse instance.""" + return openai_api_protocol.ChatCompletionResponse( + id=request_id, + choices=[ + openai_api_protocol.ChatCompletionResponseChoice( + index=i, + finish_reason=finish_reasons[i], + message=( + openai_api_protocol.ChatCompletionMessage(role="assistant", content=output_text) + if not use_function_calling or finish_reason == "error" + else openai_api_protocol.ChatCompletionMessage( + role="assistant", tool_calls=tool_calls + ) + ), + logprobs=( + openai_api_protocol.LogProbs(content=logprob_results[i]) + if logprob_results is not None + else None + ), + ) + for i, (output_text, finish_reason, tool_calls) in enumerate( + zip(output_texts, finish_reasons, tool_calls_list) + ) + ], + model=model, + system_fingerprint="", + usage=openai_api_protocol.UsageInfo( + prompt_tokens=num_prompt_tokens, completion_tokens=num_completion_tokens + ), + ) + + +def wrap_completion_response( # pylint: disable=too-many-arguments + request_id: str, + model: str, + output_texts: List[str], + finish_reasons: List[str], + logprob_results: Optional[List[List[openai_api_protocol.LogProbsContent]]], + num_prompt_tokens: int, + num_completion_tokens: int, +) -> openai_api_protocol.CompletionResponse: + """Wrap the non-streaming completion results to CompletionResponse instance.""" + return openai_api_protocol.CompletionResponse( + id=request_id, + choices=[ + openai_api_protocol.CompletionResponseChoice( + index=i, + finish_reason=finish_reason, + text=output_text, + logprobs=( + openai_api_protocol.LogProbs(content=logprob_results[i]) + if logprob_results is not None + else None + ), + ) + for i, (output_text, finish_reason) in enumerate(zip(output_texts, finish_reasons)) + ], + model=model, + usage=openai_api_protocol.UsageInfo( + prompt_tokens=num_prompt_tokens, completion_tokens=num_completion_tokens + ), + ) diff --git a/python/mlc_llm/serve/engine_utils.py b/python/mlc_llm/serve/engine_utils.py new file mode 100644 index 0000000000..d1c96e37d4 --- /dev/null +++ b/python/mlc_llm/serve/engine_utils.py @@ -0,0 +1,97 @@ +"""Utility functions for MLC Serve engine""" + +import uuid +from typing import Callable, List, Union + +from mlc_llm.serve import data + +from ..protocol import RequestProtocol, error_protocol, protocol_utils + + +def random_uuid() -> str: + """Generate a random id in hexadecimal string.""" + return uuid.uuid4().hex + + +def check_unsupported_fields(request: RequestProtocol) -> None: + """Check if the request has unsupported fields. Raise BadRequestError if so.""" + unsupported_fields = protocol_utils.get_unsupported_fields(request) + if len(unsupported_fields) != 0: + unsupported_fields = [f'"{field}"' for field in unsupported_fields] + raise error_protocol.BadRequestError( + f'Request fields {", ".join(unsupported_fields)} are not supported right now.', + ) + + +def check_and_get_prompts_length( + prompts: List[Union[List[int], data.ImageData]], max_input_sequence_length: int +) -> int: + """Check if the total prompt length exceeds the max single sequence + sequence length allowed by the served model. Raise BadRequestError if so. + Return the total prompt length. + """ + total_length: int = 0 + for prompt in prompts: + total_length += len(prompt) + if total_length > max_input_sequence_length: + raise error_protocol.BadRequestError( + f"Request prompt has {total_length} tokens in total," + f" larger than the model input length limit {max_input_sequence_length}.", + ) + return total_length + + +def process_prompts( + input_prompts: Union[str, List[int], List[Union[str, List[int], data.ImageData]]], + ftokenize: Callable[[str], List[int]], +) -> List[Union[List[int], data.ImageData]]: + """Convert all input tokens to list of token ids with regard to the + given tokenization function. + For each input prompt, return the list of token ids after tokenization. + """ + error_msg = f"Invalid request prompt {input_prompts}" + + # Case 1. The prompt is a single string. + if isinstance(input_prompts, str): + return [ftokenize(input_prompts)] + + assert isinstance(input_prompts, list) + if len(input_prompts) == 0: + raise error_protocol.BadRequestError(error_msg) + + # Case 2. The prompt is a list of token ids. + if isinstance(input_prompts[0], int): + assert isinstance(input_prompts, list) + if not all(isinstance(token_id, int) for token_id in input_prompts): + raise error_protocol.BadRequestError(error_msg) + return [input_prompts] # type: ignore + + # Case 3. A list of prompts. + output_prompts: List[Union[List[int], data.ImageData]] = [] + for input_prompt in input_prompts: + if isinstance(input_prompt, str): + output_prompts.append(ftokenize(input_prompt)) + elif isinstance(input_prompt, list) and all( + isinstance(token_id, int) for token_id in input_prompt + ): + output_prompts.append(input_prompt) + elif isinstance(input_prompt, data.ImageData): + output_prompts.append(input_prompt) + else: + raise error_protocol.BadRequestError(error_msg) + return output_prompts + + +def convert_prompts_to_data( + prompts: Union[str, List[int], List[Union[str, List[int], data.Data]]] +) -> List[data.Data]: + """Convert the given prompts in the combination of token id lists + and/or data to all data.""" + if isinstance(prompts, data.Data): + return [prompts] + if isinstance(prompts, str): + return [data.TextData(prompts)] + if isinstance(prompts[0], int): + assert isinstance(prompts, list) and all(isinstance(token_id, int) for token_id in prompts) + return [data.TokenData(prompts)] # type: ignore + return [convert_prompts_to_data(x)[0] for x in prompts] # type: ignore diff --git a/python/mlc_llm/serve/grammar.py b/python/mlc_llm/serve/grammar.py index d640c62da2..d5ad862a42 100644 --- a/python/mlc_llm/serve/grammar.py +++ b/python/mlc_llm/serve/grammar.py @@ -137,11 +137,13 @@ def from_schema( separators : Optional[Tuple[str, str]] Two separators used in the schema: comma and colon. Examples: (",", ":"), (", ", ": "). If None, the default separators will be used: (",", ": ") when the indent is not None, - and (", ", ": ") otherwise. Default: None. + and (", ", ": ") otherwise. This follows the convention in json.dumps(). Default: None. strict_mode : bool Whether to use strict mode. In strict mode, the generated grammar will not allow - unevaluatedProperties and unevaluatedItems, i.e. these will be set to false by default. + properties and items that is not specified in the schema. This is equivalent to + setting unevaluatedProperties and unevaluatedItems to false. + This helps LLM to generate accurate output in the grammar-guided generation with JSON schema. Default: True. @@ -150,9 +152,8 @@ def from_schema( grammar : BNFGrammar The generated BNF grammar. """ - indent_converted = -1 if indent is None else indent return _ffi_api.BNFGrammarFromSchema( # type: ignore # pylint: disable=no-member - schema, indent_converted, separators, strict_mode + schema, indent, separators, strict_mode ) @staticmethod @@ -166,6 +167,47 @@ def get_grammar_of_json() -> "BNFGrammar": """ return _ffi_api.BNFGrammarGetGrammarOfJSON() # type: ignore # pylint: disable=no-member + @staticmethod + def debug_json_schema_to_ebnf( + schema: str, + *, + indent: Optional[int] = None, + separators: Optional[Tuple[str, str]] = None, + strict_mode: bool = True + ) -> str: + """Convert JSON schema string to EBNF grammar string. For test purposes. + + Parameters + ---------- + json_schema : str + The JSON schema string. + + indent : Optional[int] + The number of spaces for indentation. If None, the output will be in one line. + Default: None. + + separators : Optional[Tuple[str, str]] + Two separators used in the schema: comma and colon. Examples: (",", ":"), (", ", ": "). + If None, the default separators will be used: (",", ": ") when the indent is not None, + and (", ", ": ") otherwise. This follows the convention in json.dumps(). Default: None. + + strict_mode : bool + Whether to use strict mode. In strict mode, the generated grammar will not allow + properties and items that is not specified in the schema. This is equivalent to + setting unevaluatedProperties and unevaluatedItems to false. + + This helps LLM to generate accurate output in the grammar-guided generation with JSON + schema. Default: True. + + Returns + ------- + ebnf_string : str + The EBNF grammar string. + """ + return _ffi_api.DebugJSONSchemaToEBNF( # type: ignore # pylint: disable=no-member + schema, indent, separators, strict_mode + ) + @tvm._ffi.register_object("mlc.serve.GrammarStateMatcher") # pylint: disable=protected-access class GrammarStateMatcher(Object): diff --git a/python/mlc_llm/serve/server/popen_server.py b/python/mlc_llm/serve/server/popen_server.py index ed63f6ac51..1d17f8e66a 100644 --- a/python/mlc_llm/serve/server/popen_server.py +++ b/python/mlc_llm/serve/server/popen_server.py @@ -1,13 +1,17 @@ """The MLC LLM server launched in a subprocess.""" +import os import subprocess import sys import time from pathlib import Path -from typing import Optional +from typing import List, Literal, Optional, Union import psutil import requests +from tvm.runtime import Device + +from mlc_llm.serve.config import SpeculativeMode class PopenServer: # pylint: disable=too-many-instance-attributes @@ -17,11 +21,17 @@ class PopenServer: # pylint: disable=too-many-instance-attributes def __init__( # pylint: disable=too-many-arguments self, model: str, - model_lib_path: str, - device: str = "auto", + device: Union[str, Device] = "auto", *, - max_batch_size: int = 80, + model_lib_path: Optional[str] = None, + mode: Literal["local", "interactive", "server"] = "local", + additional_models: Optional[List[str]] = None, + max_batch_size: Optional[int] = None, max_total_sequence_length: Optional[int] = None, + prefill_chunk_size: Optional[int] = None, + gpu_memory_utilization: Optional[float] = None, + speculative_mode: SpeculativeMode = SpeculativeMode.DISABLE, + spec_draft_length: int = 4, enable_tracing: bool = False, host: str = "127.0.0.1", port: int = 8000, @@ -30,37 +40,62 @@ def __init__( # pylint: disable=too-many-arguments self.model = model self.model_lib_path = model_lib_path self.device = device + self.mode = mode + self.additional_models = additional_models self.max_batch_size = max_batch_size self.max_total_sequence_length = max_total_sequence_length + self.prefill_chunk_size = prefill_chunk_size + self.gpu_memory_utilization = gpu_memory_utilization + self.speculative_mode = speculative_mode + self.spec_draft_length = spec_draft_length self.enable_tracing = enable_tracing self.host = host self.port = port self._proc: Optional[subprocess.Popen] = None - def start(self) -> None: + def start(self) -> None: # pylint: disable=too-many-branches """Launch the server in a popen subprocess. Wait until the server becomes ready before return. """ cmd = [sys.executable] cmd += ["-m", "mlc_llm", "serve", self.model] - cmd += ["--model-lib-path", self.model_lib_path] + if self.model_lib_path is not None: + cmd += ["--model-lib-path", self.model_lib_path] cmd += ["--device", self.device] - cmd += ["--max-batch-size", str(self.max_batch_size)] + if self.mode is not None: + cmd += ["--mode", self.mode] + if self.additional_models is not None: + cmd += ["--additional-models", *self.additional_models] + if self.max_batch_size is not None: + cmd += ["--max-batch-size", str(self.max_batch_size)] if self.max_total_sequence_length is not None: cmd += ["--max-total-seq-length", str(self.max_total_sequence_length)] + if self.prefill_chunk_size is not None: + cmd += ["--prefill-chunk-size", str(self.prefill_chunk_size)] + if self.speculative_mode != SpeculativeMode.DISABLE: + cmd += [ + "--speculative-mode", + self.speculative_mode.name, + "--spec-draft-length", + str(self.spec_draft_length), + ] + if self.gpu_memory_utilization is not None: + cmd += ["--gpu-memory-utilization", str(self.gpu_memory_utilization)] if self.enable_tracing: cmd += ["--enable-tracing"] cmd += ["--host", self.host] cmd += ["--port", str(self.port)] process_path = str(Path(__file__).resolve().parents[4]) - self._proc = subprocess.Popen(cmd, cwd=process_path) # pylint: disable=consider-using-with + self._proc = subprocess.Popen( # pylint: disable=consider-using-with + cmd, cwd=process_path, env=os.environ + ) # NOTE: DO NOT USE `stdout=subprocess.PIPE, stderr=subprocess.PIPE` # in subprocess.Popen here. PIPE has a fixed-size buffer with may block # and hang forever. # Try to query the server until it is ready. - openai_v1_models_url = "http://127.0.0.1:8000/v1/models" + openai_v1_models_url = f"http://{self.host}:{str(self.port)}/v1/models" query_result = None timeout = 60 attempts = 0.0 diff --git a/python/mlc_llm/serve/server/server_context.py b/python/mlc_llm/serve/server/server_context.py index baad7b5e7d..0a9a1b0b1f 100644 --- a/python/mlc_llm/serve/server/server_context.py +++ b/python/mlc_llm/serve/server/server_context.py @@ -1,12 +1,8 @@ """Server context that shared by multiple entrypoint files.""" -import json from typing import Dict, List, Optional -from ...chat_module import _get_model_path -from ...conversation_template import ConvTemplateRegistry -from ...protocol.conversation_protocol import Conversation -from .. import async_engine +from ..engine import AsyncLLMEngine class ServerContext: @@ -17,9 +13,7 @@ class ServerContext: server_context: Optional["ServerContext"] = None def __init__(self): - self._models: Dict[str, async_engine.AsyncThreadedEngine] = {} - self._conv_templates: Dict[str, Conversation] = {} - self._model_configs: Dict[str, Dict] = {} + self._models: Dict[str, AsyncLLMEngine] = {} def __enter__(self): if ServerContext.server_context is not None: @@ -31,46 +25,22 @@ def __exit__(self, exc_type, exc_value, traceback): for model_engine in self._models.values(): model_engine.terminate() self._models.clear() - self._conv_templates.clear() - self._model_configs.clear() @staticmethod def current(): """Returns the current ServerContext.""" return ServerContext.server_context - def add_model(self, hosted_model: str, engine: async_engine.AsyncThreadedEngine) -> None: + def add_model(self, hosted_model: str, engine: AsyncLLMEngine) -> None: """Add a new model to the server context together with the engine.""" if hosted_model in self._models: raise RuntimeError(f"Model {hosted_model} already running.") self._models[hosted_model] = engine - # Get the conversation template. - if engine.conv_template_name is not None: - conv_template = ConvTemplateRegistry.get_conv_template(engine.conv_template_name) - if conv_template is not None: - self._conv_templates[hosted_model] = conv_template - - _, config_file_path = _get_model_path(hosted_model) - with open(config_file_path, "r", encoding="utf-8") as file: - config = json.load(file) - self._model_configs[hosted_model] = config - - def get_engine(self, model: str) -> Optional[async_engine.AsyncThreadedEngine]: + def get_engine(self, model: str) -> Optional[AsyncLLMEngine]: """Get the async engine of the requested model.""" return self._models.get(model, None) - def get_conv_template(self, model: str) -> Optional[Conversation]: - """Get the conversation template of the requested model.""" - conv_template = self._conv_templates.get(model, None) - if conv_template is not None: - return conv_template.model_copy(deep=True) - return None - def get_model_list(self) -> List[str]: """Get the list of models on serve.""" return list(self._models.keys()) - - def get_model_config(self, model: str) -> Optional[Dict]: - """Get the model config path of the requested model.""" - return self._model_configs.get(model, None) diff --git a/python/mlc_llm/serve/sync_engine.py b/python/mlc_llm/serve/sync_engine.py new file mode 100644 index 0000000000..23b151d5c7 --- /dev/null +++ b/python/mlc_llm/serve/sync_engine.py @@ -0,0 +1,360 @@ +"""The MLC LLM synchronized engine. + +NOTE: This engine defined in this file directly wraps the underlying +Engine implementation in C++, is not optimized by multi-threading and +does not offer standard OpenAI API interface. + +We do not expose it and use it by default. As of now it mainly serves +the test and debug purpose because of its simplicity. +""" + +import json +from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union + +import tvm + +from mlc_llm.serve import data +from mlc_llm.serve.config import EngineConfig, GenerationConfig, SpeculativeMode +from mlc_llm.serve.engine_base import ( + _infer_kv_cache_config, + _parse_models, + _process_model_args, + detect_device, +) +from mlc_llm.serve.event_trace_recorder import EventTraceRecorder +from mlc_llm.serve.request import Request +from mlc_llm.streamer import TextStreamer +from mlc_llm.support import logging +from mlc_llm.tokenizer import Tokenizer + +logging.enable_logging() +logger = logging.getLogger(__name__) + + +def _create_tvm_module( + creator: str, ffi_funcs: Sequence[str], creator_args: Optional[List[Any]] = None +) -> Dict[str, Callable]: + """Internal method to create a module.""" + if creator_args is None: + creator_args = [] + module = tvm.get_global_func(creator, allow_missing=False)(*creator_args) + return {key: module[key] for key in ffi_funcs} + + +class SyncLLMEngine: + """The Python interface of synchronize request serving engine for MLC LLM. + + The engine receives requests from the "add_request" method. For + an given request, the engine will keep generating new tokens for + the request until finish (under certain criterion). After finish, + the engine will return the generation result through the callback + function provided by the request. + + NOTE: This engine directly wraps the underlying Engine implementation + in C++, is not optimized by multi-threading and does not offer standard + OpenAI API interface. We do not expose it and use it by default. + As of now it mainly serves the test and debug purpose because of its + simplicity. + + Parameters + ---------- + models : Union[ModelInfo, List[ModelInfo]] + One or a list of model info (specifying which models to load and + which device to load to) to launch the engine. + + kv_cache_config : KVCacheConfig + The configuration of the paged KV cache. + + request_stream_callback : Optional[Callable[[str, data.TokenData, Optional[str]], None]] + The provided callback function to handle the generation + output. It has the signature of `(str, data.TokenData, bool) -> None`, + where + - the first string is the request id, + - the TokenData contains the generated **delta** token ids since + the last invocation of the callback on the specific request, + - the optional string value denotes the finish reason if the + generation of the request is finished, or None if it has not finished. + + The callback function is optional at construction, but it needs to + be set before the engine executing requests. This can be done via + the `set_request_stream_callback` method. Otherwise, the engine will raise + exception. + + engine_config : Optional[EngineConfig] + The Engine execution configuration. + + enable_tracing : bool + A boolean indicating if to enable event logging for requests. + """ + + def __init__( # pylint: disable=too-many-arguments,too-many-locals + self, + model: str, + device: Union[str, tvm.runtime.Device] = "auto", + *, + model_lib_path: Optional[str] = None, + mode: Literal["local", "interactive", "server"] = "local", + additional_models: Optional[List[str]] = None, + max_batch_size: Optional[int] = None, + max_total_sequence_length: Optional[int] = None, + prefill_chunk_size: Optional[int] = None, + gpu_memory_utilization: Optional[float] = None, + enable_tracing: bool = False, + speculative_mode: SpeculativeMode = SpeculativeMode.DISABLE, + spec_draft_length: int = 4, + request_stream_callback: Optional[Callable[[List[data.RequestStreamOutput]], None]] = None, + ): + # - Initialize model loading info. + models = _parse_models(model, model_lib_path, additional_models) + if isinstance(device, str): + device = detect_device(device) + assert isinstance(device, tvm.runtime.Device) + ( + model_args, + model_config_paths, + self.conv_template, + ) = _process_model_args(models, device) + + # - Load the raw model config into dict + self.model_config_dicts = [] + for i, model_info in enumerate(models): + model_info.model_lib_path = model_args[i][1] + with open(model_config_paths[i], "r", encoding="utf-8") as file: + self.model_config_dicts.append(json.load(file)) + + # - Decide the KV cache config based on mode and user input. + ( + max_batch_size, + max_total_sequence_length, + prefill_chunk_size, + max_single_sequence_length, + ) = _infer_kv_cache_config( + mode, + max_batch_size, + max_total_sequence_length, + prefill_chunk_size, + gpu_memory_utilization, + models, + device, + self.model_config_dicts, + model_config_paths, + ) + self.max_input_sequence_length = min(max_single_sequence_length, max_total_sequence_length) + + self._ffi = _create_tvm_module( + "mlc.serve.create_engine", + ffi_funcs=[ + "init", + "add_request", + "abort_request", + "step", + "stats", + "reset", + "get_request_stream_callback", + "set_request_stream_callback", + ], + ) + self.trace_recorder = EventTraceRecorder() if enable_tracing else None + + self._ffi["init"]( + EngineConfig( + model=model_args[0][0], + model_lib_path=model_args[0][1], + additional_models=[model_arg[0] for model_arg in model_args[1:]], + additional_model_lib_paths=[model_arg[1] for model_arg in model_args[1:]], + device=device, + kv_cache_page_size=16, + max_num_sequence=max_batch_size, + max_total_sequence_length=max_total_sequence_length, + max_single_sequence_length=max_single_sequence_length, + prefill_chunk_size=prefill_chunk_size, + speculative_mode=speculative_mode, + spec_draft_length=spec_draft_length, + ), + request_stream_callback, + self.trace_recorder, + ) + self.tokenizer = Tokenizer(model_args[0][0]) + + def generate( # pylint: disable=too-many-locals + self, + prompts: Union[str, List[str], List[int], List[List[int]], List[List[data.Data]]], + generation_config: Union[GenerationConfig, List[GenerationConfig]], + ) -> Tuple[List[List[str]], List[Optional[List[List[str]]]]]: + """Generate texts for a list of input prompts. + Each prompt can be a string or a list of token ids. + The generation for each prompt is independent. + Return the generation results, one for each prompt. + + Parameters + ---------- + prompts : Union[str, List[str], List[int], List[List[int]]] + One or a list of input prompts for text generation. + Each prompt can be a string or a list of token ids. + + generation_config : Union[GenerationConfig, List[GenerationConfig]] + The generation config for each requests. + If the it is a single GenerationConfig instance, + this config will be shared by all the prompts. + Otherwise, one generation config is required for every + prompt. + + Returns + ------- + output_text : List[List[str]] + The text generation results, one list of strings for each input prompt. + The length of each list is the parallel generation `n` in + generation config. + + output_logprobs_str : List[Optional[List[List[str]]]] + The logprob strings of each token for each input prompt, or None + if an input prompt does not require logprobs. + """ + if isinstance(prompts, str): + # `prompts` is a single string. + prompts = [prompts] + else: + assert isinstance(prompts, list), ( + "Input `prompts` is expected to be a string, a list of " + "str, a list of token ids or multiple lists of token ids. " + ) + if len(prompts) == 0: + return [], [] + if isinstance(prompts[0], int): + # `prompts` is a list of token ids + prompts = [prompts] # type: ignore + + num_requests = len(prompts) + if not isinstance(generation_config, list): + generation_config = [generation_config] * num_requests + + assert ( + len(generation_config) == num_requests + ), "Number of generation config and number of prompts mismatch" + + num_finished_generations = 0 + output_texts: List[List[str]] = [] + output_logprobs_str: List[Optional[List[List[str]]]] = [] + text_streamers: List[List[TextStreamer]] = [] + for i in range(num_requests): + output_texts.append([]) + output_logprobs_str.append([] if generation_config[i].logprobs else None) + text_streamers.append([]) + for _ in range(generation_config[i].n): + output_texts[i].append("") + text_streamers[i].append(TextStreamer(self.tokenizer)) + if output_logprobs_str[i] is not None: + output_logprobs_str[i].append([]) + + num_total_generations = sum(cfg.n for cfg in generation_config) + + # Save a copy of the original function callback since `generate` + # overrides the callback function. + # The original callback will be set back later on. + original_callback = self._ffi["get_request_stream_callback"]() + + # Define the callback function for request generation results + def request_stream_callback(delta_outputs: List[data.RequestStreamOutput]): + nonlocal num_finished_generations + for delta_output in delta_outputs: + request_id, stream_outputs = delta_output.unpack() + rid = int(request_id) + + assert len(stream_outputs) == generation_config[rid].n + for i, (stream_output, text_streamer) in enumerate( + zip(stream_outputs, text_streamers[rid]) + ): + if output_logprobs_str[rid] is not None: + assert stream_output.delta_logprob_json_strs is not None + output_logprobs_str[rid][i] += stream_output.delta_logprob_json_strs + + delta_text = ( + text_streamer.put(stream_output.delta_token_ids) + if len(stream_output.delta_token_ids) > 0 + else "" + ) + if stream_output.finish_reason is not None: + delta_text += text_streamer.finish() + + output_texts[rid][i] += delta_text + if stream_output.finish_reason is not None: + num_finished_generations += 1 + + # Override the callback function in engine. + self._ffi["set_request_stream_callback"](request_stream_callback) + + def convert_to_data(prompt: Union[str, List[int], List[data.Data]]) -> List[data.Data]: + if isinstance(prompt, str): + return [data.TextData(prompt)] + if isinstance(prompt[0], int): + return [data.TokenData(prompt)] # type: ignore + return prompt # type: ignore + + # Add requests to engine. + for req_id, (prompt, generation_cfg) in enumerate(zip(prompts, generation_config)): + input_data = convert_to_data(prompt) # type: ignore + self.add_request( + Request( + request_id=str(req_id), + inputs=input_data, + generation_config=generation_cfg, + ) + ) + + while num_finished_generations != num_total_generations: + self.step() + + # Restore the callback function in engine. + self._ffi["set_request_stream_callback"](original_callback) + return output_texts, output_logprobs_str + + def add_request(self, request: Request) -> None: + """Add a new request to the engine. + + Parameters + ---------- + request : Request + The request to add. + """ + self._ffi["add_request"](request) + + def abort_request(self, request_id: str) -> None: + """Abort the generation of the request corresponding to the input request id. + + Parameters + ---------- + request_id : str + The unique id of the request to abort. + """ + self._ffi["abort_request"](request_id) + + def step(self) -> None: + """The main function that the engine takes a step of action. + + At each step, the engine may decide to + - run prefill for one (or more) requests, + - run one-step decode for the all existing requests + ... + + In the end of certain actions (e.g., decode), the engine will + check if any request has finished, and will return the + generation results for those finished requests. + """ + self._ffi["step"]() + + def reset(self) -> None: + """Reset the engine, clean up all running data and statistics.""" + self._ffi["reset"]() + + def stats(self) -> Dict[str, float]: + """The engine runtime statistics. + We collect the following entries: + - single token prefill latency (s/tok): avg latency of processing one token in prefill + - single token decode latency (s/tok): avg latency of processing one token in decode + - engine time for prefill (sec) + - engine time for decode (sec) + - total number of processed tokens in prefill. + - total number of processed tokens in decode. + """ + stats_json_str = self._ffi["stats"]() + return json.loads(stats_json_str) diff --git a/python/mlc_llm/support/auto_target.py b/python/mlc_llm/support/auto_target.py index 6e64247ea8..5c61af6f07 100644 --- a/python/mlc_llm/support/auto_target.py +++ b/python/mlc_llm/support/auto_target.py @@ -193,6 +193,24 @@ def build(mod: IRModule, args: "CompileArgs", pipeline=None): return build +def _build_android_so(): + def build(mod: IRModule, args: "CompileArgs", pipeline=None): + output = args.output + mod = _add_system_lib_prefix(mod, args.system_lib_prefix, is_system_lib=False) + assert output.suffix == ".so" + relax.build( + mod, + target=args.target, + pipeline=pipeline, + system_lib=False, + ).export_library( + str(output), + fcompile=ndk.create_shared, + ) + + return build + + def _build_webgpu(): def build(mod: IRModule, args: "CompileArgs", pipeline=None): output = args.output @@ -330,7 +348,9 @@ def detect_system_lib_prefix( prefix_hint : str The hint for the system lib prefix. """ - if prefix_hint == "auto" and target_hint in ["iphone", "android"]: + if prefix_hint == "auto" and ( + target_hint.startswith("iphone") or target_hint.startswith("android") + ): prefix = f"{model_name}_{quantization}_".replace("-", "_") logger.warning( "%s is automatically picked from the filename, %s, this allows us to use the filename " @@ -370,6 +390,28 @@ def detect_system_lib_prefix( }, "build": _build_android, }, + "android:adreno": { + "target": { + "kind": "opencl", + "device": "adreno", + "host": { + "kind": "llvm", + "mtriple": "aarch64-linux-android", + }, + }, + "build": _build_android, + }, + "android:adreno-so": { + "target": { + "kind": "opencl", + "device": "adreno", + "host": { + "kind": "llvm", + "mtriple": "aarch64-linux-android", + }, + }, + "build": _build_android_so, + }, "metal:x86-64": { "target": { "kind": "metal", @@ -419,6 +461,7 @@ def detect_system_lib_prefix( "max_shared_memory_per_block": 32768, "thread_warp_size": 1, "supports_float16": 1, + "supports_int64": 1, "supports_int16": 1, "supports_int8": 1, "supports_8bit_buffer": 1, diff --git a/python/mlc_llm/testing/debug_chat.py b/python/mlc_llm/testing/debug_chat.py index 51e7bae586..2a70154bba 100644 --- a/python/mlc_llm/testing/debug_chat.py +++ b/python/mlc_llm/testing/debug_chat.py @@ -21,7 +21,7 @@ ) from mlc_llm.conversation_template import ConvTemplateRegistry from mlc_llm.help import HELP -from mlc_llm.serve.entrypoints import entrypoint_utils +from mlc_llm.serve import engine_utils from mlc_llm.support.argparse import ArgumentParser from mlc_llm.support.auto_device import detect_device from mlc_llm.support.style import green, red @@ -132,7 +132,7 @@ def __call__(self, func, name, before_run, ret_val, *args): class DebugChat: # pylint: disable=too-many-instance-attributes, too-few-public-methods """A chat interface used only for debugging purpose. - It debugs autoregressive decoding fully in Python via the prefill and + It debugs auto-regressive decoding fully in Python via the prefill and decode interface. It supports debugging instrument (either default or customized) to dump intermediate values for each VM function call. @@ -261,7 +261,7 @@ def _tokenize(self, prompt: str) -> tvm.nd.array: "Parsed prompt using conversation template " f"{green(self.conversation.name)}: {parsed_prompt}" ) - tokens = entrypoint_utils.process_prompts(parsed_prompt, self.tokenizer.encode) + tokens = engine_utils.process_prompts(parsed_prompt, self.tokenizer.encode) # type: ignore # TODO: Handle ImageData in DebugChat # pylint: disable=fixme assert len(tokens) == 1, "DebugChat will only handle TextData for now" diff --git a/python/mlc_llm/testing/debug_compare.py b/python/mlc_llm/testing/debug_compare.py new file mode 100644 index 0000000000..b3487e3e48 --- /dev/null +++ b/python/mlc_llm/testing/debug_compare.py @@ -0,0 +1,249 @@ +"""Debug compiled models with TVM instrument""" + +import os +from pathlib import Path +from typing import Dict, List, Set, Tuple + +import tvm +from tvm import rpc, runtime +from tvm.relax.testing.lib_comparator import LibCompareVMInstrument + +from mlc_llm.help import HELP +from mlc_llm.support.argparse import ArgumentParser +from mlc_llm.testing.debug_chat import DebugChat + + +def _print_as_table(sorted_list): + print("=" * 100) + print( + "Name".ljust(50) + + "Time (ms)".ljust(12) + + "Count".ljust(8) + + "Total time (ms)".ljust(18) + + "Percentage (%)" + ) + total_time = sum(record[1][0] * record[1][1] for record in sorted_list) * 1000 + for record in sorted_list: + time = record[1][0] * 1000 + weighted_time = time * record[1][1] + percentage = weighted_time / total_time * 100 + print( + record[0].ljust(50) + + f"{time:.4f}".ljust(12) + + str(record[1][1]).ljust(8) + + f"{weighted_time:.4f}".ljust(18) + + f"{percentage:.2f}" + ) + print(f"Total time: {total_time:.4f} ms") + + +class LibCompare(LibCompareVMInstrument): + """The default debug instrument to use if users don't specify + a customized one. + + This debug instrument will dump the arguments and output of each + VM Call instruction into a .npz file. It will also alert the user + if any function outputs are NaN or INF. + + Parameters + ---------- + mod: runtime.Module + The module of interest to be validated. + + device: runtime.Device + The device to run the target module on. + + time_eval: bool + Whether to time evaluate the functions. + + rtol: float + rtol used in validation + + atol: float + atol used in validation + """ + + def __init__( # pylint: disable=too-many-arguments, unused-argument + self, + mod: runtime.Module, + device: runtime.Device, + debug_dir: Path, + time_eval: bool = True, + rtol: float = 1e-2, + atol: float = 1, + skip_rounds: int = 0, + ): + super().__init__(mod, device, True, rtol, atol) + self.time_eval = time_eval + self.time_eval_results: Dict[str, Tuple[float, int]] = {} + self.visited: Set[str] = set([]) + self.skip_rounds = skip_rounds + self.counter = 0 + + def reset(self, debug_dir: Path): # pylint: disable=unused-argument + """Reset the state of the Instrument class + + Note + ---- + `debug_dir` is not used in this class. + + Parameters + ---------- + debug_out : Path + the directory to dump the .npz files + """ + _print_as_table( + sorted( + self.time_eval_results.items(), + key=lambda x: -(x[1][0] * x[1][1]), + ) + ) + self.time_eval_results = {} + self.visited = set([]) + self.counter = 0 + + def skip_instrument(self, func, name, before_run, ret_val, *args): + if name.startswith("shape_func"): + return True + if self.counter < self.skip_rounds: + self.counter += 1 + print(f"[{self.counter}] Skip validating {name}..") + return True + if name in self.visited: + if self.time_eval and name in self.time_eval_results: + record = self.time_eval_results[name] + self.time_eval_results[name] = (record[0], record[1] + 1) + return True + self.visited.add(name) + return False + + def compare( + self, + name: str, + ref_args: List[tvm.nd.NDArray], + new_args: List[tvm.nd.NDArray], + ret_indices: List[int], + ): + super().compare(name, ref_args, new_args, ret_indices) + + if self.time_eval and name not in self.time_eval_results: + res = self.mod.time_evaluator( + name, self.device, number=20, repeat=3 # , cache_flush_bytes=256 * 10**6 + )(*new_args) + self.time_eval_results[name] = (res.mean, 1) + print(f"Time-eval result {name} on {self.device}:\n {res}") + + +def get_instrument(args): + """Get the debug instrument from the CLI arguments""" + if args.cmp_device is None: + assert args.cmp_lib_path is None, "cmp_lib_path must be None if cmp_device is None" + args.cmp_device = args.device + args.cmp_lib_path = args.model_lib_path + + if args.cmp_device == "iphone": + assert args.cmp_lib_path.endswith(".dylib"), "Require a dylib file for iPhone" + proxy_host = os.environ.get("TVM_RPC_PROXY_HOST", "127.0.0.1") + proxy_port = int(os.environ.get("TVM_RPC_PROXY_PORT", "9090")) + sess = rpc.connect(proxy_host, proxy_port, "iphone") + sess.upload(args.cmp_lib_path) + lib = sess.load_module(os.path.basename(args.cmp_lib_path)) + cmp_device = sess.metal() + elif args.cmp_device == "android": + assert args.cmp_lib_path.endswith(".so"), "Require a so file for Android" + tracker_host = os.environ.get("TVM_TRACKER_HOST", "0.0.0.0") + tracker_port = int(os.environ.get("TVM_TRACKER_PORT", "9190")) + tracker = rpc.connect_tracker(tracker_host, tracker_port) + sess = tracker.request("android") + sess.upload(args.cmp_lib_path) + lib = sess.load_module(os.path.basename(args.cmp_lib_path)) + cmp_device = sess.cl(0) + else: + lib = tvm.runtime.load_module( + os.path.join( + args.artifact_path, + f"{args.model}-{args.quantization.name}-{args.cmp_device}.so", + ) + ) + cmp_device = tvm.device(args.cmp_device) + + return LibCompare( + lib, + cmp_device, + time_eval=args.time_eval, + debug_dir=Path(args.debug_dir), + ) + + +def main(): + """The main function to start a DebugChat CLI""" + + parser = ArgumentParser("MLC LLM Chat Debug Tool") + parser.add_argument( + "prompt", + type=str, + help="The user input prompt.", + ) + parser.add_argument( + "--generate-len", type=int, help="Number of output tokens to generate.", required=True + ) + parser.add_argument( + "--model", + type=str, + help="An MLC model directory that contains `mlc-chat-config.json`", + required=True, + ) + parser.add_argument( + "--model-lib-path", + type=str, + help="The full path to the model library file to use (e.g. a ``.so`` file).", + required=True, + ) + parser.add_argument( + "--debug-dir", + type=str, + help="The output folder to store the dumped debug files.", + required=True, + ) + parser.add_argument( + "--device", + type=str, + default="auto", + help=HELP["device_compile"] + ' (default: "%(default)s")', + ) + parser.add_argument( + "--cmp-device", + type=str, + default="none", + ) + parser.add_argument( + "--cmp-lib-path", + type=str, + default="none", + ) + parser.add_argument( + "--time-eval", + action="store_true", + help="Whether to time evaluate the functions.", + ) + parsed = parser.parse_args() + instrument = get_instrument(parsed) + debug_chat = DebugChat( + model=parsed.model, + model_lib_path=parsed.model_lib_path, + debug_dir=Path(parsed.debug_dir), + device=parsed.device, + debug_instrument=instrument, + ) + debug_chat.generate(parsed.prompt, parsed.generate_len) + # Only print decode for now + _print_as_table( + sorted( + instrument.time_eval_results.items(), + key=lambda x: -(x[1][0] * x[1][1]), + ) + ) + + +if __name__ == "__main__": + main() diff --git a/tests/python/integration/test_model_compile.py b/tests/python/integration/test_model_compile.py index 2f136f3f16..3ec70b61b3 100644 --- a/tests/python/integration/test_model_compile.py +++ b/tests/python/integration/test_model_compile.py @@ -39,12 +39,13 @@ "max_num_threads": 256, "max_shared_memory_per_block": 32768, "thread_warp_size": 1, - "supports_int16": 1, "supports_float32": 1, + "supports_float16": 1, + "supports_int64": 1, "supports_int32": 1, + "supports_int16": 1, "supports_int8": 1, "supports_16bit_buffer": 1, - "supports_float16": 1, }, "metal": "metal", "wasm": "webgpu", diff --git a/tests/python/json_ffi/test_json_ffi_engine.py b/tests/python/json_ffi/test_json_ffi_engine.py new file mode 100644 index 0000000000..b86fd423a9 --- /dev/null +++ b/tests/python/json_ffi/test_json_ffi_engine.py @@ -0,0 +1,307 @@ +# pylint: disable=chained-comparison,line-too-long,missing-docstring, +# pylint: disable=too-many-arguments,too-many-locals,unused-argument,unused-variable +import json +import queue +import threading +from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Union + +import tvm + +from mlc_llm.protocol import openai_api_protocol +from mlc_llm.serve import engine_utils +from mlc_llm.serve.engine_base import ( + EngineConfig, + SpeculativeMode, + _infer_kv_cache_config, + _parse_models, + _process_model_args, + detect_device, +) +from mlc_llm.tokenizer import Tokenizer + +prompts = [ + "What is the meaning of life?", + "Introduce the history of Pittsburgh to me. Please elaborate in detail.", + "Write a three-day Seattle travel plan. Please elaborate in detail.", + "What is Alaska famous of? Please elaborate in detail.", + "What is the difference between Lambda calculus and Turing machine? Please elaborate in detail.", + "What are the necessary components to assemble a desktop computer? Please elaborate in detail.", + "Why is Vitamin D important to human beings? Please elaborate in detail.", + "Where is milk tea originated from? Please elaborate in detail.", + "Where is the southernmost place in United States? Please elaborate in detail.", + "Do you know AlphaGo? What capabilities does it have, and what achievements has it got? Please elaborate in detail.", +] + + +class EngineState: + sync_queue: queue.Queue + + def get_request_stream_callback(self) -> Callable[[List[str]], None]: + # ChatCompletionStreamResponse + + def _callback(chat_completion_stream_responses_json_str: List[str]) -> None: + self._sync_request_stream_callback(chat_completion_stream_responses_json_str) + + return _callback + + def _sync_request_stream_callback( + self, chat_completion_stream_responses_json_str: List[str] + ) -> None: + # Put the delta outputs to the queue in the unblocking way. + self.sync_queue.put_nowait(chat_completion_stream_responses_json_str) + + +class JSONFFIEngine: + def __init__( # pylint: disable=too-many-arguments,too-many-locals + self, + model: str, + device: Union[str, tvm.runtime.Device] = "auto", + *, + model_lib_path: Optional[str] = None, + mode: Literal["local", "interactive", "server"] = "local", + additional_models: Optional[List[str]] = None, + max_batch_size: Optional[int] = None, + max_total_sequence_length: Optional[int] = None, + prefill_chunk_size: Optional[int] = None, + speculative_mode: SpeculativeMode = SpeculativeMode.DISABLE, + spec_draft_length: int = 4, + gpu_memory_utilization: Optional[float] = None, + ) -> None: + # - Initialize model loading info. + models = _parse_models(model, model_lib_path, additional_models) + if isinstance(device, str): + device = detect_device(device) + assert isinstance(device, tvm.runtime.Device) + ( + model_args, + model_config_paths, + self.conv_template, + ) = _process_model_args(models, device) + + # - Load the raw model config into dict + self.model_config_dicts = [] + for i, model_info in enumerate(models): + model_info.model_lib_path = model_args[i][1] + with open(model_config_paths[i], "r", encoding="utf-8") as file: + self.model_config_dicts.append(json.load(file)) + + # - Decide the KV cache config based on mode and user input. + ( + max_batch_size, + max_total_sequence_length, + prefill_chunk_size, + max_single_sequence_length, + ) = _infer_kv_cache_config( + mode, + max_batch_size, + max_total_sequence_length, + prefill_chunk_size, + gpu_memory_utilization, + models, + device, + self.model_config_dicts, + model_config_paths, + ) + self.max_input_sequence_length = min(max_single_sequence_length, max_total_sequence_length) + + # - Initialize engine state and engine. + self.state = EngineState() + module = tvm.get_global_func("mlc.json_ffi.CreateJSONFFIEngine", allow_missing=False)() + self._ffi = { + key: module[key] + for key in [ + "init_background_engine", + "chat_completion", + "abort", + "get_last_error", + "run_background_loop", + "run_background_stream_back_loop", + "exit_background_loop", + ] + } + self.tokenizer = Tokenizer(model_args[0][0]) + + def _background_loop(): + self._ffi["init_background_engine"]( + EngineConfig( + model=model_args[0][0], + model_lib_path=model_args[0][1], + additional_models=[model_arg[0] for model_arg in model_args[1:]], + additional_model_lib_paths=[model_arg[1] for model_arg in model_args[1:]], + device=device, + kv_cache_page_size=16, + max_num_sequence=max_batch_size, + max_total_sequence_length=max_total_sequence_length, + max_single_sequence_length=max_single_sequence_length, + prefill_chunk_size=prefill_chunk_size, + speculative_mode=speculative_mode, + spec_draft_length=spec_draft_length, + ), + self.state.get_request_stream_callback(), + None, + ) + self._ffi["run_background_loop"]() + + def _background_stream_back_loop(): + self._ffi["run_background_stream_back_loop"]() + + # Create the background engine-driving thread and start the loop. + self._background_loop_thread: threading.Thread = threading.Thread(target=_background_loop) + self._background_stream_back_loop_thread: threading.Thread = threading.Thread( + target=_background_stream_back_loop + ) + self._background_loop_thread.start() + self._background_stream_back_loop_thread.start() + self._terminated = False + + def terminate(self): + self._terminated = True + self._ffi["exit_background_loop"]() + self._background_loop_thread.join() + self._background_stream_back_loop_thread.join() + + def chat_completion( # pylint: disable=too-many-arguments,too-many-locals + self, + *, + messages: List[Dict[str, Any]], + model: str, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + logprobs: bool = False, + top_logprobs: int = 0, + logit_bias: Optional[Dict[int, float]] = None, + max_tokens: Optional[int] = None, + n: int = 1, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: bool = False, + temperature: float = 1.0, + top_p: float = 1.0, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, + user: Optional[str] = None, + ignore_eos: bool = False, + response_format: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + ) -> Iterator[openai_api_protocol.ChatCompletionStreamResponse]: + if request_id is None: + request_id = f"chatcmpl-{engine_utils.random_uuid()}" + + chatcmpl_generator = self._handle_chat_completion( + openai_api_protocol.ChatCompletionRequest( + messages=[ + openai_api_protocol.ChatCompletionMessage.model_validate(message) + for message in messages + ], + model=model, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + logprobs=logprobs, + top_logprobs=top_logprobs, + logit_bias=logit_bias, + max_tokens=max_tokens, + n=n, + seed=seed, + stop=stop, + stream=stream, + temperature=temperature, + top_p=top_p, + tools=( + [openai_api_protocol.ChatTool.model_validate(tool) for tool in tools] + if tools is not None + else None + ), + tool_choice=tool_choice, + user=user, + ignore_eos=ignore_eos, + response_format=( + openai_api_protocol.RequestResponseFormat.model_validate(response_format) + if response_format is not None + else None + ), + ).model_dump_json(), + n=n, + request_id=request_id, + ) + for response in chatcmpl_generator: + yield response + + def _handle_chat_completion( + self, request_json_str: str, n: int, request_id: str + ) -> Iterator[openai_api_protocol.ChatCompletionStreamResponse]: + self.state.sync_queue = queue.Queue() + num_unfinished_requests = n + + success = bool(self._ffi["chat_completion"](request_json_str, request_id)) + + try: + while num_unfinished_requests > 0: + chat_completion_stream_responses_json_str = self.state.sync_queue.get() + for chat_completion_response_json_str in chat_completion_stream_responses_json_str: + chat_completion_response = ( + openai_api_protocol.ChatCompletionStreamResponse.model_validate_json( + chat_completion_response_json_str + ) + ) + for choice in chat_completion_response.choices: + if choice.finish_reason is not None: + num_unfinished_requests -= 1 + yield chat_completion_response + except Exception as exception: # pylint: disable=broad-exception-caught + self._ffi["abort"](request_id) + raise exception + + +def test_chat_completion(engine: JSONFFIEngine): + num_requests = 2 + max_tokens = 64 + n = 1 + output_texts: List[List[str]] = [["" for _ in range(n)] for _ in range(num_requests)] + + for rid in range(num_requests): + print(f"chat completion for request {rid}") + for response in engine.chat_completion( + messages=[{"role": "user", "content": [{"type": "text", "text": prompts[rid]}]}], + model=model, + max_tokens=max_tokens, + n=n, + request_id=str(rid), + ): + for choice in response.choices: + assert choice.delta.role == "assistant" + assert isinstance(choice.delta.content[0], Dict) + assert choice.delta.content[0]["type"] == "text" + output_texts[rid][choice.index] += choice.delta.content[0]["text"] + + # Print output. + print("Chat completion all finished") + for req_id, outputs in enumerate(output_texts): + print(f"Prompt {req_id}: {prompts[req_id]}") + if len(outputs) == 1: + print(f"Output {req_id}:{outputs[0]}\n") + else: + for i, output in enumerate(outputs): + print(f"Output {req_id}({i}):{output}\n") + + +def test_malformed_request(engine: JSONFFIEngine): + for response in engine._handle_chat_completion("malformed_string", n=1, request_id="123"): + assert len(response.choices) == 1 + assert response.choices[0].finish_reason == "error" + + +if __name__ == "__main__": + # Create engine. + model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" + model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + engine = JSONFFIEngine( + model, + model_lib_path=model_lib_path, + max_total_sequence_length=1024, + ) + + test_chat_completion(engine) + test_malformed_request(engine) + + engine.terminate() + del engine diff --git a/tests/python/serve/evaluate_engine.py b/tests/python/serve/evaluate_engine.py index bbd2089f4c..4e541b7437 100644 --- a/tests/python/serve/evaluate_engine.py +++ b/tests/python/serve/evaluate_engine.py @@ -4,8 +4,8 @@ import random from typing import List, Tuple -from mlc_llm.serve import Engine, GenerationConfig, KVCacheConfig -from mlc_llm.serve.engine import ModelInfo +from mlc_llm.serve import GenerationConfig +from mlc_llm.serve.sync_engine import SyncLLMEngine def _parse_args(): @@ -13,15 +13,12 @@ def _parse_args(): args.add_argument("--model-lib-path", type=str) args.add_argument("--device", type=str, default="auto") args.add_argument("--batch-size", type=int, default=80) - args.add_argument("--page-size", type=int, default=16) args.add_argument("--max-total-seq-length", type=int) args.add_argument("--seed", type=int, default=0) parsed = args.parse_args() parsed.model = os.path.dirname(parsed.model_lib_path) assert parsed.batch_size % 16 == 0 - assert parsed.page_size == 16 - assert parsed.max_total_seq_length >= 2048 return parsed @@ -43,17 +40,16 @@ def generate_requests( def benchmark(args: argparse.Namespace): random.seed(args.seed) - # Initialize model loading info and KV cache config - model = ModelInfo(args.model, args.model_lib_path, args.device) - kv_cache_config = KVCacheConfig( - page_size=args.page_size, - max_num_sequence=args.batch_size, + # Create engine + engine = SyncLLMEngine( + model=args.model, + device=args.device, + model_lib_path=args.model_lib_path, + mode="server", + max_batch_size=args.batch_size, max_total_sequence_length=args.max_total_seq_length, ) - # Create engine - engine = Engine(model, kv_cache_config) - print(args) for num_requests in [1, 2, 4, 8, 16, 32, 64]: if num_requests > args.batch_size: diff --git a/tests/python/serve/server/test_server.py b/tests/python/serve/server/test_server.py index 286d64a874..ad4fa01a82 100644 --- a/tests/python/serve/server/test_server.py +++ b/tests/python/serve/server/test_server.py @@ -181,7 +181,7 @@ def check_openai_stream_response( usage = response["usage"] assert isinstance(usage, dict) assert usage["total_tokens"] == usage["prompt_tokens"] + usage["completion_tokens"] - assert usage["prompt_tokens"] > 0 + assert usage["prompt_tokens"] >= 0 if completion_tokens is not None: assert usage["completion_tokens"] <= completion_tokens @@ -255,6 +255,7 @@ def test_openai_v1_completions( "prompt": prompt, "max_tokens": max_tokens, "stream": stream, + "ignore_eos": True, } response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180) @@ -310,7 +311,7 @@ def test_openai_v1_completions_openai_package( model=served_model[0], object_str="text_completion", num_choices=1, - finish_reasons=["length"], + finish_reasons=["length", "stop"], completion_tokens=max_tokens, ) else: @@ -323,7 +324,7 @@ def test_openai_v1_completions_openai_package( model=served_model[0], object_str="text_completion", num_choices=1, - finish_reasons=["length"], + finish_reasons=["length", "stop"], completion_tokens=max_tokens, ) @@ -362,6 +363,7 @@ def test_openai_v1_completions_echo( "max_tokens": max_tokens, "echo": True, "stream": stream, + "ignore_eos": True, } response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180) @@ -412,6 +414,7 @@ def test_openai_v1_completions_suffix( "max_tokens": max_tokens, "suffix": suffix, "stream": stream, + "ignore_eos": True, } response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180) @@ -511,6 +514,7 @@ def test_openai_v1_completions_temperature( "max_tokens": max_tokens, "stream": stream, "temperature": 0.0, + "ignore_eos": True, } response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180) @@ -616,6 +620,51 @@ class Schema(BaseModel): "response_format": {"type": "json_object", "schema": schema_str}, } + response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180) + if not stream: + check_openai_nonstream_response( + response.json(), + is_chat_completion=False, + model=served_model[0], + object_str="text_completion", + num_choices=1, + finish_reasons=["length"], + ) + else: + responses = [] + for chunk in response.iter_lines(chunk_size=512): + if not chunk or chunk == b"data: [DONE]": + continue + responses.append(json.loads(chunk.decode("utf-8")[6:])) + check_openai_stream_response( + responses, + is_chat_completion=False, + model=served_model[0], + object_str="text_completion", + num_choices=1, + finish_reasons=["length"], + ) + + +@pytest.mark.parametrize("stream", [False, True]) +def test_openai_v1_completions_json( + served_model: Tuple[str, str], + launch_server, # pylint: disable=unused-argument + stream: bool, +): + # `served_model` and `launch_server` are pytest fixtures + # defined in conftest.py. + + prompt = "Response with a json object:" + max_tokens = 128 + payload = { + "model": served_model[0], + "prompt": prompt, + "max_tokens": max_tokens, + "stream": stream, + "response_format": {"type": "json_object"}, + } + response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=60) if not stream: check_openai_nonstream_response( @@ -664,6 +713,7 @@ def test_openai_v1_completions_logit_bias( "max_tokens": max_tokens, "stream": stream, "logit_bias": {338: -100}, # 338 is " is" in Llama tokenizer. + "ignore_eos": True, } response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180) @@ -710,6 +760,7 @@ def test_openai_v1_completions_presence_frequency_penalty( "stream": stream, "frequency_penalty": 2.0, "presence_penalty": 2.0, + "ignore_eos": True, } response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180) @@ -753,6 +804,7 @@ def test_openai_v1_completions_seed( "max_tokens": max_tokens, "stream": False, "seed": 233, + "ignore_eos": True, } response1 = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180) diff --git a/tests/python/serve/test_grammar_parser.py b/tests/python/serve/test_grammar_parser.py index 325b0a5117..10eacdf9b9 100644 --- a/tests/python/serve/test_grammar_parser.py +++ b/tests/python/serve/test_grammar_parser.py @@ -17,7 +17,7 @@ def test_bnf_simple(): b ::= (([b])) c ::= (([c])) """ - bnf_grammar = BNFGrammar.from_ebnf_string(before, True, False) + bnf_grammar = BNFGrammar.from_ebnf_string(before, "main", True, False) after = bnf_grammar.to_string() assert after == expected @@ -36,7 +36,7 @@ def test_ebnf(): c_1 ::= (([acep-z] c_1) | ([acep-z])) d_1 ::= ("" | ([d])) """ - bnf_grammar = BNFGrammar.from_ebnf_string(before, True, False) + bnf_grammar = BNFGrammar.from_ebnf_string(before, "main", True, False) after = bnf_grammar.to_string() assert after == expected @@ -60,7 +60,7 @@ def test_star_quantifier(): e_star_2 ::= [g]* d_1_choice ::= (([b] [c] [d]) | ([p] [q])) """ - bnf_grammar = BNFGrammar.from_ebnf_string(before, True, False) + bnf_grammar = BNFGrammar.from_ebnf_string(before, "main", True, False) after = bnf_grammar.to_string() assert after == expected @@ -75,7 +75,7 @@ def test_char(): rest1 ::= ((([\?] [\"] [\'] [\u6d4b] [\u8bd5] [\u3042] [c]) ([\U0001f440]) "")) """ # Disable unwrap_nesting_rules to expose the result before unwrapping. - bnf_grammar = BNFGrammar.from_ebnf_string(before, False, False) + bnf_grammar = BNFGrammar.from_ebnf_string(before, "main", False, False) after = bnf_grammar.to_string() assert after == expected @@ -90,7 +90,7 @@ def test_space(): """ expected = """main ::= (([a] [b] [c] [d] [e]) | ([f]) | ([g])) """ - bnf_grammar = BNFGrammar.from_ebnf_string(before, True, False) + bnf_grammar = BNFGrammar.from_ebnf_string(before, "main", True, False) after = bnf_grammar.to_string() assert after == expected @@ -101,7 +101,7 @@ def test_nest(): expected = """main ::= (([a] main_choice) | ([e] [f])) main_choice ::= (([b]) | ([c] [d])) """ - bnf_grammar = BNFGrammar.from_ebnf_string(before, True, False) + bnf_grammar = BNFGrammar.from_ebnf_string(before, "main", True, False) after = bnf_grammar.to_string() assert after == expected @@ -122,7 +122,7 @@ def test_flatten(): empty_test ::= ("" | ([d]) | ([a])) sequence_test_choice ::= (([c]) | ([d])) """ - bnf_grammar = BNFGrammar.from_ebnf_string(before, True, False) + bnf_grammar = BNFGrammar.from_ebnf_string(before, "main", True, False) after = bnf_grammar.to_string() assert after == expected @@ -159,7 +159,7 @@ def test_json(): exponent_choice_1 ::= ("" | ([+]) | ([\-])) """ - bnf_grammar = BNFGrammar.from_ebnf_string(before, True, False) + bnf_grammar = BNFGrammar.from_ebnf_string(before, "main", True, False) after = bnf_grammar.to_string() assert after == expected @@ -176,9 +176,9 @@ def test_to_string_roundtrip(): c_2 ::= [acep-z] d_1 ::= [d] | "" """ - bnf_grammar_1 = BNFGrammar.from_ebnf_string(before, True, False) + bnf_grammar_1 = BNFGrammar.from_ebnf_string(before, "main", True, False) output_string_1 = bnf_grammar_1.to_string() - bnf_grammar_2 = BNFGrammar.from_ebnf_string(output_string_1, True, False) + bnf_grammar_2 = BNFGrammar.from_ebnf_string(output_string_1, "main", True, False) output_string_2 = bnf_grammar_2.to_string() assert output_string_1 == output_string_2 @@ -240,7 +240,8 @@ def test_error(): with pytest.raises( TVMError, - match='TVMError: EBNF parse error at line 1, column 10: There must be a rule named "main"', + match="TVMError: EBNF parse error at line 1, column 10: " + 'The main rule with name "main" is not found.', ): BNFGrammar.from_ebnf_string('a ::= "a"') @@ -256,7 +257,7 @@ def test_to_json(): '4,3,7,8,9,5,1,10,0,2,97,122,4,1,12,5,1,13],"rules":[{"body_expr_id":6,"name":"main"},' '{"body_expr_id":11,"name":"b"},{"body_expr_id":14,"name":"c"}]}' ) - bnf_grammar = BNFGrammar.from_ebnf_string(before, True, False) + bnf_grammar = BNFGrammar.from_ebnf_string(before, "main", True, False) after = bnf_grammar.to_json(False) assert after == expected @@ -271,7 +272,7 @@ def test_to_json_roundtrip(): c_2 ::= (([acep-z])) d_1 ::= ("" | ([d])) """ - bnf_grammar_1 = BNFGrammar.from_ebnf_string(before, True, False) + bnf_grammar_1 = BNFGrammar.from_ebnf_string(before, "main", True, False) output_json_1 = bnf_grammar_1.to_json(False) bnf_grammar_2 = BNFGrammar.from_json(output_json_1) output_json_2 = bnf_grammar_2.to_json(False) diff --git a/tests/python/serve/test_grammar_state_matcher_custom.py b/tests/python/serve/test_grammar_state_matcher_custom.py index 5bdc8ecc4b..6fc48705d1 100644 --- a/tests/python/serve/test_grammar_state_matcher_custom.py +++ b/tests/python/serve/test_grammar_state_matcher_custom.py @@ -12,7 +12,7 @@ import tvm.testing from pydantic import BaseModel -from mlc_llm.serve import BNFGrammar, GrammarStateMatcher, json_schema_to_ebnf +from mlc_llm.serve import BNFGrammar, GrammarStateMatcher from mlc_llm.tokenizer import Tokenizer diff --git a/tests/python/serve/test_json_schema_converter.py b/tests/python/serve/test_json_schema_converter.py index 822199977c..84dbd2cb7b 100644 --- a/tests/python/serve/test_json_schema_converter.py +++ b/tests/python/serve/test_json_schema_converter.py @@ -5,7 +5,7 @@ import tvm.testing from pydantic import BaseModel, Field, TypeAdapter -from mlc_llm.serve import BNFGrammar, GrammarStateMatcher, json_schema_to_ebnf +from mlc_llm.serve import BNFGrammar, GrammarStateMatcher def check_schema_with_grammar( @@ -16,7 +16,7 @@ def check_schema_with_grammar( strict_mode: bool = True, ): schema_str = json.dumps(schema, indent=2) - grammar = json_schema_to_ebnf( + grammar = BNFGrammar.debug_json_schema_to_ebnf( schema_str, indent=indent, separators=separators, strict_mode=strict_mode ) assert grammar == expected_grammar @@ -25,17 +25,14 @@ def check_schema_with_grammar( def check_schema_with_json( schema: Dict[str, Any], json_str: str, - check_accepted=True, + check_accepted: bool = True, indent: Optional[int] = None, separators: Optional[Tuple[str, str]] = None, strict_mode: bool = True, ): - schema_str = json.dumps(schema, indent=2) - - ebnf_grammar_str = json_schema_to_ebnf( - schema_str, indent=indent, separators=separators, strict_mode=strict_mode + ebnf_grammar = BNFGrammar.from_schema( + json.dumps(schema, indent=2), indent=indent, separators=separators, strict_mode=strict_mode ) - ebnf_grammar = BNFGrammar.from_ebnf_string(ebnf_grammar_str) matcher = GrammarStateMatcher(ebnf_grammar) if check_accepted: @@ -47,7 +44,7 @@ def check_schema_with_json( def check_schema_with_instance( schema: Dict[str, Any], instance: BaseModel, - check_accepted=True, + check_accepted: bool = True, indent: Optional[int] = None, separators: Optional[Tuple[str, str]] = None, strict_mode: bool = True, @@ -78,14 +75,14 @@ class MainModel(BaseModel): basic_null ::= "null" basic_array ::= ("[" "" basic_any (", " basic_any)* "" "]") | "[]" basic_object ::= ("{" "" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" "}") | "{}" -main_any_array_field ::= ("[" "" basic_any (", " basic_any)* "" "]") | "[]" -main_array_field ::= ("[" "" basic_string (", " basic_string)* "" "]") | "[]" -main_tuple_field_2 ::= ("[" "" basic_string (", " basic_string)* "" "]") | "[]" -main_tuple_field ::= "[" "" basic_string ", " basic_integer ", " main_tuple_field_2 "" "]" -main_object_field ::= ("{" "" basic_string ": " basic_integer (", " basic_string ": " basic_integer)* "" "}") | "{}" -main_nested_object_field_add ::= ("{" "" basic_string ": " basic_integer (", " basic_string ": " basic_integer)* "" "}") | "{}" -main_nested_object_field ::= ("{" "" basic_string ": " main_nested_object_field_add (", " basic_string ": " main_nested_object_field_add)* "" "}") | "{}" -main ::= "{" "" "\"integer_field\"" ": " basic_integer ", " "\"number_field\"" ": " basic_number ", " "\"boolean_field\"" ": " basic_boolean ", " "\"any_array_field\"" ": " main_any_array_field ", " "\"array_field\"" ": " main_array_field ", " "\"tuple_field\"" ": " main_tuple_field ", " "\"object_field\"" ": " main_object_field ", " "\"nested_object_field\"" ": " main_nested_object_field "" "}" +main_prop_3 ::= ("[" "" basic_any (", " basic_any)* "" "]") | "[]" +main_prop_4 ::= ("[" "" basic_string (", " basic_string)* "" "]") | "[]" +main_prop_5_item_2 ::= ("[" "" basic_string (", " basic_string)* "" "]") | "[]" +main_prop_5 ::= "[" "" basic_string ", " basic_integer ", " main_prop_5_item_2 "" "]" +main_prop_6 ::= ("{" "" basic_string ": " basic_integer (", " basic_string ": " basic_integer)* "" "}") | "{}" +main_prop_7_addl ::= ("{" "" basic_string ": " basic_integer (", " basic_string ": " basic_integer)* "" "}") | "{}" +main_prop_7 ::= ("{" "" basic_string ": " main_prop_7_addl (", " basic_string ": " main_prop_7_addl)* "" "}") | "{}" +main ::= "{" "" "\"integer_field\"" ": " basic_integer ", " "\"number_field\"" ": " basic_number ", " "\"boolean_field\"" ": " basic_boolean ", " "\"any_array_field\"" ": " main_prop_3 ", " "\"array_field\"" ": " main_prop_4 ", " "\"tuple_field\"" ": " main_prop_5 ", " "\"object_field\"" ": " main_prop_6 ", " "\"nested_object_field\"" ": " main_prop_7 "" "}" """ schema = MainModel.model_json_schema() @@ -134,11 +131,11 @@ class MainModel(BaseModel): basic_null ::= "null" basic_array ::= ("[" "" basic_any ("," basic_any)* "" "]") | "[]" basic_object ::= ("{" "" basic_string ": " basic_any ("," basic_string ": " basic_any)* "" "}") | "{}" -main_array_field ::= ("[" "\n " basic_string (",\n " basic_string)* "\n " "]") | "[]" -main_tuple_field_2 ::= ("[" "\n " basic_string (",\n " basic_string)* "\n " "]") | "[]" -main_tuple_field ::= "[" "\n " basic_string ",\n " basic_integer ",\n " main_tuple_field_2 "\n " "]" -main_object_field ::= ("{" "\n " basic_string ": " basic_integer (",\n " basic_string ": " basic_integer)* "\n " "}") | "{}" -main ::= "{" "\n " "\"array_field\"" ": " main_array_field ",\n " "\"tuple_field\"" ": " main_tuple_field ",\n " "\"object_field\"" ": " main_object_field "\n" "}" +main_prop_0 ::= ("[" "\n " basic_string (",\n " basic_string)* "\n " "]") | "[]" +main_prop_1_item_2 ::= ("[" "\n " basic_string (",\n " basic_string)* "\n " "]") | "[]" +main_prop_1 ::= "[" "\n " basic_string ",\n " basic_integer ",\n " main_prop_1_item_2 "\n " "]" +main_prop_2 ::= ("{" "\n " basic_string ": " basic_integer (",\n " basic_string ": " basic_integer)* "\n " "}") | "{}" +main ::= "{" "\n " "\"array_field\"" ": " main_prop_0 ",\n " "\"tuple_field\"" ": " main_prop_1 ",\n " "\"object_field\"" ": " main_prop_2 "\n" "}" """ instance = MainModel( @@ -171,10 +168,10 @@ class MainModel(BaseModel): basic_null ::= "null" basic_array ::= ("[" "" basic_any ("," basic_any)* "" "]") | "[]" basic_object ::= ("{" "" basic_string ": " basic_any ("," basic_string ": " basic_any)* "" "}") | "{}" -main_tuple_field_1 ::= "[" "\n " basic_integer ",\n " basic_integer (",\n " basic_any)* "\n " "]" -main_tuple_field ::= "[" "\n " basic_string ",\n " main_tuple_field_1 (",\n " basic_any)* "\n " "]" -main_foo_field ::= ("{" "\n " basic_string ": " basic_any (",\n " basic_string ": " basic_any)* "\n " "}") | "{}" -main ::= "{" "\n " "\"tuple_field\"" ": " main_tuple_field ",\n " "\"foo_field\"" ": " main_foo_field (",\n " basic_string ": " basic_any)* "\n" "}" +main_prop_0_item_1 ::= "[" "\n " basic_integer ",\n " basic_integer (",\n " basic_any)* "\n " "]" +main_prop_0 ::= "[" "\n " basic_string ",\n " main_prop_0_item_1 (",\n " basic_any)* "\n " "]" +main_prop_1 ::= ("{" "\n " basic_string ": " basic_any (",\n " basic_string ": " basic_any)* "\n " "}") | "{}" +main ::= "{" "\n " "\"tuple_field\"" ": " main_prop_0 ",\n " "\"foo_field\"" ": " main_prop_1 (",\n " basic_string ": " basic_any)* "\n" "}" """ instance_json = """{ @@ -220,12 +217,12 @@ class MainModel(BaseModel): basic_null ::= "null" basic_array ::= ("[" "" basic_any (", " basic_any)* "" "]") | "[]" basic_object ::= ("{" "" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" "}") | "{}" -main_bars ::= "\"a\"" -main_str_values ::= "\"a\\n\\r\\\"\"" -main_foo ::= ("\"a\"") | ("\"b\"") | ("\"c\"") -main_values ::= ("1") | ("\"a\"") | ("true") -main_field ::= ("\"foo\"") | ("\"bar\"") -main ::= "{" "" "\"bars\"" ": " main_bars ", " "\"str_values\"" ": " main_str_values ", " "\"foo\"" ": " main_foo ", " "\"values\"" ": " main_values ", " "\"field\"" ": " main_field "" "}" +main_prop_0 ::= "\"a\"" +main_prop_1 ::= "\"a\\n\\r\\\"\"" +main_prop_2 ::= ("\"a\"") | ("\"b\"") | ("\"c\"") +main_prop_3 ::= ("1") | ("\"a\"") | ("true") +main_prop_4 ::= ("\"foo\"") | ("\"bar\"") +main ::= "{" "" "\"bars\"" ": " main_prop_0 ", " "\"str_values\"" ": " main_prop_1 ", " "\"foo\"" ": " main_prop_2 ", " "\"values\"" ": " main_prop_3 ", " "\"field\"" ": " main_prop_4 "" "}" """ schema = MainModel.model_json_schema() @@ -251,9 +248,9 @@ class MainModel(BaseModel): basic_null ::= "null" basic_array ::= ("[" "" basic_any (", " basic_any)* "" "]") | "[]" basic_object ::= ("{" "" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" "}") | "{}" -main_opt_bool ::= basic_boolean | basic_null -main_size ::= basic_number | basic_null -main ::= "{" "" ("\"num\"" ": " basic_integer ", ")? ("\"opt_bool\"" ": " main_opt_bool ", ")? "\"size\"" ": " main_size (", " "\"name\"" ": " basic_string)? "" "}" +main_prop_1 ::= basic_boolean | basic_null +main_prop_2 ::= basic_number | basic_null +main ::= "{" "" ("\"num\"" ": " basic_integer ", ")? ("\"opt_bool\"" ": " main_prop_1 ", ")? "\"size\"" ": " main_prop_2 (", " "\"name\"" ": " basic_string)? "" "}" """ schema = MainModel.model_json_schema() @@ -286,9 +283,9 @@ class MainModel(BaseModel): basic_null ::= "null" basic_array ::= ("[" "" basic_any (", " basic_any)* "" "]") | "[]" basic_object ::= ("{" "" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" "}") | "{}" -main_sub_1 ::= "" | ", " "\"num\"" ": " basic_number "" -main_sub_0 ::= main_sub_1 | ", " "\"state\"" ": " basic_boolean main_sub_1 -main ::= ("{" "" (("\"size\"" ": " basic_integer main_sub_0) | ("\"state\"" ": " basic_boolean main_sub_1) | ("\"num\"" ": " basic_number "")) "" "}") | "{}" +main_part_1 ::= "" | ", " "\"num\"" ": " basic_number "" +main_part_0 ::= main_part_1 | ", " "\"state\"" ": " basic_boolean main_part_1 +main ::= ("{" "" (("\"size\"" ": " basic_integer main_part_0) | ("\"state\"" ": " basic_boolean main_part_1) | ("\"num\"" ": " basic_number "")) "" "}") | "{}" """ schema = MainModel.model_json_schema() @@ -310,10 +307,10 @@ class MainModel(BaseModel): basic_null ::= "null" basic_array ::= ("[" "" basic_any (", " basic_any)* "" "]") | "[]" basic_object ::= ("{" "" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" "}") | "{}" -main_sub_2 ::= (", " basic_string ": " basic_any)* -main_sub_1 ::= main_sub_2 | ", " "\"num\"" ": " basic_number main_sub_2 -main_sub_0 ::= main_sub_1 | ", " "\"state\"" ": " basic_boolean main_sub_1 -main ::= ("{" "" (("\"size\"" ": " basic_integer main_sub_0) | ("\"state\"" ": " basic_boolean main_sub_1) | ("\"num\"" ": " basic_number main_sub_2) | basic_string ": " basic_any main_sub_2) "" "}") | "{}" +main_part_2 ::= (", " basic_string ": " basic_any)* +main_part_1 ::= main_part_2 | ", " "\"num\"" ": " basic_number main_part_2 +main_part_0 ::= main_part_1 | ", " "\"state\"" ": " basic_boolean main_part_1 +main ::= ("{" "" (("\"size\"" ": " basic_integer main_part_0) | ("\"state\"" ": " basic_boolean main_part_1) | ("\"num\"" ": " basic_number main_part_2) | basic_string ": " basic_any main_part_2) "" "}") | "{}" """ check_schema_with_grammar(schema, ebnf_grammar_non_strict, strict_mode=False) @@ -376,12 +373,12 @@ class MainModel(BaseModel): basic_null ::= "null" basic_array ::= ("[" "" basic_any (", " basic_any)* "" "]") | "[]" basic_object ::= ("{" "" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" "}") | "{}" -main_foo_size ::= basic_number | basic_null -main_foo ::= "{" "" "\"count\"" ": " basic_integer (", " "\"size\"" ": " main_foo_size)? "" "}" -main_bars_item_sub_0 ::= "" | ", " "\"banana\"" ": " basic_string "" -main_bars_item ::= ("{" "" (("\"apple\"" ": " basic_string main_bars_item_sub_0) | ("\"banana\"" ": " basic_string "")) "" "}") | "{}" -main_bars ::= ("[" "" main_bars_item (", " main_bars_item)* "" "]") | "[]" -main ::= "{" "" "\"foo\"" ": " main_foo ", " "\"bars\"" ": " main_bars "" "}" +main_prop_0_prop_1 ::= basic_number | basic_null +main_prop_0 ::= "{" "" "\"count\"" ": " basic_integer (", " "\"size\"" ": " main_prop_0_prop_1)? "" "}" +main_prop_1_items_part_0 ::= "" | ", " "\"banana\"" ": " basic_string "" +main_prop_1_items ::= ("{" "" (("\"apple\"" ": " basic_string main_prop_1_items_part_0) | ("\"banana\"" ": " basic_string "")) "" "}") | "{}" +main_prop_1 ::= ("[" "" main_prop_1_items (", " main_prop_1_items)* "" "]") | "[]" +main ::= "{" "" "\"foo\"" ": " main_prop_0 ", " "\"bars\"" ": " main_prop_1 "" "}" """ schema = MainModel.model_json_schema() @@ -412,9 +409,9 @@ class Dog(BaseModel): basic_null ::= "null" basic_array ::= ("[" "" basic_any (", " basic_any)* "" "]") | "[]" basic_object ::= ("{" "" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" "}") | "{}" -main_0 ::= "{" "" "\"name\"" ": " basic_string ", " "\"color\"" ": " basic_string "" "}" -main_1 ::= "{" "" "\"name\"" ": " basic_string ", " "\"breed\"" ": " basic_string "" "}" -main ::= main_0 | main_1 +main_case_0 ::= "{" "" "\"name\"" ": " basic_string ", " "\"color\"" ": " basic_string "" "}" +main_case_1 ::= "{" "" "\"name\"" ": " basic_string ", " "\"breed\"" ": " basic_string "" "}" +main ::= main_case_0 | main_case_1 """ check_schema_with_grammar(model_schema, ebnf_grammar) @@ -450,6 +447,32 @@ class MainModel(BaseModel): instance_str = json.dumps(instance.model_dump(mode="json", round_trip=True, by_alias=True)) check_schema_with_json(MainModel.model_json_schema(by_alias=True), instance_str) + # property name contains space + class MainModelSpace(BaseModel): + test: Literal["abc"] = Field(..., alias="name 1") + + ebnf_grammar_space = r"""basic_escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] +basic_string_sub ::= "" | [^"\\\r\n] basic_string_sub | "\\" basic_escape basic_string_sub +basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object +basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"? +basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)? +basic_string ::= ["] basic_string_sub ["] +basic_boolean ::= "true" | "false" +basic_null ::= "null" +basic_array ::= ("[" "" basic_any (", " basic_any)* "" "]") | "[]" +basic_object ::= ("{" "" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" "}") | "{}" +main_prop_0 ::= "\"abc\"" +main ::= "{" "" "\"name 1\"" ": " main_prop_0 "" "}" +""" + + check_schema_with_grammar(MainModelSpace.model_json_schema(), ebnf_grammar_space) + + instance_space = MainModelSpace(**{"name 1": "abc"}) + instance_space_str = json.dumps( + instance_space.model_dump(mode="json", round_trip=True, by_alias=True) + ) + check_schema_with_json(MainModelSpace.model_json_schema(by_alias=True), instance_space_str) + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/serve/test_serve_async_engine.py b/tests/python/serve/test_serve_async_engine.py index a1a2791bf7..9bece30578 100644 --- a/tests/python/serve/test_serve_async_engine.py +++ b/tests/python/serve/test_serve_async_engine.py @@ -3,8 +3,7 @@ import asyncio from typing import List -from mlc_llm.serve import AsyncThreadedEngine, GenerationConfig, KVCacheConfig -from mlc_llm.serve.engine import ModelInfo +from mlc_llm.serve import AsyncLLMEngine, GenerationConfig prompts = [ "What is the meaning of life?", @@ -21,32 +20,33 @@ async def test_engine_generate(): - # Initialize model loading info and KV cache config - model = ModelInfo( - "dist/Llama-2-7b-chat-hf-q0f16-MLC", - model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", - ) - kv_cache_config = KVCacheConfig(page_size=16, max_total_sequence_length=4096) # Create engine - async_engine = AsyncThreadedEngine(model, kv_cache_config) + model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" + model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + async_engine = AsyncLLMEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + max_total_sequence_length=4096, + ) num_requests = 10 max_tokens = 256 - generation_cfg = GenerationConfig(max_tokens=max_tokens, n=3) + generation_cfg = GenerationConfig(max_tokens=max_tokens, n=7) output_texts: List[List[str]] = [ ["" for _ in range(generation_cfg.n)] for _ in range(num_requests) ] async def generate_task( - async_engine: AsyncThreadedEngine, + async_engine: AsyncLLMEngine, prompt: str, generation_cfg: GenerationConfig, request_id: str, ): print(f"generate task for request {request_id}") rid = int(request_id) - async for delta_outputs in async_engine.generate( + async for delta_outputs in async_engine._generate( prompt, generation_cfg, request_id=request_id ): assert len(delta_outputs) == generation_cfg.n @@ -76,5 +76,215 @@ async def generate_task( del async_engine +async def test_chat_completion(): + # Create engine + model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" + model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + async_engine = AsyncLLMEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + max_total_sequence_length=4096, + ) + + num_requests = 2 + max_tokens = 32 + n = 1 + output_texts: List[List[str]] = [["" for _ in range(n)] for _ in range(num_requests)] + + async def generate_task(prompt: str, request_id: str): + print(f"generate chat completion task for request {request_id}") + rid = int(request_id) + async for response in await async_engine.chat.completions.create( + messages=[{"role": "user", "content": prompt}], + model=model, + max_tokens=max_tokens, + n=n, + request_id=request_id, + stream=True, + ): + for choice in response.choices: + assert choice.delta.role == "assistant" + output_texts[rid][choice.index] += choice.delta.content + + tasks = [ + asyncio.create_task(generate_task(prompts[i], request_id=str(i))) + for i in range(num_requests) + ] + + await asyncio.gather(*tasks) + + # Print output. + print("Chat completion all finished") + for req_id, outputs in enumerate(output_texts): + print(f"Prompt {req_id}: {prompts[req_id]}") + if len(outputs) == 1: + print(f"Output {req_id}:{outputs[0]}\n") + else: + for i, output in enumerate(outputs): + print(f"Output {req_id}({i}):{output}\n") + + async_engine.terminate() + del async_engine + + +async def test_chat_completion_non_stream(): + # Create engine + model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" + model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + async_engine = AsyncLLMEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + max_total_sequence_length=4096, + ) + + num_requests = 2 + max_tokens = 32 + n = 1 + output_texts: List[List[str]] = [["" for _ in range(n)] for _ in range(num_requests)] + + async def generate_task(prompt: str, request_id: str): + print(f"generate chat completion task for request {request_id}") + rid = int(request_id) + response = await async_engine.chat.completions.create( + messages=[{"role": "user", "content": prompt}], + model=model, + max_tokens=max_tokens, + n=n, + request_id=request_id, + ) + for choice in response.choices: + assert choice.message.role == "assistant" + output_texts[rid][choice.index] += choice.message.content + + tasks = [ + asyncio.create_task(generate_task(prompts[i], request_id=str(i))) + for i in range(num_requests) + ] + + await asyncio.gather(*tasks) + + # Print output. + print("Chat completion all finished") + for req_id, outputs in enumerate(output_texts): + print(f"Prompt {req_id}: {prompts[req_id]}") + if len(outputs) == 1: + print(f"Output {req_id}:{outputs[0]}\n") + else: + for i, output in enumerate(outputs): + print(f"Output {req_id}({i}):{output}\n") + + async_engine.terminate() + del async_engine + + +async def test_completion(): + # Create engine + model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" + model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + async_engine = AsyncLLMEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + max_total_sequence_length=4096, + ) + + num_requests = 2 + max_tokens = 128 + n = 1 + output_texts: List[List[str]] = [["" for _ in range(n)] for _ in range(num_requests)] + + async def generate_task(prompt: str, request_id: str): + print(f"generate completion task for request {request_id}") + rid = int(request_id) + async for response in await async_engine.completions.create( + prompt=prompt, + model=model, + max_tokens=max_tokens, + n=n, + ignore_eos=True, + request_id=request_id, + stream=True, + ): + for choice in response.choices: + output_texts[rid][choice.index] += choice.text + + tasks = [ + asyncio.create_task(generate_task(prompts[i], request_id=str(i))) + for i in range(num_requests) + ] + + await asyncio.gather(*tasks) + + # Print output. + print("Completion all finished") + for req_id, outputs in enumerate(output_texts): + print(f"Prompt {req_id}: {prompts[req_id]}") + if len(outputs) == 1: + print(f"Output {req_id}:{outputs[0]}\n") + else: + for i, output in enumerate(outputs): + print(f"Output {req_id}({i}):{output}\n") + + async_engine.terminate() + del async_engine + + +async def test_completion_non_stream(): + # Create engine + model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" + model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + async_engine = AsyncLLMEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + max_total_sequence_length=4096, + ) + + num_requests = 2 + max_tokens = 128 + n = 1 + output_texts: List[List[str]] = [["" for _ in range(n)] for _ in range(num_requests)] + + async def generate_task(prompt: str, request_id: str): + print(f"generate completion task for request {request_id}") + rid = int(request_id) + response = await async_engine.completions.create( + prompt=prompt, + model=model, + max_tokens=max_tokens, + n=n, + ignore_eos=True, + request_id=request_id, + ) + for choice in response.choices: + output_texts[rid][choice.index] += choice.text + + tasks = [ + asyncio.create_task(generate_task(prompts[i], request_id=str(i))) + for i in range(num_requests) + ] + + await asyncio.gather(*tasks) + + # Print output. + print("Completion all finished") + for req_id, outputs in enumerate(output_texts): + print(f"Prompt {req_id}: {prompts[req_id]}") + if len(outputs) == 1: + print(f"Output {req_id}:{outputs[0]}\n") + else: + for i, output in enumerate(outputs): + print(f"Output {req_id}({i}):{output}\n") + + async_engine.terminate() + del async_engine + + if __name__ == "__main__": asyncio.run(test_engine_generate()) + asyncio.run(test_chat_completion()) + asyncio.run(test_chat_completion_non_stream()) + asyncio.run(test_completion()) + asyncio.run(test_completion_non_stream()) diff --git a/tests/python/serve/test_serve_async_engine_spec.py b/tests/python/serve/test_serve_async_engine_spec.py index 10ed7a4729..6915224f81 100644 --- a/tests/python/serve/test_serve_async_engine_spec.py +++ b/tests/python/serve/test_serve_async_engine_spec.py @@ -1,8 +1,9 @@ # pylint: disable=chained-comparison,line-too-long,missing-docstring, -# pylint: disable=too-many-arguments,too-many-locals,unused-argument,unused-variable +# pylint: disable=too-many-arguments,too-many-locals import asyncio from typing import List +<<<<<<< HEAD from mlc_llm.serve import ( AsyncThreadedEngine, EngineMode, @@ -10,6 +11,9 @@ KVCacheConfig, ) from mlc_llm.serve.engine import ModelInfo +======= +from mlc_llm.serve import AsyncLLMEngine, GenerationConfig, SpeculativeMode +>>>>>>> upstream/main prompts = [ "What is the meaning of life?", @@ -26,19 +30,20 @@ async def test_engine_generate(): - # Initialize model loading info and KV cache config - ssm = ModelInfo( - "dist/Llama-2-7b-chat-hf-q4f16_1-MLC", - model_lib_path="dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so", + # Create engine + model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" + model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + small_model = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" + small_model_lib_path = ( + "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" ) - llm = ModelInfo( - "dist/Llama-2-7b-chat-hf-q0f16-MLC", - model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", + async_engine = AsyncLLMEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + additional_models=[small_model + ":" + small_model_lib_path], + speculative_mode=SpeculativeMode.SMALL_DRAFT, ) - kv_cache_config = KVCacheConfig(page_size=16) - engine_mode = EngineMode(enable_speculative=True) - # Create engine - async_engine = AsyncThreadedEngine([llm, ssm], kv_cache_config, engine_mode) num_requests = 10 max_tokens = 256 @@ -49,14 +54,18 @@ async def test_engine_generate(): ] async def generate_task( - async_engine: AsyncThreadedEngine, + async_engine: AsyncLLMEngine, prompt: str, generation_cfg: GenerationConfig, request_id: str, ): print(f"generate task for request {request_id}") rid = int(request_id) +<<<<<<< HEAD async for delta_outputs in async_engine.generate( +======= + async for delta_outputs in async_engine._generate( +>>>>>>> upstream/main prompt, generation_cfg, request_id=request_id ): assert len(delta_outputs) == generation_cfg.n diff --git a/tests/python/serve/test_serve_engine.py b/tests/python/serve/test_serve_engine.py index 9f56f507ca..330bd4cf82 100644 --- a/tests/python/serve/test_serve_engine.py +++ b/tests/python/serve/test_serve_engine.py @@ -1,18 +1,8 @@ # pylint: disable=chained-comparison,line-too-long,missing-docstring, # pylint: disable=too-many-arguments,too-many-locals,unused-argument,unused-variable -from typing import Callable, List, Optional +from typing import List -import numpy as np - -from mlc_llm.serve import ( - Engine, - GenerationConfig, - KVCacheConfig, - Request, - RequestStreamOutput, - data, -) -from mlc_llm.serve.engine import ModelInfo +from mlc_llm.serve import GenerationConfig, LLMEngine prompts = [ "What is the meaning of life?", @@ -28,361 +18,207 @@ ] -def create_requests( - num_requests: int, - stop_token_id: Optional[int] = None, - temperature: float = 0.8, - repetition_penalty: float = 1.0, - max_tokens_low: int = 256, - max_tokens_high: int = 257, -) -> List[Request]: - assert num_requests >= 0 and num_requests <= len(prompts) - - stop_token_ids = [stop_token_id] if stop_token_id is not None else [] - requests = [] - for req_id, prompt in zip(range(num_requests), prompts): - max_tokens = np.random.randint(max_tokens_low, max_tokens_high) - requests.append( - Request( - request_id=str(req_id), - inputs=data.TextData(prompt), - generation_config=GenerationConfig( - temperature=temperature, - repetition_penalty=repetition_penalty, - max_tokens=max_tokens, - stop_token_ids=stop_token_ids, - ), - ) - ) - return requests - - -def test_engine_basic(): - """Test engine **without continuous batching**. - - - Add all requests to the engine altogether in the beginning. - - All requests have the same max_tokens. This means all requests - will end together. - - Engine keeps running `step` for estimated number of steps (number of - requests + max_tokens - 1). Then check the output of each request. - """ - - # Initialize model loading info and KV cache config - model = ModelInfo( - "dist/Llama-2-7b-chat-hf-q0f16-MLC", - model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", +def test_engine_generate(): + # Create engine + model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" + model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + engine = LLMEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + max_total_sequence_length=4096, ) - kv_cache_config = KVCacheConfig(page_size=16) - # Hyperparameters for tests (you can try different combinations). - num_requests = 10 # [4, 8, 10] - temperature = 0.9 # [0, 0.8, 0.9, 1.0, 1.1] - repetition_penalty = 1.0 # [1.0, 1.01] - max_tokens: int = 256 # [32, 128, 256] - np.random.seed(0) + num_requests = 10 + max_tokens = 256 + generation_cfg = GenerationConfig(max_tokens=max_tokens, n=7) + + output_texts: List[List[str]] = [ + ["" for _ in range(generation_cfg.n)] for _ in range(num_requests) + ] + for rid in range(num_requests): + print(f"generating for request {rid}") + for delta_outputs in engine._generate(prompts[rid], generation_cfg, request_id=str(rid)): + assert len(delta_outputs) == generation_cfg.n + for i, delta_output in enumerate(delta_outputs): + output_texts[rid][i] += delta_output.delta_text + + # Print output. + print("All finished") + for req_id, outputs in enumerate(output_texts): + print(f"Prompt {req_id}: {prompts[req_id]}") + if len(outputs) == 1: + print(f"Output {req_id}:{outputs[0]}\n") + else: + for i, output in enumerate(outputs): + print(f"Output {req_id}({i}):{output}\n") - # Output list - outputs = [[] for _ in range(num_requests)] + engine.terminate() + del engine - # Define the callback function for request generation results - def fcallback(delta_outputs: List[RequestStreamOutput]): - for delta_output in delta_outputs: - request_id, stream_outputs = delta_output.unpack() - assert len(stream_outputs) == 1 - outputs[int(request_id)] += stream_outputs[0].delta_token_ids +def test_chat_completion(): # Create engine - engine = Engine(model, kv_cache_config, request_stream_callback=fcallback) - - # Create requests - requests = create_requests( - num_requests, - temperature=temperature, - repetition_penalty=repetition_penalty, - max_tokens_low=max_tokens, - max_tokens_high=max_tokens + 1, + model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" + model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + engine = LLMEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + max_total_sequence_length=4096, ) - # Add all requests to engine - for request in requests: - engine.add_request(request) - - num_steps = num_requests + max_tokens - 1 - # Run steps - for step in range(num_steps): - engine.step() - - for req_id, output in enumerate(outputs): - print(f"Prompt {req_id}: {requests[req_id].inputs[0]}") - print(f"Output {req_id}:{engine.tokenizer.decode(output)}\n") - - -def test_engine_continuous_batching_1(): - """Test engine **with continuous batching**. + num_requests = 2 + max_tokens = 64 + n = 2 + output_texts: List[List[str]] = [["" for _ in range(n)] for _ in range(num_requests)] + + for rid in range(num_requests): + print(f"chat completion for request {rid}") + for response in engine.chat.completions.create( + messages=[{"role": "user", "content": prompts[rid]}], + model=model, + max_tokens=max_tokens, + n=n, + request_id=str(rid), + stream=True, + ): + for choice in response.choices: + assert choice.delta.role == "assistant" + output_texts[rid][choice.index] += choice.delta.content + + # Print output. + print("Chat completion all finished") + for req_id, outputs in enumerate(output_texts): + print(f"Prompt {req_id}: {prompts[req_id]}") + if len(outputs) == 1: + print(f"Output {req_id}:{outputs[0]}\n") + else: + for i, output in enumerate(outputs): + print(f"Output {req_id}({i}):{output}\n") - - Add all requests to the engine altogether in the beginning. - - All requests have a random maximum generation length. So each - request keeps generating until reaching the maximum length. - - Engine keeps running `step` for estimated number of steps (number of - requests + the maximum max_tokens - 1). Then check the output - of each request. - """ + engine.terminate() + del engine - # Initialize model loading info and KV cache config - model = ModelInfo( - "dist/Llama-2-7b-chat-hf-q0f16-MLC", - model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", - ) - kv_cache_config = KVCacheConfig(page_size=16) - - # Hyperparameters for tests (you can try different combinations) - num_requests = 10 # [4, 8, 10] - temperature = 0.9 # [0.8, 0.9, 1.0, 1.1] - repetition_penalty = 1.00 # [1.0, 1.01] - max_tokens_low = 128 - max_tokens_high = 384 - np.random.seed(0) - - # Output list - outputs = [[] for _ in range(num_requests)] - finish_time = [None] * num_requests - - # Define the callback class for request generation results - class CallbackTimer: - timer: int = -1 - - def callback_getter(self) -> Callable[[List[RequestStreamOutput]], None]: - def fcallback(delta_outputs: List[RequestStreamOutput]): - for delta_output in delta_outputs: - request_id, stream_outputs = delta_output.unpack() - assert len(stream_outputs) == 1 - if stream_outputs[0].finish_reason is not None: - print(f"Request {request_id} finished at step {self.timer}.") - outputs[int(request_id)] += stream_outputs[0].delta_token_ids - finish_time[int(request_id)] = self.timer - - return fcallback - - def step(self) -> None: - self.timer += 1 +def test_chat_completion_non_stream(): # Create engine - timer = CallbackTimer() - engine = Engine(model, kv_cache_config, request_stream_callback=timer.callback_getter()) - - # Create requests - requests = create_requests( - num_requests, - temperature=temperature, - repetition_penalty=repetition_penalty, - max_tokens_low=max_tokens_low, - max_tokens_high=max_tokens_high, + model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" + model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + engine = LLMEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + max_total_sequence_length=4096, ) - # Add all requests to engine - for request in requests: - engine.add_request(request) - - num_steps = num_requests + max(request.generation_config.max_tokens for request in requests) - 1 - # Run steps - for step in range(num_steps): - timer.step() - assert timer.timer == step - engine.step() - - for req_id, (request, output, fin_time) in enumerate(zip(requests, outputs, finish_time)): - print(f"Prompt {req_id}: {request.inputs[0]}") - print(f"Output {req_id}:{engine.tokenizer.decode(output)}\n") - assert fin_time == request.generation_config.max_tokens - 1 - - -def test_engine_continuous_batching_2(): - """Test engine **with continuous batching**. - - - Add all requests to the engine altogether in the beginning. - - All requests have the stop token. So each request keeps generating - until having the stop token or reaching the maximum length. - - Engine keeps running `step` for estimated number of steps (number of - requests + the maximum max_tokens - 1). Then check the output - of each request. - """ - - # Initialize model loading info and KV cache config - model = ModelInfo( - "dist/Llama-2-7b-chat-hf-q0f16-MLC", - model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", - ) - kv_cache_config = KVCacheConfig(page_size=16) - - # Hyperparameters for tests (you can try different combinations) - num_requests = 10 # [4, 8, 10] - temperature = 0.9 # [0.8, 0.9, 1.0, 1.1] - repetition_penalty = 1.00 # [1.0, 1.01] - stop_token_id = 2 - max_tokens = 512 - np.random.seed(0) - - # Output list - outputs = [[] for _ in range(num_requests)] - finish_time = [None] * num_requests - - # Define the callback class for request generation results - class CallbackTimer: - timer: int = -1 - - def callback_getter(self) -> Callable[[List[RequestStreamOutput]], None]: - def fcallback(delta_outputs: List[RequestStreamOutput]): - for delta_output in delta_outputs: - request_id, stream_outputs = delta_output.unpack() - assert len(stream_outputs) == 1 - if stream_outputs[0].finish_reason is not None: - print(f"Request {request_id} finished at step {self.timer}.") - outputs[int(request_id)] += stream_outputs[0].delta_token_ids - finish_time[int(request_id)] = self.timer - - return fcallback - - def step(self) -> None: - self.timer += 1 + num_requests = 2 + max_tokens = 64 + n = 2 + output_texts: List[List[str]] = [["" for _ in range(n)] for _ in range(num_requests)] + + for rid in range(num_requests): + print(f"chat completion for request {rid}") + response = engine.chat.completions.create( + messages=[{"role": "user", "content": prompts[rid]}], + model=model, + max_tokens=max_tokens, + n=n, + request_id=str(rid), + ) + for choice in response.choices: + assert choice.message.role == "assistant" + output_texts[rid][choice.index] += choice.message.content - # Create engine - timer = CallbackTimer() - engine = Engine(model, kv_cache_config, request_stream_callback=timer.callback_getter()) - - # Create requests - requests = create_requests( - num_requests, - stop_token_id=stop_token_id, - temperature=temperature, - repetition_penalty=repetition_penalty, - max_tokens_low=max_tokens, - max_tokens_high=max_tokens + 1, - ) + # Print output. + print("Chat completion all finished") + for req_id, outputs in enumerate(output_texts): + print(f"Prompt {req_id}: {prompts[req_id]}") + if len(outputs) == 1: + print(f"Output {req_id}:{outputs[0]}\n") + else: + for i, output in enumerate(outputs): + print(f"Output {req_id}({i}):{output}\n") + + engine.terminate() + del engine - # Add all requests to engine - for request in requests: - engine.add_request(request) - - num_steps = num_requests + max_tokens - 1 - # Run steps - for step in range(num_steps): - timer.step() - assert timer.timer == step - engine.step() - - for req_id, (request, output, fin_time) in enumerate(zip(requests, outputs, finish_time)): - print(f"Prompt {req_id}: {request.inputs[0]}") - if fin_time < num_requests + max_tokens - 2: - print(f"Request {req_id} ends early on the stop token") - print(f"Output {req_id}:{engine.tokenizer.decode(output)}\n") - - -def test_engine_continuous_batching_3(): - """Test engine **with continuous batching**. - - - Add requests randomly between time [0, 200). - - All requests have a random maximum generation length. So each - request keeps generating until reaching the maximum length. - - Engine keeps running `step` until all requests finish. - Then check the output of each request. - """ - - # Initialize model loading info and KV cache config - model = ModelInfo( - "dist/Llama-2-7b-chat-hf-q0f16-MLC", - model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", - ) - kv_cache_config = KVCacheConfig(page_size=16) - - # Hyperparameters for tests (you can try different combinations) - num_requests = 10 # [4, 8, 10] - temperature = 0.9 # [0.8, 0.9, 1.0, 1.1] - repetition_penalty = 1.00 # [1.0, 1.01] - stop_token_id = 2 - max_tokens_low = 64 - max_tokens_high = 192 - np.random.seed(0) - - # Output list - outputs = [[] for _ in range(num_requests)] - finish_time = [None] * num_requests - - # Define the callback class for request generation results - class CallbackTimer: - timer: int = -1 - finished_requests: int = 0 - - def callback_getter(self) -> Callable[[List[RequestStreamOutput]], None]: - def fcallback(delta_outputs: List[RequestStreamOutput]): - for delta_output in delta_outputs: - request_id, stream_outputs = delta_output.unpack() - assert len(stream_outputs) == 1 - if stream_outputs[0].finish_reason is not None: - print(f"Request {request_id} finished at step {self.timer}.") - self.finished_requests += 1 - outputs[int(request_id)] += stream_outputs[0].delta_token_ids - finish_time[int(request_id)] = self.timer - - return fcallback - - def step(self) -> None: - self.timer += 1 - - def all_finished(self) -> bool: - return self.finished_requests == num_requests +def test_completion(): # Create engine - timer = CallbackTimer() - engine = Engine(model, kv_cache_config, request_stream_callback=timer.callback_getter()) - - # Create requests - requests = create_requests( - num_requests, - stop_token_id=stop_token_id, - temperature=temperature, - repetition_penalty=repetition_penalty, - max_tokens_low=max_tokens_low, - max_tokens_high=max_tokens_high, + model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" + model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + engine = LLMEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + max_total_sequence_length=4096, ) - # Assign the time to add requests to engine - request_add_time = [np.random.randint(0, 200) for _ in range(num_requests)] - - # Run steps - while not timer.all_finished(): - timer.step() - - # Add requests to engine - for req_id, add_time in enumerate(request_add_time): - if add_time == timer.timer: - print(f"add request {req_id} at step {timer.timer}") - engine.add_request(requests[req_id]) - - engine.step() + num_requests = 2 + max_tokens = 128 + n = 1 + output_texts: List[List[str]] = [["" for _ in range(n)] for _ in range(num_requests)] + + for rid in range(num_requests): + print(f"completion for request {rid}") + for response in engine.completions.create( + prompt=prompts[rid], + model=model, + max_tokens=max_tokens, + n=n, + ignore_eos=True, + request_id=str(rid), + stream=True, + ): + for choice in response.choices: + output_texts[rid][choice.index] += choice.text + + # Print output. + print("Completion all finished") + for req_id, outputs in enumerate(output_texts): + print(f"Prompt {req_id}: {prompts[req_id]}") + if len(outputs) == 1: + print(f"Output {req_id}:{outputs[0]}\n") + else: + for i, output in enumerate(outputs): + print(f"Output {req_id}({i}):{output}\n") - for req_id, (request, output, fin_time) in enumerate(zip(requests, outputs, finish_time)): - print(f"Prompt {req_id}: {request.inputs[0]}") - print(f"Finish time: {fin_time}") - print(f"Output {req_id}:{engine.tokenizer.decode(output)}\n") + engine.terminate() + del engine -def test_engine_generate(): - # Initialize model loading info and KV cache config - model = ModelInfo( - "dist/Llama-2-7b-chat-hf-q0f16-MLC", - model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", - ) - kv_cache_config = KVCacheConfig(page_size=16, max_total_sequence_length=4096) +def test_completion_non_stream(): # Create engine - engine = Engine(model, kv_cache_config) + model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" + model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + engine = LLMEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + max_total_sequence_length=4096, + ) - num_requests = 10 - max_tokens = 256 + num_requests = 2 + max_tokens = 128 + n = 1 + output_texts: List[List[str]] = [["" for _ in range(n)] for _ in range(num_requests)] + + for rid in range(num_requests): + print(f"completion for request {rid}") + response = engine.completions.create( + prompt=prompts[rid], + model=model, + max_tokens=max_tokens, + n=n, + ignore_eos=True, + request_id=str(rid), + ) + for choice in response.choices: + output_texts[rid][choice.index] += choice.text - # Generate output. - output_texts, _ = engine.generate( - prompts[:num_requests], GenerationConfig(max_tokens=max_tokens) - ) + # Print output. + print("Completion all finished") for req_id, outputs in enumerate(output_texts): print(f"Prompt {req_id}: {prompts[req_id]}") if len(outputs) == 1: @@ -391,10 +227,13 @@ def test_engine_generate(): for i, output in enumerate(outputs): print(f"Output {req_id}({i}):{output}\n") + engine.terminate() + del engine + if __name__ == "__main__": - test_engine_basic() - test_engine_continuous_batching_1() - test_engine_continuous_batching_2() - test_engine_continuous_batching_3() test_engine_generate() + test_chat_completion() + test_chat_completion_non_stream() + test_completion() + test_completion_non_stream() diff --git a/tests/python/serve/test_serve_engine_grammar.py b/tests/python/serve/test_serve_engine_grammar.py index 45926002ae..7f2a33b230 100644 --- a/tests/python/serve/test_serve_engine_grammar.py +++ b/tests/python/serve/test_serve_engine_grammar.py @@ -7,10 +7,9 @@ import pytest from pydantic import BaseModel -from mlc_llm.serve import Engine, GenerationConfig, KVCacheConfig -from mlc_llm.serve.async_engine import AsyncThreadedEngine +from mlc_llm.serve import AsyncLLMEngine, GenerationConfig from mlc_llm.serve.config import ResponseFormat -from mlc_llm.serve.engine import ModelInfo +from mlc_llm.serve.sync_engine import SyncLLMEngine prompts_list = [ "Generate a JSON string containing 20 objects:", @@ -22,11 +21,8 @@ def test_batch_generation_with_grammar(): - # Initialize model loading info and KV cache config - model = ModelInfo(model_path, model_lib_path=model_lib_path) - kv_cache_config = KVCacheConfig(page_size=16) # Create engine - engine = Engine(model, kv_cache_config) + engine = SyncLLMEngine(model=model_path, model_lib_path=model_lib_path, mode="server") prompt_len = len(prompts_list) prompts = prompts_list * 3 @@ -72,11 +68,8 @@ def test_batch_generation_with_grammar(): def test_batch_generation_with_schema(): - # Initialize model loading info and KV cache config - model = ModelInfo(model_path, model_lib_path=model_lib_path) - kv_cache_config = KVCacheConfig(page_size=16) # Create engine - engine = Engine(model, kv_cache_config) + engine = SyncLLMEngine(model=model_path, model_lib_path=model_lib_path, mode="server") prompt = ( "Generate a json containing three fields: an integer field named size, a " @@ -127,11 +120,8 @@ class Schema(BaseModel): async def run_async_engine(): - # Initialize model loading info and KV cache config - model = ModelInfo(model_path, model_lib_path=model_lib_path) - kv_cache_config = KVCacheConfig(page_size=16) # Create engine - async_engine = AsyncThreadedEngine(model, kv_cache_config, enable_tracing=True) + async_engine = AsyncLLMEngine(model=model_path, model_lib_path=model_lib_path, mode="server") prompts = prompts_list * 20 @@ -152,14 +142,14 @@ async def run_async_engine(): ] async def generate_task( - async_engine: AsyncThreadedEngine, + async_engine: AsyncLLMEngine, prompt: str, generation_cfg: GenerationConfig, request_id: str, ): print(f"Start generation task for request {request_id}") rid = int(request_id) - async for delta_outputs in async_engine.generate( + async for delta_outputs in async_engine._generate( prompt, generation_cfg, request_id=request_id ): assert len(delta_outputs) == generation_cfg.n @@ -185,8 +175,6 @@ async def generate_task( for i, output in enumerate(outputs): print(f"Output {req_id}({i}):{output}\n") - print(async_engine.state.trace_recorder.dump_json(), file=open("tmpfiles/tmp.json", "w")) - async_engine.terminate() diff --git a/tests/python/serve/test_serve_engine_image.py b/tests/python/serve/test_serve_engine_image.py index 5b23a245f9..ff64e7235b 100644 --- a/tests/python/serve/test_serve_engine_image.py +++ b/tests/python/serve/test_serve_engine_image.py @@ -1,33 +1,38 @@ -from mlc_llm.serve import Engine, GenerationConfig, KVCacheConfig, data -from mlc_llm.serve.engine import ModelInfo -from mlc_llm.serve.entrypoints.entrypoint_utils import get_image_from_url +import json +from pathlib import Path +from mlc_llm.serve import GenerationConfig, data +from mlc_llm.serve.sync_engine import SyncLLMEngine -def get_test_image(): - return get_image_from_url("https://llava-vl.github.io/static/images/view.jpg") + +def get_test_image(config) -> data.ImageData: + return data.ImageData.from_url("https://llava-vl.github.io/static/images/view.jpg", config) def test_engine_generate(): - # Initialize model loading info and KV cache config - model = ModelInfo( - "dist/llava-1.5-7b-hf-q4f16_1-MLC/params", - model_lib_path="dist/llava-1.5-7b-hf-q4f16_1-MLC/llava-1.5-7b-hf-q4f16_1-MLC.so", - ) - kv_cache_config = KVCacheConfig(page_size=16, max_total_sequence_length=4096) # Create engine - engine = Engine(model, kv_cache_config) - + model = "dist/llava-1.5-7b-hf-q4f16_1-MLC/params" + model_lib_path = "dist/llava-1.5-7b-hf-q4f16_1-MLC/llava-1.5-7b-hf-q4f16_1-MLC.so" + engine = SyncLLMEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + max_total_sequence_length=4096, + ) max_tokens = 256 + with open(Path(model) / "mlc-chat-config.json", "r", encoding="utf-8") as file: + model_config = json.load(file) + prompts = [ [ data.TextData("USER: "), - data.ImageData(get_test_image(), 576), + get_test_image(model_config), data.TextData("\nWhat does this image represent? ASSISTANT:"), ], [ data.TextData("USER: "), - data.ImageData(get_test_image(), 576), + get_test_image(model_config), data.TextData("\nIs there a dog in this image? ASSISTANT:"), ], [data.TextData("USER: What is the meaning of life? ASSISTANT:")], diff --git a/tests/python/serve/test_serve_engine_spec.py b/tests/python/serve/test_serve_engine_spec.py index 828146afc9..60be02ce1a 100644 --- a/tests/python/serve/test_serve_engine_spec.py +++ b/tests/python/serve/test_serve_engine_spec.py @@ -1,19 +1,17 @@ # pylint: disable=chained-comparison,line-too-long,missing-docstring, -# pylint: disable=too-many-arguments,too-many-locals,unused-argument,unused-variable +# pylint: disable=too-many-arguments,too-many-locals from typing import Callable, List, Optional import numpy as np from mlc_llm.serve import ( - Engine, - EngineMode, GenerationConfig, - KVCacheConfig, Request, RequestStreamOutput, + SpeculativeMode, data, ) -from mlc_llm.serve.engine import ModelInfo +from mlc_llm.serve.sync_engine import SyncLLMEngine prompts = [ "What is the meaning of life?", @@ -68,17 +66,73 @@ def test_engine_basic(): requests + max_tokens - 1). Then check the output of each request. """ - # Initialize model loading info and KV cache config - ssm = ModelInfo( - "dist/Llama-2-7b-chat-hf-q4f16_1-MLC", - model_lib_path="dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so", + # Hyperparameters for tests (you can try different combinations). + num_requests = len(prompts) # [4, 8, 10] + temperature = 0.9 # [0, 0.8, 0.9, 1.0, 1.1] + repetition_penalty = 1.0 # [1.0, 1.01] + max_tokens: int = 256 # [32, 128, 256] + np.random.seed(0) + + # Output list + outputs = [[] for _ in range(num_requests)] + + # Define the callback function for request generation results + def fcallback(delta_outputs: List[RequestStreamOutput]): + for delta_output in delta_outputs: + request_id, stream_outputs = delta_output.unpack() + assert len(stream_outputs) == 1 + outputs[int(request_id)] += stream_outputs[0].delta_token_ids + + # Create engine + model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" + model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + small_model = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" + small_model_lib_path = ( + "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" + ) + engine = SyncLLMEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + max_total_sequence_length=4096, + additional_models=[small_model + ":" + small_model_lib_path], + speculative_mode=SpeculativeMode.SMALL_DRAFT, + request_stream_callback=fcallback, ) - model = ModelInfo( - "dist/Llama-2-7b-chat-hf-q0f16-MLC", - model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", + + # Create requests + requests = create_requests( + num_requests, + temperature=temperature, + repetition_penalty=repetition_penalty, + max_tokens_low=max_tokens, + max_tokens_high=max_tokens + 1, ) - kv_cache_config = KVCacheConfig(page_size=16) - engine_mode = EngineMode(enable_speculative=True) + + # Add all requests to engine + for request in requests: + engine.add_request(request) + + num_steps = num_requests + max_tokens - 1 + # Run steps + for step in range(num_steps): + engine.step() + + for req_id, output in enumerate(outputs): + print(f"Prompt {req_id}: {requests[req_id].inputs[0]}") + print(f"Output {req_id}:{engine.tokenizer.decode(output)}\n") + + +def test_engine_eagle_basic(): + """Test engine **without continuous batching**. + + - Add all requests to the engine altogether in the beginning. + - All requests have the same max_tokens. This means all requests + will end together. + - Engine keeps running `step` for estimated number of steps (number of + requests + max_tokens - 1). Then check the output of each request. + - Use Eagle model as speculative model + """ # Hyperparameters for tests (you can try different combinations). num_requests = len(prompts) # [4, 8, 10] @@ -98,7 +152,22 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): outputs[int(request_id)] += stream_outputs[0].delta_token_ids # Create engine - engine = Engine([model, ssm], kv_cache_config, engine_mode, fcallback) + model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" + model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + small_model = "dist/Eagle-llama2-7b-chat-q0f16-MLC" + small_model_lib_path = ( + "dist/Eagle-llama2-7b-chat-q0f16-MLC/Eagle-llama2-7b-chat-q0f16-MLC-cuda.so" + ) + engine = SyncLLMEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + max_total_sequence_length=4096, + additional_models=[small_model + ":" + small_model_lib_path], + speculative_mode=SpeculativeMode.EAGLE, + spec_draft_length=2, + request_stream_callback=fcallback, + ) # Create requests requests = create_requests( @@ -134,17 +203,91 @@ def test_engine_continuous_batching_1(): of each request. """ - # Initialize model loading info and KV cache config - ssm = ModelInfo( - "dist/Llama-2-7b-chat-hf-q4f16_1-MLC", - model_lib_path="dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so", + # Hyperparameters for tests (you can try different combinations) + num_requests = len(prompts) # [4, 8, 10] + temperature = 0.9 # [0.8, 0.9, 1.0, 1.1] + repetition_penalty = 1.00 # [1.0, 1.01] + max_tokens_low = 128 + max_tokens_high = 384 + np.random.seed(0) + + # Output list + outputs = [[] for _ in range(num_requests)] + finish_time = [None] * num_requests + + # Define the callback class for request generation results + class CallbackTimer: + timer: int = -1 + + def callback_getter(self) -> Callable[[List[RequestStreamOutput]], None]: + def fcallback(delta_outputs: List[RequestStreamOutput]): + for delta_output in delta_outputs: + request_id, stream_outputs = delta_output.unpack() + assert len(stream_outputs) == 1 + if stream_outputs[0].finish_reason is not None: + print(f"Request {request_id} finished at step {self.timer}.") + outputs[int(request_id)] += stream_outputs[0].delta_token_ids + finish_time[int(request_id)] = self.timer + + return fcallback + + def step(self) -> None: + self.timer += 1 + + # Create engine + model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" + model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + small_model = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" + small_model_lib_path = ( + "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" ) - model = ModelInfo( - "dist/Llama-2-7b-chat-hf-q0f16-MLC", - model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", + timer = CallbackTimer() + engine = SyncLLMEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + max_total_sequence_length=4096, + additional_models=[small_model + ":" + small_model_lib_path], + speculative_mode=SpeculativeMode.SMALL_DRAFT, + request_stream_callback=timer.callback_getter(), ) - kv_cache_config = KVCacheConfig(page_size=16) - engine_mode = EngineMode(enable_speculative=True) + + # Create requests + requests = create_requests( + num_requests, + temperature=temperature, + repetition_penalty=repetition_penalty, + max_tokens_low=max_tokens_low, + max_tokens_high=max_tokens_high, + ) + + # Add all requests to engine + for request in requests: + engine.add_request(request) + + num_steps = num_requests + max(request.generation_config.max_tokens for request in requests) - 1 + # Run steps + for step in range(num_steps): + timer.step() + assert timer.timer == step + engine.step() + + for req_id, (request, output, fin_time) in enumerate(zip(requests, outputs, finish_time)): + print(f"Prompt {req_id}: {request.inputs[0]}") + print(f"Output {req_id}:{engine.tokenizer.decode(output)}\n") + # assert fin_time == request.generation_config.max_tokens - 1 + + +def test_engine_eagle_continuous_batching_1(): + """Test engine **with continuous batching**. + + - Add all requests to the engine altogether in the beginning. + - All requests have a random maximum generation length. So each + request keeps generating until reaching the maximum length. + - Engine keeps running `step` for estimated number of steps (number of + requests + the maximum max_tokens - 1). Then check the output + of each request. + """ # Hyperparameters for tests (you can try different combinations) num_requests = len(prompts) # [4, 8, 10] @@ -178,8 +321,22 @@ def step(self) -> None: self.timer += 1 # Create engine + model = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" + model_lib_path = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" + small_model = "dist/Eagle-llama2-7b-chat-q4f16_1-MLC" + small_model_lib_path = ( + "dist/Eagle-llama2-7b-chat-q4f16_1-MLC/Eagle-llama2-7b-chat-q4f16_1-MLC-cuda.so" + ) timer = CallbackTimer() - engine = Engine([model, ssm], kv_cache_config, engine_mode, timer.callback_getter()) + engine = SyncLLMEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + max_total_sequence_length=4096, + additional_models=[small_model + ":" + small_model_lib_path], + speculative_mode=SpeculativeMode.EAGLE, + request_stream_callback=timer.callback_getter(), + ) # Create requests requests = create_requests( @@ -208,19 +365,54 @@ def step(self) -> None: def test_engine_generate(): - # Initialize model loading info and KV cache config - ssm = ModelInfo( - "dist/Llama-2-7b-chat-hf-q4f16_1-MLC", - model_lib_path="dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so", + # Create engine + model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" + model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + small_model = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" + small_model_lib_path = ( + "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" + ) + engine = SyncLLMEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + max_total_sequence_length=4096, + additional_models=[small_model + ":" + small_model_lib_path], + speculative_mode=SpeculativeMode.SMALL_DRAFT, ) - model = ModelInfo( - "dist/Llama-2-7b-chat-hf-q0f16-MLC", - model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", + + num_requests = 10 + max_tokens = 256 + + # Generate output. + output_texts, _ = engine.generate( + prompts[:num_requests], GenerationConfig(max_tokens=max_tokens, n=3) ) - kv_cache_config = KVCacheConfig(page_size=16) - engine_mode = EngineMode(enable_speculative=True) + for req_id, outputs in enumerate(output_texts): + print(f"Prompt {req_id}: {prompts[req_id]}") + if len(outputs) == 1: + print(f"Output {req_id}:{outputs[0]}\n") + else: + for i, output in enumerate(outputs): + print(f"Output {req_id}({i}):{output}\n") + + +def test_engine_eagle_generate(): # Create engine - engine = Engine([model, ssm], kv_cache_config, engine_mode) + model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" + model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + small_model = "dist/Eagle-llama2-7b-chat-q4f16_1-MLC" + small_model_lib_path = ( + "dist/Eagle-llama2-7b-chat-q4f16_1-MLC/Eagle-llama2-7b-chat-q4f16_1-MLC-cuda.so" + ) + engine = SyncLLMEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + max_total_sequence_length=4096, + additional_models=[small_model + ":" + small_model_lib_path], + speculative_mode=SpeculativeMode.EAGLE, + ) num_requests = 10 max_tokens = 256 @@ -241,13 +433,6 @@ def test_engine_generate(): def test_engine_efficiency(): """Test engine speculative decoding efficiency.""" - # Initialize model loading info and KV cache config - model = ModelInfo( - "dist/Llama-2-13b-chat-hf-q4f16_1-MLC", - model_lib_path="dist/Llama-2-13b-chat-hf-q4f16_1-MLC/Llama-2-13b-chat-hf-q4f16_1-MLC-cuda.so", - ) - kv_cache_config = KVCacheConfig(page_size=16) - # Hyperparameters for tests (you can try different combinations). num_requests = 1 # [4, 8, 10] temperature = 0.9 # [0, 0.8, 0.9, 1.0, 1.1] @@ -266,7 +451,15 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): outputs[int(request_id)] += stream_outputs[0].delta_token_ids # Create engine - engine = Engine(model, kv_cache_config, request_stream_callback=fcallback) + model = "dist/Llama-2-13b-chat-hf-q4f16_1-MLC" + model_lib_path = "dist/Llama-2-13b-chat-hf-q4f16_1-MLC/Llama-2-13b-chat-hf-q4f16_1-MLC-cuda.so" + engine = SyncLLMEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + max_total_sequence_length=4096, + request_stream_callback=fcallback, + ) # Create requests requests = create_requests( @@ -303,22 +496,80 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): def test_engine_spec_efficiency(): """Test engine speculative decoding efficiency.""" - # Initialize model loading info and KV cache config - ssm = ModelInfo( - "dist/Llama-2-7b-chat-hf-q4f16_1-MLC", - model_lib_path="dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so", + # Hyperparameters for tests (you can try different combinations). + num_requests = 1 # [4, 8, 10] + temperature = 0.9 # [0, 0.8, 0.9, 1.0, 1.1] + repetition_penalty = 1.0 # [1.0, 1.01] + max_tokens: int = 512 + np.random.seed(0) + + # Output list + outputs = [[] for _ in range(num_requests)] + + # Define the callback function for request generation results + def fcallback(delta_outputs: List[RequestStreamOutput]): + for delta_output in delta_outputs: + request_id, stream_outputs = delta_output.unpack() + assert len(stream_outputs) == 1 + outputs[int(request_id)] += stream_outputs[0].delta_token_ids + + # Create engine + model = "dist/Llama-2-13b-chat-hf-q4f16_1-MLC" + model_lib_path = "dist/Llama-2-13b-chat-hf-q4f16_1-MLC/Llama-2-13b-chat-hf-q4f16_1-MLC-cuda.so" + small_model = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" + small_model_lib_path = ( + "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" ) # If Flashinfer allows head_dim < 128, we can test this model - # ssm = ModelInfo( - # "dist/TinyLlama-1.1B-Chat-v1.0-q0f16-MLC", - # model_lib_path="dist/TinyLlama-1.1B-Chat-v1.0-q0f16-MLC/TinyLlama-1.1B-Chat-v1.0-q0f16-MLC-cuda.so", + # small_model = "dist/TinyLlama-1.1B-Chat-v1.0-q0f16-MLC" + # small_model_lib_path = ( + # "dist/TinyLlama-1.1B-Chat-v1.0-q0f16-MLC/TinyLlama-1.1B-Chat-v1.0-q0f16-MLC-cuda.so" # ) - model = ModelInfo( - "dist/Llama-2-13b-chat-hf-q4f16_1-MLC", - model_lib_path="dist/Llama-2-13b-chat-hf-q4f16_1-MLC/Llama-2-13b-chat-hf-q4f16_1-MLC-cuda.so", + spec_engine = SyncLLMEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + max_total_sequence_length=4096, + additional_models=[small_model + ":" + small_model_lib_path], + spec_draft_length=6, + speculative_mode=SpeculativeMode.SMALL_DRAFT, + request_stream_callback=fcallback, + ) + + # Create requests + requests = create_requests( + num_requests, + temperature=temperature, + repetition_penalty=repetition_penalty, + max_tokens_low=max_tokens, + max_tokens_high=max_tokens + 1, ) - kv_cache_config = KVCacheConfig(page_size=16) - engine_mode = EngineMode(enable_speculative=True, spec_draft_length=6) + + # Add all requests to engine + for request in requests: + spec_engine.add_request(request) + + num_steps = num_requests + max_tokens - 1 + # Run steps + for step in range(num_steps): + spec_engine.step() + + for eg, name in zip([spec_engine], ["Speculative Decoding"]): + stats = eg.stats() + print("engine name:", name) + if name == "Speculative Decoding": + print("total draft tokens:", stats["total_draft_tokens"]) + print("total accepted tokens:", stats["total_accepted_tokens"]) + print( + "Accept rate:", + stats["total_accepted_tokens"] / (1e-10 + stats["total_draft_tokens"]), + ) + print("engine total decode time:", stats["engine_total_decode_time"]) + print() + + +def test_engine_eagle_spec_efficiency(): + """Test engine speculative decoding efficiency.""" # Hyperparameters for tests (you can try different combinations). num_requests = 1 # [4, 8, 10] @@ -338,7 +589,22 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): outputs[int(request_id)] += stream_outputs[0].delta_token_ids # Create engine - spec_engine = Engine([model, ssm], kv_cache_config, engine_mode, fcallback) + model = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" + model_lib_path = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" + small_model = "dist/Eagle-llama2-7b-chat-q0f16-MLC" + small_model_lib_path = ( + "dist/Eagle-llama2-7b-chat-q0f16-MLC/Eagle-llama2-7b-chat-q0f16-MLC-cuda.so" + ) + spec_engine = SyncLLMEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + max_total_sequence_length=4096, + additional_models=[small_model + ":" + small_model_lib_path], + spec_draft_length=6, + speculative_mode=SpeculativeMode.EAGLE, + request_stream_callback=fcallback, + ) # Create requests requests = create_requests( @@ -374,7 +640,11 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): if __name__ == "__main__": test_engine_basic() + test_engine_eagle_basic() test_engine_continuous_batching_1() + test_engine_eagle_continuous_batching_1() test_engine_generate() + test_engine_eagle_generate() test_engine_efficiency() test_engine_spec_efficiency() + test_engine_eagle_spec_efficiency() diff --git a/tests/python/serve/test_serve_sync_engine.py b/tests/python/serve/test_serve_sync_engine.py new file mode 100644 index 0000000000..c5d521b02d --- /dev/null +++ b/tests/python/serve/test_serve_sync_engine.py @@ -0,0 +1,396 @@ +# pylint: disable=chained-comparison,line-too-long,missing-docstring, +# pylint: disable=too-many-arguments,too-many-locals,unused-argument,unused-variable +from typing import Callable, List, Optional + +import numpy as np + +from mlc_llm.serve import GenerationConfig, Request, RequestStreamOutput, data +from mlc_llm.serve.sync_engine import SyncLLMEngine + +prompts = [ + "What is the meaning of life?", + "Introduce the history of Pittsburgh to me. Please elaborate in detail.", + "Write a three-day Seattle travel plan. Please elaborate in detail.", + "What is Alaska famous of? Please elaborate in detail.", + "What is the difference between Lambda calculus and Turing machine? Please elaborate in detail.", + "What are the necessary components to assemble a desktop computer? Please elaborate in detail.", + "Why is Vitamin D important to human beings? Please elaborate in detail.", + "Where is milk tea originated from? Please elaborate in detail.", + "Where is the southernmost place in United States? Please elaborate in detail.", + "Do you know AlphaGo? What capabilities does it have, and what achievements has it got? Please elaborate in detail.", +] + + +def create_requests( + num_requests: int, + stop_token_id: Optional[int] = None, + temperature: float = 0.8, + repetition_penalty: float = 1.0, + max_tokens_low: int = 256, + max_tokens_high: int = 257, +) -> List[Request]: + assert num_requests >= 0 and num_requests <= len(prompts) + + stop_token_ids = [stop_token_id] if stop_token_id is not None else [] + requests = [] + for req_id, prompt in zip(range(num_requests), prompts): + max_tokens = np.random.randint(max_tokens_low, max_tokens_high) + requests.append( + Request( + request_id=str(req_id), + inputs=data.TextData(prompt), + generation_config=GenerationConfig( + temperature=temperature, + repetition_penalty=repetition_penalty, + max_tokens=max_tokens, + stop_token_ids=stop_token_ids, + ), + ) + ) + return requests + + +def test_engine_basic(): + """Test engine **without continuous batching**. + + - Add all requests to the engine altogether in the beginning. + - All requests have the same max_tokens. This means all requests + will end together. + - Engine keeps running `step` for estimated number of steps (number of + requests + max_tokens - 1). Then check the output of each request. + """ + + # Hyperparameters for tests (you can try different combinations). + num_requests = 10 # [4, 8, 10] + temperature = 0.9 # [0, 0.8, 0.9, 1.0, 1.1] + repetition_penalty = 1.0 # [1.0, 1.01] + max_tokens: int = 256 # [32, 128, 256] + np.random.seed(0) + + # Output list + outputs = [[] for _ in range(num_requests)] + + # Define the callback function for request generation results + def fcallback(delta_outputs: List[RequestStreamOutput]): + for delta_output in delta_outputs: + request_id, stream_outputs = delta_output.unpack() + assert len(stream_outputs) == 1 + outputs[int(request_id)] += stream_outputs[0].delta_token_ids + + # Create engine + model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" + model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + engine = SyncLLMEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + request_stream_callback=fcallback, + ) + + # Create requests + requests = create_requests( + num_requests, + temperature=temperature, + repetition_penalty=repetition_penalty, + max_tokens_low=max_tokens, + max_tokens_high=max_tokens + 1, + ) + + # Add all requests to engine + for request in requests: + engine.add_request(request) + + num_steps = num_requests + max_tokens - 1 + # Run steps + for step in range(num_steps): + engine.step() + + for req_id, output in enumerate(outputs): + print(f"Prompt {req_id}: {requests[req_id].inputs[0]}") + print(f"Output {req_id}:{engine.tokenizer.decode(output)}\n") + + +def test_engine_continuous_batching_1(): + """Test engine **with continuous batching**. + + - Add all requests to the engine altogether in the beginning. + - All requests have a random maximum generation length. So each + request keeps generating until reaching the maximum length. + - Engine keeps running `step` for estimated number of steps (number of + requests + the maximum max_tokens - 1). Then check the output + of each request. + """ + + # Hyperparameters for tests (you can try different combinations) + num_requests = 10 # [4, 8, 10] + temperature = 0.9 # [0.8, 0.9, 1.0, 1.1] + repetition_penalty = 1.00 # [1.0, 1.01] + max_tokens_low = 128 + max_tokens_high = 384 + np.random.seed(0) + + # Output list + outputs = [[] for _ in range(num_requests)] + finish_time = [None] * num_requests + + # Define the callback class for request generation results + class CallbackTimer: + timer: int = -1 + + def callback_getter(self) -> Callable[[List[RequestStreamOutput]], None]: + def fcallback(delta_outputs: List[RequestStreamOutput]): + for delta_output in delta_outputs: + request_id, stream_outputs = delta_output.unpack() + assert len(stream_outputs) == 1 + if stream_outputs[0].finish_reason is not None: + print(f"Request {request_id} finished at step {self.timer}.") + outputs[int(request_id)] += stream_outputs[0].delta_token_ids + finish_time[int(request_id)] = self.timer + + return fcallback + + def step(self) -> None: + self.timer += 1 + + # Create engine + timer = CallbackTimer() + model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" + model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + engine = SyncLLMEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + request_stream_callback=timer.callback_getter(), + ) + + # Create requests + requests = create_requests( + num_requests, + temperature=temperature, + repetition_penalty=repetition_penalty, + max_tokens_low=max_tokens_low, + max_tokens_high=max_tokens_high, + ) + + # Add all requests to engine + for request in requests: + engine.add_request(request) + + num_steps = num_requests + max(request.generation_config.max_tokens for request in requests) - 1 + # Run steps + for step in range(num_steps): + timer.step() + assert timer.timer == step + engine.step() + + for req_id, (request, output, fin_time) in enumerate(zip(requests, outputs, finish_time)): + print(f"Prompt {req_id}: {request.inputs[0]}") + print(f"Output {req_id}:{engine.tokenizer.decode(output)}\n") + assert ( + fin_time == request.generation_config.max_tokens - 1 + ), f"finish time = {fin_time}, max tokens = {request.generation_config.max_tokens - 1}" + + +def test_engine_continuous_batching_2(): + """Test engine **with continuous batching**. + + - Add all requests to the engine altogether in the beginning. + - All requests have the stop token. So each request keeps generating + until having the stop token or reaching the maximum length. + - Engine keeps running `step` for estimated number of steps (number of + requests + the maximum max_tokens - 1). Then check the output + of each request. + """ + + # Hyperparameters for tests (you can try different combinations) + num_requests = 10 # [4, 8, 10] + temperature = 0.9 # [0.8, 0.9, 1.0, 1.1] + repetition_penalty = 1.00 # [1.0, 1.01] + stop_token_id = 2 + max_tokens = 512 + np.random.seed(0) + + # Output list + outputs = [[] for _ in range(num_requests)] + finish_time = [None] * num_requests + + # Define the callback class for request generation results + class CallbackTimer: + timer: int = -1 + + def callback_getter(self) -> Callable[[List[RequestStreamOutput]], None]: + def fcallback(delta_outputs: List[RequestStreamOutput]): + for delta_output in delta_outputs: + request_id, stream_outputs = delta_output.unpack() + assert len(stream_outputs) == 1 + if stream_outputs[0].finish_reason is not None: + print(f"Request {request_id} finished at step {self.timer}.") + outputs[int(request_id)] += stream_outputs[0].delta_token_ids + finish_time[int(request_id)] = self.timer + + return fcallback + + def step(self) -> None: + self.timer += 1 + + # Create engine + timer = CallbackTimer() + model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" + model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + engine = SyncLLMEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + request_stream_callback=timer.callback_getter(), + ) + + # Create requests + requests = create_requests( + num_requests, + stop_token_id=stop_token_id, + temperature=temperature, + repetition_penalty=repetition_penalty, + max_tokens_low=max_tokens, + max_tokens_high=max_tokens + 1, + ) + + # Add all requests to engine + for request in requests: + engine.add_request(request) + + num_steps = num_requests + max_tokens - 1 + # Run steps + for step in range(num_steps): + timer.step() + assert timer.timer == step + engine.step() + + for req_id, (request, output, fin_time) in enumerate(zip(requests, outputs, finish_time)): + print(f"Prompt {req_id}: {request.inputs[0]}") + if fin_time < num_requests + max_tokens - 2: + print(f"Request {req_id} ends early on the stop token") + print(f"Output {req_id}:{engine.tokenizer.decode(output)}\n") + + +def test_engine_continuous_batching_3(): + """Test engine **with continuous batching**. + + - Add requests randomly between time [0, 200). + - All requests have a random maximum generation length. So each + request keeps generating until reaching the maximum length. + - Engine keeps running `step` until all requests finish. + Then check the output of each request. + """ + + # Hyperparameters for tests (you can try different combinations) + num_requests = 10 # [4, 8, 10] + temperature = 0.9 # [0.8, 0.9, 1.0, 1.1] + repetition_penalty = 1.00 # [1.0, 1.01] + stop_token_id = 2 + max_tokens_low = 64 + max_tokens_high = 192 + np.random.seed(0) + + # Output list + outputs = [[] for _ in range(num_requests)] + finish_time = [None] * num_requests + + # Define the callback class for request generation results + class CallbackTimer: + timer: int = -1 + finished_requests: int = 0 + + def callback_getter(self) -> Callable[[List[RequestStreamOutput]], None]: + def fcallback(delta_outputs: List[RequestStreamOutput]): + for delta_output in delta_outputs: + request_id, stream_outputs = delta_output.unpack() + assert len(stream_outputs) == 1 + if stream_outputs[0].finish_reason is not None: + print(f"Request {request_id} finished at step {self.timer}.") + self.finished_requests += 1 + outputs[int(request_id)] += stream_outputs[0].delta_token_ids + finish_time[int(request_id)] = self.timer + + return fcallback + + def step(self) -> None: + self.timer += 1 + + def all_finished(self) -> bool: + return self.finished_requests == num_requests + + # Create engine + timer = CallbackTimer() + model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" + model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + engine = SyncLLMEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + request_stream_callback=timer.callback_getter(), + ) + + # Create requests + requests = create_requests( + num_requests, + stop_token_id=stop_token_id, + temperature=temperature, + repetition_penalty=repetition_penalty, + max_tokens_low=max_tokens_low, + max_tokens_high=max_tokens_high, + ) + + # Assign the time to add requests to engine + request_add_time = [np.random.randint(0, 200) for _ in range(num_requests)] + + # Run steps + while not timer.all_finished(): + timer.step() + + # Add requests to engine + for req_id, add_time in enumerate(request_add_time): + if add_time == timer.timer: + print(f"add request {req_id} at step {timer.timer}") + engine.add_request(requests[req_id]) + + engine.step() + + for req_id, (request, output, fin_time) in enumerate(zip(requests, outputs, finish_time)): + print(f"Prompt {req_id}: {request.inputs[0]}") + print(f"Finish time: {fin_time}") + print(f"Output {req_id}:{engine.tokenizer.decode(output)}\n") + + +def test_engine_generate(): + # Create engine + model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" + model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + engine = SyncLLMEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + max_total_sequence_length=4096, + ) + + num_requests = 10 + max_tokens = 256 + + # Generate output. + output_texts, _ = engine.generate( + prompts[:num_requests], GenerationConfig(max_tokens=max_tokens, n=7) + ) + for req_id, outputs in enumerate(output_texts): + print(f"Prompt {req_id}: {prompts[req_id]}") + if len(outputs) == 1: + print(f"Output {req_id}:{outputs[0]}\n") + else: + for i, output in enumerate(outputs): + print(f"Output {req_id}({i}):{output}\n") + + +if __name__ == "__main__": + test_engine_basic() + test_engine_continuous_batching_1() + test_engine_continuous_batching_2() + test_engine_continuous_batching_3() + test_engine_generate()