From 7ab58f70a289043fd951ecf1ff7a6714da0d95da Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Tue, 19 Mar 2024 14:33:25 -0500 Subject: [PATCH] Upstream sync 2024 03 18 (#134) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit SUMMARY: * upstream merge (sync) up to `93348d9458af7517bb8c114611d438a1b4a2c3be` * some minor changes related to `ruff` and `yapf` NOTES: skipped flaky lora gemma test TEST PLAN: ran nightly, passed all except gemma running now on remote push --------- Signed-off-by: Tao He Signed-off-by: Yuan Tang Signed-off-by: Sherlock113 Co-authored-by: Ronen Schaffer Co-authored-by: Mustafa Eyceoz Co-authored-by: Roy Co-authored-by: Woosuk Kwon Co-authored-by: Massimiliano Pronesti Co-authored-by: 44670 <44670@users.noreply.github.com> Co-authored-by: zhaoyang-star Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Jared Moore <27744679+jlcmoore@users.noreply.github.com> Co-authored-by: Philipp Moritz Co-authored-by: Cade Daniel Co-authored-by: 张大成 <1345739055@qq.com> Co-authored-by: zhangdacheng Co-authored-by: Jingru Co-authored-by: Dylan Hawk <51147702+dylanwhawk@users.noreply.github.com> Co-authored-by: Tao He Co-authored-by: Ganesh Jagadeesan Co-authored-by: Allen.Dou Co-authored-by: Liangfu Chen Co-authored-by: CHU Tianxiang Co-authored-by: Jae-Won Chung Co-authored-by: Seonghyeon Co-authored-by: Billy Cao Co-authored-by: Nick Hill Co-authored-by: felixzhu555 <79335195+felixzhu555@users.noreply.github.com> Co-authored-by: br3no Co-authored-by: simon-mo Co-authored-by: Sherry <503147114@qq.com> Co-authored-by: Yuan Tang Co-authored-by: Huarong Co-authored-by: huohuarong Co-authored-by: Robert Shaw <114415538+rib-2@users.noreply.github.com> Co-authored-by: alexm Co-authored-by: zixiao Co-authored-by: cloudhan Co-authored-by: Sage Moore Co-authored-by: ElizaWszola Co-authored-by: Michael Goin Co-authored-by: Jason Cox Co-authored-by: Zhuohan Li Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com> Co-authored-by: TianYu GUO Co-authored-by: Jialun Lyu <43287111+pian13131@users.noreply.github.com> Co-authored-by: ttbachyinsda Co-authored-by: guofangze Co-authored-by: Antoni Baum Co-authored-by: Avnish Narayan Co-authored-by: Chen Wang Co-authored-by: Hongxia Yang <62075498+hongxiayang@users.noreply.github.com> Co-authored-by: lcskrishna Co-authored-by: SangBin Cho Co-authored-by: Chujie Zheng Co-authored-by: TechxGenus Co-authored-by: Michael Goin Co-authored-by: jacobthebanana <50071502+jacobthebanana@users.noreply.github.com> Co-authored-by: whyiug Co-authored-by: Terry <149540247+tterrysun@users.noreply.github.com> Co-authored-by: Douglas Lehr <91553416+dllehr-amd@users.noreply.github.com> Co-authored-by: kliuae <17350011+kliuae@users.noreply.github.com> Co-authored-by: DAIZHENWEI <32122197+DAIZHENWEI@users.noreply.github.com> Co-authored-by: Sherlock Xu <65327072+Sherlock113@users.noreply.github.com> Co-authored-by: Bo-Wen Wang <1849994161@qq.com> Co-authored-by: Ronan McGovern <78278410+RonanKMcGovern@users.noreply.github.com> Co-authored-by: Hui Liu <96135754+hliuca@users.noreply.github.com> Co-authored-by: 陈序 Co-authored-by: Or Sharir Co-authored-by: youkaichao Co-authored-by: Thomas Parnell Co-authored-by: Dan Clark <44146800+declark1@users.noreply.github.com> Co-authored-by: Daniel Clark Co-authored-by: youkaichao Co-authored-by: Enrique Shockwave <33002121+qeternity@users.noreply.github.com> Co-authored-by: akhoroshev Co-authored-by: Dinghow Yang Co-authored-by: Junda Chen <32371474+GindaChen@users.noreply.github.com> Co-authored-by: Yang Fan Co-authored-by: laneeee <55518470+laneeeee@users.noreply.github.com> --- .buildkite/test-pipeline.yaml | 15 +- .buildkite/test-template.j2 | 3 + .github/ISSUE_TEMPLATE/100-documentation.yml | 2 +- .../ISSUE_TEMPLATE/500-feature request.yml | 2 +- .github/PULL_REQUEST_TEMPLATE.md | 60 ++++++ .github/workflows/ruff.yml | 4 +- CONTRIBUTING.md | 26 +-- benchmarks/backend_request_func.py | 21 +- benchmarks/benchmark_serving.py | 6 +- collect_env.py | 185 ++++++++++-------- csrc/moe_align_block_size_kernels.cu | 42 ++-- csrc/punica/bgmv/generator.py | 2 +- docs/source/dev/kernel/paged_attention.rst | 2 +- examples/multilora_inference.py | 69 ++++--- examples/offline_inference_with_prefix.py | 7 +- examples/template_baichuan.jinja | 29 +-- examples/template_chatglm.jinja | 18 ++ examples/template_chatglm2.jinja | 18 ++ examples/template_falcon.jinja | 15 ++ examples/template_falcon_180b.jinja | 17 ++ pyproject.toml | 2 - requirements-dev.txt | 1 + setup.py | 64 ++++-- tests/async_engine/test_api_server.py | 16 +- tests/conftest.py | 11 ++ tests/entrypoints/test_openai_server.py | 50 +++++ tests/kernels/test_prefix_prefill.py | 11 +- tests/lora/test_gemma.py | 2 +- tests/lora/test_tokenizer.py | 69 ------- tests/lora/test_tokenizer_group.py | 53 +++++ tests/models/test_marlin.py | 1 + tests/test_cache_block_hashing.py | 8 +- tests/tokenization/__init__.py | 0 tests/tokenization/test_cached_tokenizer.py | 20 ++ .../test_detokenize.py | 0 tests/tokenization/test_tokenizer_group.py | 100 ++++++++++ vllm/__init__.py | 1 + vllm/config.py | 58 ++++++ vllm/core/scheduler.py | 2 +- vllm/engine/arg_utils.py | 65 ++++-- vllm/engine/async_llm_engine.py | 3 +- vllm/engine/llm_engine.py | 19 +- vllm/entrypoints/api_server.py | 11 +- vllm/entrypoints/openai/api_server.py | 15 +- vllm/entrypoints/openai/protocol.py | 9 + vllm/entrypoints/openai/serving_chat.py | 4 +- vllm/entrypoints/openai/serving_completion.py | 5 +- vllm/model_executor/guided_decoding.py | 54 +++-- .../guided_logits_processors.py | 112 +++++++---- vllm/model_executor/input_metadata.py | 49 ++--- .../layers/quantization/marlin.py | 2 +- vllm/model_executor/models/qwen2.py | 14 +- .../parallel_utils/communication_op.py | 2 +- vllm/transformers_utils/tokenizer.py | 99 ++++------ .../tokenizer_group/__init__.py | 32 +++ .../tokenizer_group/base_tokenizer_group.py | 48 +++++ .../tokenizer_group/ray_tokenizer_group.py | 166 ++++++++++++++++ .../tokenizer_group/tokenizer_group.py | 80 ++++++++ vllm/utils.py | 26 ++- vllm/worker/model_runner.py | 37 +--- 60 files changed, 1351 insertions(+), 513 deletions(-) create mode 100644 .github/PULL_REQUEST_TEMPLATE.md create mode 100644 examples/template_chatglm.jinja create mode 100644 examples/template_chatglm2.jinja create mode 100644 examples/template_falcon.jinja create mode 100644 examples/template_falcon_180b.jinja delete mode 100644 tests/lora/test_tokenizer.py create mode 100644 tests/lora/test_tokenizer_group.py create mode 100644 tests/tokenization/__init__.py create mode 100644 tests/tokenization/test_cached_tokenizer.py rename tests/{engine => tokenization}/test_detokenize.py (100%) create mode 100644 tests/tokenization/test_tokenizer_group.py create mode 100644 vllm/transformers_utils/tokenizer_group/__init__.py create mode 100644 vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py create mode 100644 vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py create mode 100644 vllm/transformers_utils/tokenizer_group/tokenizer_group.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 42a1eacb6de57..2c7dd9f304b9d 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -13,7 +13,7 @@ steps: - label: Basic Correctness Test command: pytest -v -s --forked basic_correctness - + - label: Core Test command: pytest -v -s core @@ -28,14 +28,14 @@ steps: num_gpus: 2 # only support 1 or 2 for now. - label: Engine Test - command: pytest -v -s engine test_sequence.py + command: pytest -v -s engine tokenization test_sequence.py - label: Entrypoints Test command: pytest -v -s entrypoints -- label: Kernels Test - command: pytest -v -s kernels - soft_fail: true +- label: Kernels Test %N + command: pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT + parallelism: 4 - label: Models Test commands: @@ -55,8 +55,9 @@ steps: - label: Speculative decoding tests command: pytest -v -s spec_decode -- label: LoRA Test - command: pytest -v -s lora --forked +- label: LoRA Test %N + command: pytest -v -s lora --forked --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT + parallelism: 4 - label: Metrics Test command: pytest -v -s metrics diff --git a/.buildkite/test-template.j2 b/.buildkite/test-template.j2 index 7c1cf2b5a9b39..b5853a2f39383 100644 --- a/.buildkite/test-template.j2 +++ b/.buildkite/test-template.j2 @@ -20,6 +20,9 @@ steps: agents: queue: kubernetes soft_fail: {{ step.soft_fail or false }} + {% if step.parallelism %} + parallelism: {{ step.parallelism }} + {% endif %} retry: automatic: - exit_status: -1 # Agent was lost diff --git a/.github/ISSUE_TEMPLATE/100-documentation.yml b/.github/ISSUE_TEMPLATE/100-documentation.yml index 7ef052a525963..501c0aa48b887 100644 --- a/.github/ISSUE_TEMPLATE/100-documentation.yml +++ b/.github/ISSUE_TEMPLATE/100-documentation.yml @@ -1,7 +1,7 @@ name: 📚 Documentation description: Report an issue related to https://docs.vllm.ai/ title: "[Doc]: " -labels: ["doc"] +labels: ["documentation"] body: - type: textarea diff --git a/.github/ISSUE_TEMPLATE/500-feature request.yml b/.github/ISSUE_TEMPLATE/500-feature request.yml index 0dd5a3e5d14de..47a90628c76ce 100644 --- a/.github/ISSUE_TEMPLATE/500-feature request.yml +++ b/.github/ISSUE_TEMPLATE/500-feature request.yml @@ -1,7 +1,7 @@ name: 🚀 Feature request description: Submit a proposal/request for a new vllm feature title: "[Feature]: " -labels: ["feature"] +labels: ["feature request"] body: - type: markdown diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000000000..46fda7eeef55e --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,60 @@ +
+ + PR Checklist (Click to expand. Please read before submitting.) + +

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

+ +

PR Title and Classification

+

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

+
    +
  • [Bugfix] for bug fixes.
  • +
  • [CI/Build] for build or continuous integration improvements.
  • +
  • [Doc] for documentation fixes and improvements.
  • +
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • +
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • +
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • +
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • +
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • +
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.
  • +
+

Note: If the PR spans more than one category, please include all relevant prefixes.

+ +

Code Quality

+ +

The PR need to meet the following code quality standards:

+ +
    +
  • We adhere to Google Python style guide and Google C++ style guide.
  • +
  • Pass all linter checks. Please use format.sh to format your code.
  • +
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • +
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • +
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.
  • +
+ +

Notes for Large Changes

+

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

+ +

What to Expect for the Reviews

+ +

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

+ +
    +
  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • +
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • +
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • +
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion. +
  • +
+ +

Thank You

+ +

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

+ + +
+ +--- + +Please provide a brief explanation of the motivation behind the PR and the changes it introduces. This helps reviewers understand the context and rationale for the contribution. If possible, please link existing issues this PR will resolve. + + diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml index 8f8f5ee3cc70c..cd16cecf21546 100644 --- a/.github/workflows/ruff.yml +++ b/.github/workflows/ruff.yml @@ -28,7 +28,7 @@ jobs: pip install ruff==0.1.5 codespell==2.2.6 tomli==2.0.1 - name: Analysing the code with ruff run: | - ruff vllm tests + ruff . - name: Spelling check with codespell run: | - codespell --toml pyproject.toml \ No newline at end of file + codespell --toml pyproject.toml \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 93a4de73faa89..8db5e569b6aec 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -45,31 +45,9 @@ pytest tests/ If you encounter a bug or have a feature request, please check our issues page first to see if someone else has already reported it. If not, please file a new issue, providing as much relevant information as possible. -### Coding Style Guide +### Pull Requests & Code Reviews -In general, we adhere to [Google Python style guide](https://google.github.io/styleguide/pyguide.html) and [Google C++ style guide](https://google.github.io/styleguide/cppguide.html). - -We include a formatting script [`format.sh`](./format.sh) to format the code. - -### Pull Requests - -When submitting a pull request: - -1. Make sure your code has been rebased on top of the latest commit on the main branch. -2. Ensure code is properly formatted by running [`format.sh`](./format.sh). -3. Include a detailed description of the changes in the pull request. -Explain why you made the changes you did. -If your pull request fixes an open issue, please include a reference to it in the description. - -### Code Reviews - -All submissions, including submissions by project members, require a code review. -To make the review process as smooth as possible, please: - -1. Keep your changes as concise as possible. -If your pull request involves multiple unrelated changes, consider splitting it into separate pull requests. -2. Respond to all comments within a reasonable time frame. -If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion. +Please check the PR checklist in the [PR template](.github/PULL_REQUEST_TEMPLATE.md) for detailed guide for contribution. ### Thank You diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index 8782f5546b21e..0745d3129a4f0 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -68,7 +68,7 @@ async def async_request_tgi( output.ttft = ttft output.latency = time.perf_counter() - st - body = data.decode("utf-8").lstrip("data:") # noqa + body = remove_prefix(data.decode("utf-8"), "data:") output.generated_text = json.loads(body)["generated_text"] output.success = True else: @@ -114,7 +114,7 @@ async def async_request_vllm( output.ttft = ttft output.latency = time.perf_counter() - st - # When streaming, '\0' is appended to the end of the response. + # When streaming, '\0' is appended to the end of response. body = data.decode("utf-8").strip("\0") output.generated_text = json.loads( body)["text"][0][len(request_func_input.prompt):] @@ -162,7 +162,7 @@ async def async_request_trt_llm( output.ttft = ttft output.latency = time.perf_counter() - st - body = data.decode("utf-8").lstrip("data:") # noqa + body = remove_prefix(data.decode("utf-8"), "data:") output.generated_text = json.loads(body)["text_output"] output.success = True @@ -196,7 +196,8 @@ async def async_request_deepspeed_mii( output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len - # DeepSpeed-MII doesn't support streaming as of Jan 28 2024, will use 0 as placeholder. + # DeepSpeed-MII doesn't support streaming as of Jan 28 2024, + # will use 0 as placeholder. # https://github.com/microsoft/DeepSpeed-MII/pull/311 output.ttft = 0 @@ -259,7 +260,7 @@ async def async_request_openai_completions( if not chunk: continue - chunk = chunk.decode("utf-8").lstrip("data: ") # noqa + chunk = remove_prefix(chunk.decode("utf-8"), "data: ") if chunk == "[DONE]": latency = time.perf_counter() - st else: @@ -326,7 +327,7 @@ async def async_request_openai_chat_completions( if not chunk: continue - chunk = chunk.decode("utf-8").lstrip("data: ") + chunk = remove_prefix(chunk.decode("utf-8"), "data: ") if chunk == "[DONE]": latency = time.perf_counter() - st else: @@ -348,6 +349,14 @@ async def async_request_openai_chat_completions( return output +# Since vllm must support Python 3.8, we can't use str.removeprefix(prefix) +# introduced in Python 3.9 +def remove_prefix(text: str, prefix: str) -> str: + if text.startswith(prefix): + return text[len(prefix):] + return text + + ASYNC_REQUEST_FUNCS = { "tgi": async_request_tgi, "vllm": async_request_vllm, diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 7699304769653..a097dd372b14f 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -295,7 +295,9 @@ def main(args: argparse.Namespace): # Save to file base_model_id = model_id.split("/")[-1] - file_name = f"{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json" + file_name = ( + f"{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json" + ) with open(file_name, "w") as outfile: json.dump(result_json, outfile) @@ -343,7 +345,7 @@ def main(args: argparse.Namespace): "--tokenizer", type=str, help= - "Name or path of the tokenizer, if not using the default model tokenizer.", + "Name or path of the tokenizer, if not using the default tokenizer.", ) parser.add_argument( "--best-of", diff --git a/collect_env.py b/collect_env.py index 3c914795222ee..edcbfe73b38d0 100644 --- a/collect_env.py +++ b/collect_env.py @@ -1,7 +1,4 @@ -# flake8: noqa -# UPSTREAM SYNC: noqa is required for passing ruff. -# This file has been modified by Neural Magic - +# ruff: noqa # code borrowed from https://github.com/pytorch/pytorch/blob/main/torch/utils/collect_env.py # Unlike the rest of the PyTorch this file must be python2 compliant. @@ -15,7 +12,6 @@ import os from collections import namedtuple - try: import torch TORCH_AVAILABLE = True @@ -23,38 +19,40 @@ TORCH_AVAILABLE = False # System Environment Information -SystemEnv = namedtuple('SystemEnv', [ - 'torch_version', - 'is_debug_build', - 'cuda_compiled_version', - 'gcc_version', - 'clang_version', - 'cmake_version', - 'os', - 'libc_version', - 'python_version', - 'python_platform', - 'is_cuda_available', - 'cuda_runtime_version', - 'cuda_module_loading', - 'nvidia_driver_version', - 'nvidia_gpu_models', - 'cudnn_version', - 'pip_version', # 'pip' or 'pip3' - 'pip_packages', - 'conda_packages', - 'hip_compiled_version', - 'hip_runtime_version', - 'miopen_runtime_version', - 'caching_allocator_config', - 'is_xnnpack_available', - 'cpu_info', - 'rocm_version', # vllm specific field - 'neuron_sdk_version', # vllm specific field - 'vllm_version', # vllm specific field - 'vllm_build_flags', # vllm specific field - 'gpu_topo', # vllm specific field -]) +SystemEnv = namedtuple( + 'SystemEnv', + [ + 'torch_version', + 'is_debug_build', + 'cuda_compiled_version', + 'gcc_version', + 'clang_version', + 'cmake_version', + 'os', + 'libc_version', + 'python_version', + 'python_platform', + 'is_cuda_available', + 'cuda_runtime_version', + 'cuda_module_loading', + 'nvidia_driver_version', + 'nvidia_gpu_models', + 'cudnn_version', + 'pip_version', # 'pip' or 'pip3' + 'pip_packages', + 'conda_packages', + 'hip_compiled_version', + 'hip_runtime_version', + 'miopen_runtime_version', + 'caching_allocator_config', + 'is_xnnpack_available', + 'cpu_info', + 'rocm_version', # vllm specific field + 'neuron_sdk_version', # vllm specific field + 'vllm_version', # vllm specific field + 'vllm_build_flags', # vllm specific field + 'gpu_topo', # vllm specific field + ]) DEFAULT_CONDA_PATTERNS = { "torch", @@ -81,8 +79,10 @@ def run(command): """Return (return-code, stdout, stderr).""" shell = True if type(command) is str else False - p = subprocess.Popen(command, stdout=subprocess.PIPE, - stderr=subprocess.PIPE, shell=shell) + p = subprocess.Popen(command, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=shell) raw_output, raw_err = p.communicate() rc = p.returncode if get_platform() == 'win32': @@ -112,6 +112,7 @@ def run_and_parse_first_match(run_lambda, command, regex): return None return match.group(1) + def run_and_return_first_line(run_lambda, command): """Run command using run_lambda and returns first line if output is not empty.""" rc, out, _ = run_lambda(command) @@ -128,22 +129,23 @@ def get_conda_packages(run_lambda, patterns=None): if out is None: return out - return "\n".join( - line - for line in out.splitlines() - if not line.startswith("#") - and any(name in line for name in patterns) - ) + return "\n".join(line for line in out.splitlines() + if not line.startswith("#") and any(name in line + for name in patterns)) + def get_gcc_version(run_lambda): return run_and_parse_first_match(run_lambda, 'gcc --version', r'gcc (.*)') + def get_clang_version(run_lambda): - return run_and_parse_first_match(run_lambda, 'clang --version', r'clang version (.*)') + return run_and_parse_first_match(run_lambda, 'clang --version', + r'clang version (.*)') def get_cmake_version(run_lambda): - return run_and_parse_first_match(run_lambda, 'cmake --version', r'cmake (.*)') + return run_and_parse_first_match(run_lambda, 'cmake --version', + r'cmake (.*)') def get_nvidia_driver_version(run_lambda): @@ -152,11 +154,13 @@ def get_nvidia_driver_version(run_lambda): return run_and_parse_first_match(run_lambda, cmd, r'com[.]nvidia[.]CUDA [(](.*?)[)]') smi = get_nvidia_smi() - return run_and_parse_first_match(run_lambda, smi, r'Driver Version: (.*?) ') + return run_and_parse_first_match(run_lambda, smi, + r'Driver Version: (.*?) ') def get_gpu_info(run_lambda): - if get_platform() == 'darwin' or (TORCH_AVAILABLE and hasattr(torch.version, 'hip') and torch.version.hip is not None): + if get_platform() == 'darwin' or (TORCH_AVAILABLE and hasattr( + torch.version, 'hip') and torch.version.hip is not None): if TORCH_AVAILABLE and torch.cuda.is_available(): if torch.version.hip is not None: prop = torch.cuda.get_device_properties(0) @@ -178,7 +182,8 @@ def get_gpu_info(run_lambda): def get_running_cuda_version(run_lambda): - return run_and_parse_first_match(run_lambda, 'nvcc --version', r'release .+ V(.*)') + return run_and_parse_first_match(run_lambda, 'nvcc --version', + r'release .+ V(.*)') def get_cudnn_version(run_lambda): @@ -223,8 +228,10 @@ def get_nvidia_smi(): smi = 'nvidia-smi' if get_platform() == 'win32': system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows') - program_files_root = os.environ.get('PROGRAMFILES', 'C:\\Program Files') - legacy_path = os.path.join(program_files_root, 'NVIDIA Corporation', 'NVSMI', smi) + program_files_root = os.environ.get('PROGRAMFILES', + 'C:\\Program Files') + legacy_path = os.path.join(program_files_root, 'NVIDIA Corporation', + 'NVSMI', smi) new_path = os.path.join(system_root, 'System32', smi) smis = [new_path, legacy_path] for candidate_smi in smis: @@ -236,7 +243,8 @@ def get_nvidia_smi(): def get_rocm_version(run_lambda): """Returns the ROCm version if available, otherwise 'N/A'.""" - return run_and_parse_first_match(run_lambda, 'hipcc --version', r'HIP version: (\S+)') + return run_and_parse_first_match(run_lambda, 'hipcc --version', + r'HIP version: (\S+)') def get_neuron_sdk_version(run_lambda): @@ -346,13 +354,16 @@ def get_gpu_topo(run_lambda): # ProcessorType=3 # Revision=27142 + def get_cpu_info(run_lambda): rc, out, err = 0, '', '' if get_platform() == 'linux': rc, out, err = run_lambda('lscpu') elif get_platform() == 'win32': - rc, out, err = run_lambda('wmic cpu get Name,Manufacturer,Family,Architecture,ProcessorType,DeviceID, \ - CurrentClockSpeed,MaxClockSpeed,L2CacheSize,L2CacheSpeed,Revision /VALUE') + rc, out, err = run_lambda( + 'wmic cpu get Name,Manufacturer,Family,Architecture,ProcessorType,DeviceID, \ + CurrentClockSpeed,MaxClockSpeed,L2CacheSize,L2CacheSpeed,Revision /VALUE' + ) elif get_platform() == 'darwin': rc, out, err = run_lambda("sysctl -n machdep.cpu.brand_string") cpu_info = 'None' @@ -377,18 +388,22 @@ def get_platform(): def get_mac_version(run_lambda): - return run_and_parse_first_match(run_lambda, 'sw_vers -productVersion', r'(.*)') + return run_and_parse_first_match(run_lambda, 'sw_vers -productVersion', + r'(.*)') def get_windows_version(run_lambda): system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows') wmic_cmd = os.path.join(system_root, 'System32', 'Wbem', 'wmic') findstr_cmd = os.path.join(system_root, 'System32', 'findstr') - return run_and_read_all(run_lambda, '{} os get Caption | {} /v Caption'.format(wmic_cmd, findstr_cmd)) + return run_and_read_all( + run_lambda, + '{} os get Caption | {} /v Caption'.format(wmic_cmd, findstr_cmd)) def get_lsb_version(run_lambda): - return run_and_parse_first_match(run_lambda, 'lsb_release -a', r'Description:\t(.*)') + return run_and_parse_first_match(run_lambda, 'lsb_release -a', + r'Description:\t(.*)') def check_release_file(run_lambda): @@ -447,11 +462,8 @@ def get_pip_packages(run_lambda, patterns=None): # But here it is invoked as `python -mpip` def run_with_pip(pip): out = run_and_read_all(run_lambda, pip + ["list", "--format=freeze"]) - return "\n".join( - line - for line in out.splitlines() - if any(name in line for name in patterns) - ) + return "\n".join(line for line in out.splitlines() + if any(name in line for name in patterns)) pip_version = 'pip3' if sys.version[0] == '3' else 'pip' out = run_with_pip([sys.executable, '-mpip']) @@ -476,10 +488,12 @@ def get_cuda_module_loading_config(): def is_xnnpack_available(): if TORCH_AVAILABLE: import torch.backends.xnnpack - return str(torch.backends.xnnpack.enabled) # type: ignore[attr-defined] + return str( + torch.backends.xnnpack.enabled) # type: ignore[attr-defined] else: return "N/A" + def get_env_info(): run_lambda = run pip_version, pip_list_output = get_pip_packages(run_lambda) @@ -489,9 +503,11 @@ def get_env_info(): debug_mode_str = str(torch.version.debug) cuda_available_str = str(torch.cuda.is_available()) cuda_version_str = torch.version.cuda - if not hasattr(torch.version, 'hip') or torch.version.hip is None: # cuda version + if not hasattr(torch.version, + 'hip') or torch.version.hip is None: # cuda version hip_compiled_version = hip_runtime_version = miopen_runtime_version = 'N/A' else: # HIP version + def get_version_or_na(cfg, prefix): _lst = [s.rsplit(None, 1)[-1] for s in cfg if prefix in s] return _lst[0] if _lst else 'N/A' @@ -518,7 +534,9 @@ def get_version_or_na(cfg, prefix): return SystemEnv( torch_version=version_str, is_debug_build=debug_mode_str, - python_version='{} ({}-bit runtime)'.format(sys_version, sys.maxsize.bit_length() + 1), + python_version='{} ({}-bit runtime)'.format( + sys_version, + sys.maxsize.bit_length() + 1), python_platform=get_python_platform(), is_cuda_available=cuda_available_str, cuda_compiled_version=cuda_version_str, @@ -548,6 +566,7 @@ def get_version_or_na(cfg, prefix): gpu_topo=gpu_topo, ) + env_info_fmt = """ PyTorch version: {torch_version} Is debug build: {is_debug_build} @@ -592,6 +611,7 @@ def get_version_or_na(cfg, prefix): def pretty_str(envinfo): + def replace_nones(dct, replacement='Could not collect'): for key in dct.keys(): if dct[key] is not None: @@ -636,9 +656,10 @@ def maybe_start_on_next_line(string): 'nvidia_driver_version', ] all_cuda_fields = dynamic_cuda_fields + ['cudnn_version'] - all_dynamic_cuda_fields_missing = all( - mutable_dict[field] is None for field in dynamic_cuda_fields) - if TORCH_AVAILABLE and not torch.cuda.is_available() and all_dynamic_cuda_fields_missing: + all_dynamic_cuda_fields_missing = all(mutable_dict[field] is None + for field in dynamic_cuda_fields) + if TORCH_AVAILABLE and not torch.cuda.is_available( + ) and all_dynamic_cuda_fields_missing: for field in all_cuda_fields: mutable_dict[field] = 'No CUDA' if envinfo.cuda_compiled_version is None: @@ -651,17 +672,19 @@ def maybe_start_on_next_line(string): mutable_dict = replace_nones(mutable_dict) # If either of these are '', replace with 'No relevant packages' - mutable_dict['pip_packages'] = replace_if_empty(mutable_dict['pip_packages']) - mutable_dict['conda_packages'] = replace_if_empty(mutable_dict['conda_packages']) + mutable_dict['pip_packages'] = replace_if_empty( + mutable_dict['pip_packages']) + mutable_dict['conda_packages'] = replace_if_empty( + mutable_dict['conda_packages']) # Tag conda and pip packages with a prefix # If they were previously None, they'll show up as ie '[conda] Could not collect' if mutable_dict['pip_packages']: - mutable_dict['pip_packages'] = prepend(mutable_dict['pip_packages'], - '[{}] '.format(envinfo.pip_version)) + mutable_dict['pip_packages'] = prepend( + mutable_dict['pip_packages'], '[{}] '.format(envinfo.pip_version)) if mutable_dict['conda_packages']: - mutable_dict['conda_packages'] = prepend(mutable_dict['conda_packages'], - '[conda] ') + mutable_dict['conda_packages'] = prepend( + mutable_dict['conda_packages'], '[conda] ') mutable_dict['cpu_info'] = envinfo.cpu_info return env_info_fmt.format(**mutable_dict) @@ -675,18 +698,22 @@ def main(): output = get_pretty_env_info() print(output) - if TORCH_AVAILABLE and hasattr(torch, 'utils') and hasattr(torch.utils, '_crash_handler'): + if TORCH_AVAILABLE and hasattr(torch, 'utils') and hasattr( + torch.utils, '_crash_handler'): minidump_dir = torch.utils._crash_handler.DEFAULT_MINIDUMP_DIR if sys.platform == "linux" and os.path.exists(minidump_dir): - dumps = [os.path.join(minidump_dir, dump) for dump in os.listdir(minidump_dir)] + dumps = [ + os.path.join(minidump_dir, dump) + for dump in os.listdir(minidump_dir) + ] latest = max(dumps, key=os.path.getctime) ctime = os.path.getctime(latest) - creation_time = datetime.datetime.fromtimestamp(ctime).strftime('%Y-%m-%d %H:%M:%S') + creation_time = datetime.datetime.fromtimestamp(ctime).strftime( + '%Y-%m-%d %H:%M:%S') msg = "\n*** Detected a minidump at {} created on {}, ".format(latest, creation_time) + \ "if this is related to your bug please include it when you file a report ***" print(msg, file=sys.stderr) - if __name__ == '__main__': main() diff --git a/csrc/moe_align_block_size_kernels.cu b/csrc/moe_align_block_size_kernels.cu index de6a0ec0a972c..138615a4bfba0 100644 --- a/csrc/moe_align_block_size_kernels.cu +++ b/csrc/moe_align_block_size_kernels.cu @@ -7,10 +7,17 @@ #include "cuda_compat.h" #include "dispatch_utils.h" -const static size_t NUM_MAX_EXPERTS = 64; #define CEILDIV(x,y) (((x) + (y) - 1) / (y)) namespace vllm { + +namespace { +__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, int32_t col) { + // don't worry about overflow because num_experts is relatively small + return row * total_col + col; +} +} + template __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids, int32_t *sorted_token_ids, @@ -21,10 +28,14 @@ __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids, size_t numel) { const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); const size_t start_idx = threadIdx.x * tokens_per_thread; - __shared__ int32_t tokens_cnts[NUM_MAX_EXPERTS + 1][NUM_MAX_EXPERTS]; - __shared__ int32_t cumsum[NUM_MAX_EXPERTS + 1]; + + extern __shared__ int32_t shared_mem[]; + + int32_t* tokens_cnts = shared_mem; // 2d tensor with shape (num_experts + 1, num_experts) + int32_t* cumsum = shared_mem + (num_experts + 1) * num_experts; // 1d tensor with shape (num_experts + 1) + for (int i = 0; i < num_experts; ++i) { - tokens_cnts[threadIdx.x + 1][i] = 0; + tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0; } /** @@ -33,15 +44,15 @@ __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids, * to expert expert_index. */ for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { - ++tokens_cnts[threadIdx.x + 1][topk_ids[i]]; + ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])]; } __syncthreads(); // For each expert we accumulate the token counts from the different threads. - tokens_cnts[0][threadIdx.x] = 0; + tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0; for (int i = 1; i <= blockDim.x; ++i) { - tokens_cnts[i][threadIdx.x] += tokens_cnts[i-1][threadIdx.x]; + tokens_cnts[index(num_experts, i, threadIdx.x)] += tokens_cnts[index(num_experts, i-1, threadIdx.x)]; } __syncthreads(); @@ -50,7 +61,7 @@ __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids, if (threadIdx.x == 0) { cumsum[0] = 0; for (int i = 1; i <= num_experts; ++i) { - cumsum[i] = cumsum[i-1] + CEILDIV(tokens_cnts[blockDim.x][i - 1], block_size) * block_size; + cumsum[i] = cumsum[i-1] + CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], block_size) * block_size; } *total_tokens_post_pad = cumsum[num_experts]; } @@ -78,9 +89,9 @@ __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids, * stores the indices of the tokens processed by the expert with expert_id within * the current thread's token shard. */ - int32_t rank_post_pad = tokens_cnts[threadIdx.x][expert_id] + cumsum[expert_id]; + int32_t rank_post_pad = tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + cumsum[expert_id]; sorted_token_ids[rank_post_pad] = i; - ++tokens_cnts[threadIdx.x][expert_id]; + ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)]; } } } @@ -93,11 +104,16 @@ void moe_align_block_size( torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad) { const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - assert(num_experts <= NUM_MAX_EXPERTS); VLLM_DISPATCH_INTEGRAL_TYPES( topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { - vllm::moe_align_block_size_kernel<<<1, num_experts, 0, stream>>>( - topk_ids.data_ptr(), + // calc needed amount of shared mem for `tokens_cnts` and `cumsum` tensors + const int32_t shared_mem = ((num_experts + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t); + + // set dynamic shared mem + auto kernel = vllm::moe_align_block_size_kernel; + AT_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem)); + kernel<<<1, num_experts, shared_mem, stream>>>( + topk_ids.data_ptr(), sorted_token_ids.data_ptr(), experts_ids.data_ptr(), num_tokens_post_pad.data_ptr(), diff --git a/csrc/punica/bgmv/generator.py b/csrc/punica/bgmv/generator.py index 7ceaf9e6892a5..c347d4f2ab9f4 100644 --- a/csrc/punica/bgmv/generator.py +++ b/csrc/punica/bgmv/generator.py @@ -10,7 +10,7 @@ #include "bgmv_impl.cuh" FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, {input_dtype}, {output_dtype}, {weight_dtype}) -""".lstrip() # noqa: E501 (UPSTREAM SYNC nm-automation) +""".lstrip() # noqa: E501 for input_dtype in DTYPES: for output_dtype in DTYPES: diff --git a/docs/source/dev/kernel/paged_attention.rst b/docs/source/dev/kernel/paged_attention.rst index 6fcadeeec27b6..ba4f7a2718158 100644 --- a/docs/source/dev/kernel/paged_attention.rst +++ b/docs/source/dev/kernel/paged_attention.rst @@ -447,7 +447,7 @@ Value a whole block of value tokens. And each ``accs`` in each thread contains 8 elements that accumulated at 8 different head positions. For the thread 0, the ``accs`` variable will have 8 elements, which - are 0th, 16th … 112th elements of a value head that are accumulated + are 0th, 32th … 224th elements of a value head that are accumulated from all assigned 8 tokens. LV diff --git a/examples/multilora_inference.py b/examples/multilora_inference.py index 7b1d580a9a7f6..9b85452c232f8 100644 --- a/examples/multilora_inference.py +++ b/examples/multilora_inference.py @@ -1,7 +1,8 @@ # flake8: noqa # UPSTREAM SYNC: noqa is required for passing ruff run on nm-automation """ -This example shows how to use the multi-LoRA functionality for offline inference. +This example shows how to use the multi-LoRA functionality +for offline inference. Requires HuggingFace credentials for access to Llama2. """ @@ -18,7 +19,7 @@ def create_test_prompts( lora_path: str ) -> List[Tuple[str, SamplingParams, Optional[LoRARequest]]]: """Create a list of test prompts with their sampling parameters. - + 2 requests for base model, 4 requests for the LoRA. We define 2 different LoRA adapters (using the same model for demo purposes). Since we also set `max_loras=1`, the expectation is that the requests @@ -36,36 +37,40 @@ def create_test_prompts( top_k=5, presence_penalty=0.2, max_tokens=128), None), - ("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", - SamplingParams(temperature=0.0, - logprobs=1, - prompt_logprobs=1, - max_tokens=128, - stop_token_ids=[32003]), - LoRARequest("sql-lora", 1, lora_path)), - ("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", - SamplingParams(n=3, - best_of=3, - use_beam_search=True, - temperature=0, - max_tokens=128, - stop_token_ids=[32003]), - LoRARequest("sql-lora", 1, lora_path)), - ("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", - SamplingParams(temperature=0.0, - logprobs=1, - prompt_logprobs=1, - max_tokens=128, - stop_token_ids=[32003]), - LoRARequest("sql-lora2", 2, lora_path)), - ("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", - SamplingParams(n=3, - best_of=3, - use_beam_search=True, - temperature=0, - max_tokens=128, - stop_token_ids=[32003]), - LoRARequest("sql-lora", 1, lora_path)), + ( + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501 + SamplingParams(temperature=0.0, + logprobs=1, + prompt_logprobs=1, + max_tokens=128, + stop_token_ids=[32003]), + LoRARequest("sql-lora", 1, lora_path)), + ( + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501 + SamplingParams(n=3, + best_of=3, + use_beam_search=True, + temperature=0, + max_tokens=128, + stop_token_ids=[32003]), + LoRARequest("sql-lora", 1, lora_path)), + ( + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501 + SamplingParams(temperature=0.0, + logprobs=1, + prompt_logprobs=1, + max_tokens=128, + stop_token_ids=[32003]), + LoRARequest("sql-lora2", 2, lora_path)), + ( + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501 + SamplingParams(n=3, + best_of=3, + use_beam_search=True, + temperature=0, + max_tokens=128, + stop_token_ids=[32003]), + LoRARequest("sql-lora", 1, lora_path)), ] diff --git a/examples/offline_inference_with_prefix.py b/examples/offline_inference_with_prefix.py index 2c6c6aa63944d..15d66a4350a47 100644 --- a/examples/offline_inference_with_prefix.py +++ b/examples/offline_inference_with_prefix.py @@ -40,9 +40,10 @@ print("-" * 80) -# The llm.generate call will batch all prompts and send the batch at once if resources allow. -# The prefix will only be cached after the first batch is processed, so we need to call generate once -# to calculate the prefix and cache it. +# The llm.generate call will batch all prompts and send the batch at once +# if resources allow. The prefix will only be cached after the first batch +# is processed, so we need to call generate once to calculate the prefix +# and cache it. outputs = llm.generate(generating_prompts[0], sampling_params) # Subsequent batches can leverage the cached prefix diff --git a/examples/template_baichuan.jinja b/examples/template_baichuan.jinja index a1812a6c09ab1..42a8d9270a4c6 100644 --- a/examples/template_baichuan.jinja +++ b/examples/template_baichuan.jinja @@ -1,22 +1,13 @@ {{ (messages|selectattr('role', 'equalto', 'system')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'system')|list) else '' }} -{% for message in messages %} -{% if message['role'] == 'user' %} - -{{ message['content']|trim -}} -{% if not loop.last %} - - -{% endif %} -{% elif message['role'] == 'assistant' %} - -{{ message['content']|trim -}} -{% if not loop.last %} - - -{% endif %} -{% endif %} -{% endfor %} -{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %} - +{%- for message in messages -%} + {%- if message['role'] == 'user' -%} + {{- '' + message['content'] -}} + {%- elif message['role'] == 'assistant' -%} + {{- '' + message['content'] -}} + {%- endif -%} +{%- endfor -%} + +{%- if add_generation_prompt and messages[-1]['role'] != 'assistant' -%} + {{- '' -}} {% endif %} \ No newline at end of file diff --git a/examples/template_chatglm.jinja b/examples/template_chatglm.jinja new file mode 100644 index 0000000000000..bf26f27274ef4 --- /dev/null +++ b/examples/template_chatglm.jinja @@ -0,0 +1,18 @@ +{%- set counter = namespace(index=0) -%} +{%- for message in messages -%} + {%- if message['role'] == 'user' -%} + {{- '[Round ' + counter.index|string + ']\n问:' + message['content'] -}} + {%- set counter.index = counter.index + 1 -%} + {%- endif -%} + {%- if message['role'] == 'assistant' -%} + {{- '\n答:' + message['content'] -}} + {%- if (loop.last and add_generation_prompt) or not loop.last -%} + {{- '\n' -}} + {%- endif -%} + {%- endif -%} +{%- endfor -%} + + +{%- if add_generation_prompt and messages[-1]['role'] != 'assistant' -%} + {{- '\n答:' -}} +{%- endif -%} \ No newline at end of file diff --git a/examples/template_chatglm2.jinja b/examples/template_chatglm2.jinja new file mode 100644 index 0000000000000..c155b7c23f640 --- /dev/null +++ b/examples/template_chatglm2.jinja @@ -0,0 +1,18 @@ +{%- set counter = namespace(index=1) -%} +{%- for message in messages -%} + {%- if message['role'] == 'user' -%} + {{- '[Round ' + counter.index|string + ']\n\n问:' + message['content'] -}} + {%- set counter.index = counter.index + 1 -%} + {%- endif -%} + {%- if message['role'] == 'assistant' -%} + {{- '\n\n答:' + message['content'] -}} + {%- if (loop.last and add_generation_prompt) or not loop.last -%} + {{- '\n\n' -}} + {%- endif -%} + {%- endif -%} +{%- endfor -%} + + +{%- if add_generation_prompt and messages[-1]['role'] != 'assistant' -%} + {{- '\n\n答:' -}} +{%- endif -%} \ No newline at end of file diff --git a/examples/template_falcon.jinja b/examples/template_falcon.jinja new file mode 100644 index 0000000000000..01cf0e2670d0f --- /dev/null +++ b/examples/template_falcon.jinja @@ -0,0 +1,15 @@ +{%- for message in messages -%} + {%- if message['role'] == 'user' -%} + {{- 'User: ' + message['content'] -}} + {%- elif message['role'] == 'assistant' -%} + {{- 'Assistant: ' + message['content'] -}} + {%- endif -%} + {%- if (loop.last and add_generation_prompt) or not loop.last -%} + {{- '\n' -}} + {%- endif -%} +{%- endfor -%} + + +{%- if add_generation_prompt and messages[-1]['role'] != 'assistant' -%} + {{- 'Assistant:' -}} +{% endif %} \ No newline at end of file diff --git a/examples/template_falcon_180b.jinja b/examples/template_falcon_180b.jinja new file mode 100644 index 0000000000000..f08f7395b7fd7 --- /dev/null +++ b/examples/template_falcon_180b.jinja @@ -0,0 +1,17 @@ +{%- for message in messages -%} + {%- if message['role'] == 'system' -%} + {{- 'System: ' + message['content'] -}} + {%- elif message['role'] == 'user' -%} + {{- 'User: ' + message['content'] -}} + {%- elif message['role'] == 'assistant' -%} + {{- 'Falcon: ' + message['content'] -}} + {%- endif -%} + {%- if (loop.last and add_generation_prompt) or not loop.last -%} + {{- '\n' -}} + {%- endif -%} +{%- endfor -%} + + +{%- if add_generation_prompt and messages[-1]['role'] != 'assistant' -%} + {{- 'Falcon:' -}} +{% endif %} \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index d6fa5d7a035ff..e0a01215ef997 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,8 +33,6 @@ ignore = [ "F405", "F403", # lambda expression assignment "E731", - # .strip() with multi-character strings - "B005", # Loop control variable not used within loop body "B007", ] diff --git a/requirements-dev.txt b/requirements-dev.txt index 18de4b7420de2..00fa132b14c21 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -18,6 +18,7 @@ pytest pytest-forked pytest-asyncio pytest-rerunfailures +pytest-shard httpx einops # required for MPT openai diff --git a/setup.py b/setup.py index 6c1b4a91134d0..ef34e462072bc 100644 --- a/setup.py +++ b/setup.py @@ -13,12 +13,22 @@ from packaging.version import parse, Version import setuptools +import sys import torch import torch.utils.cpp_extension as torch_cpp_ext -from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME +from torch.utils.cpp_extension import ( + BuildExtension, + CUDAExtension, + CUDA_HOME, + ROCM_HOME, +) ROOT_DIR = os.path.dirname(__file__) +# vLLM only supports Linux platform +assert sys.platform.startswith( + "linux"), "vLLM only supports Linux platform (including WSL)." + # If you are developing the C++ backend of vLLM, consider building vLLM with # `python setup.py develop` since it will give you incremental builds. # The downside is that this method is deprecated, see @@ -32,6 +42,10 @@ # SUPPORTED_ARCHS = NVIDIA_SUPPORTED_ARCHS.union(ROCM_SUPPORTED_ARCHS) +def _is_cuda() -> bool: + return torch.version.cuda is not None + + def _is_hip() -> bool: return torch.version.hip is not None @@ -40,15 +54,11 @@ def _is_neuron() -> bool: torch_neuronx_installed = True try: subprocess.run(["neuron-ls"], capture_output=True, check=True) - except (FileNotFoundError, PermissionError): + except (FileNotFoundError, PermissionError, subprocess.CalledProcessError): torch_neuronx_installed = False return torch_neuronx_installed -def _is_cuda() -> bool: - return (torch.version.cuda is not None) and not _is_neuron() - - # Compiler flags. CXX_FLAGS = ["-g", "-O2", "-std=c++17"] # TODO(woosuk): Should we use -O3? @@ -56,9 +66,8 @@ def _is_cuda() -> bool: if _is_hip(): if ROCM_HOME is None: - raise RuntimeError( - "Cannot find ROCM_HOME. ROCm must be available to build the package." - ) + raise RuntimeError("Cannot find ROCM_HOME. " + "ROCm must be available to build the package.") NVCC_FLAGS += ["-DUSE_ROCM"] NVCC_FLAGS += ["-U__HIP_NO_HALF_CONVERSIONS__"] NVCC_FLAGS += ["-U__HIP_NO_HALF_OPERATORS__"] @@ -143,7 +152,8 @@ def get_pytorch_rocm_arch() -> Set[str]: """ env_arch_list = os.environ.get("PYTORCH_ROCM_ARCH", None) - # If we don't have PYTORCH_ROCM_ARCH specified pull the list from rocm_agent_enumerator + # If we don't have PYTORCH_ROCM_ARCH specified pull the list from + # rocm_agent_enumerator if env_arch_list is None: command = "rocm_agent_enumerator" env_arch_list = (subprocess.check_output( @@ -254,11 +264,11 @@ def get_torch_arch_list() -> Set[str]: "CUDA 11.1 or higher is required for compute capability 8.6.") if nvcc_cuda_version < Version("11.8"): if any(cc.startswith("8.9") for cc in compute_capabilities): - # CUDA 11.8 is required to generate the code targeting compute capability 8.9. - # However, GPUs with compute capability 8.9 can also run the code generated by - # the previous versions of CUDA 11 and targeting compute capability 8.0. - # Therefore, if CUDA 11.8 is not available, we target compute capability 8.0 - # instead of 8.9. + # CUDA 11.8 is required to generate the code targeting compute + # capability 8.9. However, GPUs with compute capability 8.9 can + # also run the code generated by the previous versions of CUDA 11 + # and targeting compute capability 8.0. Therefore, if CUDA 11.8 + # is not available, we target compute capability 8.0 instead of 8.9. warnings.warn( "CUDA 11.8 or higher is required for compute capability 8.9. " "Targeting compute capability 8.0 instead.", @@ -394,7 +404,12 @@ def find_version(filepath: str) -> str: def get_vllm_version() -> str: version = find_version(get_path("vllm", "__init__.py")) - if _is_hip(): + if _is_cuda(): + cuda_version = str(nvcc_cuda_version) + if cuda_version != MAIN_CUDA_VERSION: + cuda_version_str = cuda_version.replace(".", "")[:3] + version += f"+cu{cuda_version_str}" + elif _is_hip(): # Get the HIP version hipcc_version = get_hipcc_rocm_version() if hipcc_version != MAIN_CUDA_VERSION: @@ -407,10 +422,7 @@ def get_vllm_version() -> str: neuron_version_str = neuron_version.replace(".", "")[:3] version += f"+neuron{neuron_version_str}" else: - cuda_version = str(nvcc_cuda_version) - if cuda_version != MAIN_CUDA_VERSION: - cuda_version_str = cuda_version.replace(".", "")[:3] - version += f"+cu{cuda_version_str}" + raise RuntimeError("Unknown runtime environment") return version @@ -426,7 +438,16 @@ def read_readme() -> str: def get_requirements() -> List[str]: """Get Python package dependencies from requirements.txt.""" - if _is_hip(): + if _is_cuda(): + with open(get_path("requirements.txt")) as f: + requirements = f.read().strip().split("\n") + if nvcc_cuda_version <= Version("11.8"): + # replace cupy-cuda12x with cupy-cuda11x for cuda 11.x + for i in range(len(requirements)): + if requirements[i].startswith("cupy-cuda12x"): + requirements[i] = "cupy-cuda11x" + break + elif _is_hip(): with open(get_path("requirements-rocm.txt")) as f: requirements = f.read().strip().split("\n") elif _is_neuron(): @@ -444,6 +465,7 @@ def get_requirements() -> List[str]: return requirements +# UPSTREAM SYNC: accept current _sparsity_deps = ["nm-magic-wand"] diff --git a/tests/async_engine/test_api_server.py b/tests/async_engine/test_api_server.py index ed9017c1e3e9d..248bfbc8ab5c0 100644 --- a/tests/async_engine/test_api_server.py +++ b/tests/async_engine/test_api_server.py @@ -25,23 +25,21 @@ def _query_server_long(prompt: str) -> dict: @pytest.fixture -def api_server(): +def api_server(tokenizer_pool_size: int): script_path = Path(__file__).parent.joinpath( "api_server_async_engine.py").absolute() uvicorn_process = subprocess.Popen([ - sys.executable, - "-u", - str(script_path), - "--model", - "facebook/opt-125m", - "--host", - "127.0.0.1", + sys.executable, "-u", + str(script_path), "--model", "facebook/opt-125m", "--host", + "127.0.0.1", "--tokenizer-pool-size", + str(tokenizer_pool_size) ]) yield uvicorn_process.terminate() -def test_api_server(api_server): +@pytest.mark.parametrize("tokenizer_pool_size", [0, 2]) +def test_api_server(api_server, tokenizer_pool_size: int): """ Run the API server and test it. diff --git a/tests/conftest.py b/tests/conftest.py index d2b6223ba9da6..1e7c8f971698f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,6 +9,7 @@ from vllm import LLM, SamplingParams from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.config import TokenizerPoolConfig _TEST_DIR = os.path.dirname(__file__) _TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")] @@ -391,3 +392,13 @@ def generate_greedy_logprobs( @pytest.fixture def vllm_runner_nm(): return VllmRunnerNm + + +def get_tokenizer_pool_config(tokenizer_group_type): + if tokenizer_group_type is None: + return None + if tokenizer_group_type == "ray": + return TokenizerPoolConfig(pool_size=1, + pool_type="ray", + extra_config={}) + raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}") diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 4a2b89befd93f..77c2f54e0daff 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -662,5 +662,55 @@ async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI): extra_body=dict(guided_regex=TEST_REGEX, guided_json=TEST_SCHEMA)) +async def test_response_format_json_object(server, client: openai.AsyncOpenAI): + resp = await client.chat.completions.create( + model=MODEL_NAME, + messages=[{ + "role": + "user", + "content": ('what is 1+1? please respond with a JSON object, ' + 'the format is {"result": 2}') + }], + response_format={"type": "json_object"}) + + content = resp.choices[0].message.content + loaded = json.loads(content) + assert loaded == {"result": 2}, loaded + + +async def test_guided_grammar(server, client: openai.AsyncOpenAI): + simple_sql_grammar = """ +start: select_statement + +select_statement: "SELECT" column "from" table "where" condition + +column: "col_1" | "col_2" +table: "table_1" | "table_2" +condition: column "=" number + +number: "1" | "2" +""" + + completion = await client.completions.create( + model=MODEL_NAME, + prompt=("Generate a sql state that select col_1 from " + "table_1 where it is equals to 1"), + temperature=1.0, + max_tokens=500, + extra_body=dict(guided_grammar=simple_sql_grammar)) + + content = completion.choices[0].text + + # use Lark to parse the output, and make sure it's a valid parse tree + from lark import Lark + parser = Lark(simple_sql_grammar) + parser.parse(content) + + # remove spaces for comparison b/c we removed them in the grammar + ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(" ", "") + + assert content.strip() == ground_truth + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py index 204cc325f7da8..d41428e0a9ad3 100644 --- a/tests/kernels/test_prefix_prefill.py +++ b/tests/kernels/test_prefix_prefill.py @@ -18,7 +18,7 @@ @pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("num_queries_per_kv", NUM_HEADS) +@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("device", CUDA_DEVICES) @@ -38,6 +38,13 @@ def test_contexted_kv_attention( if torch.cuda.is_available(): torch.cuda.manual_seed(0) torch.set_default_device(device) + + # Need this, otherwise when we capture the graph the process for GPU 1 would + # run on both GPU0 and GPU1 and things would hang + # + # see also similar issue: https://github.com/Dao-AILab/flash-attention/issues/523 + torch.cuda.set_device(device) + MAX_SEQ_LEN = 1024 MAX_CTX_LEN = 1024 BS = 10 @@ -175,5 +182,5 @@ def test_contexted_kv_attention( torch.cuda.synchronize() end_time = time.time() print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms") - output_ref = output_ref.squeeze(0, 2) + output_ref = output_ref.reshape(output.shape) assert torch.allclose(output_ref, output, atol=1e-6, rtol=0) diff --git a/tests/lora/test_gemma.py b/tests/lora/test_gemma.py index 838f56f7fd05c..799133d1bf2f8 100644 --- a/tests/lora/test_gemma.py +++ b/tests/lora/test_gemma.py @@ -27,7 +27,7 @@ def do_sample(llm, lora_path: str, lora_id: int) -> str: return generated_texts -@pytest.mark.skip(reason="high likelihood sproadic failure in GHA") +@pytest.mark.skip("Skipping for upstream sync") def test_gemma_lora(gemma_lora_files): llm = vllm.LLM(MODEL_PATH, max_model_len=1024, diff --git a/tests/lora/test_tokenizer.py b/tests/lora/test_tokenizer.py deleted file mode 100644 index 6c4c91fce8127..0000000000000 --- a/tests/lora/test_tokenizer.py +++ /dev/null @@ -1,69 +0,0 @@ -import pytest -from transformers import AutoTokenizer, PreTrainedTokenizerBase - -from vllm.lora.request import LoRARequest -from vllm.transformers_utils.tokenizer import TokenizerGroup, get_lora_tokenizer - - -@pytest.mark.asyncio -async def test_transformers_tokenizer(): - reference_tokenizer = AutoTokenizer.from_pretrained("gpt2") - tokenizer = TokenizerGroup( - tokenizer_id="gpt2", - enable_lora=False, - max_num_seqs=1, - max_input_length=None, - ) - assert reference_tokenizer.encode("prompt") == tokenizer.encode( - request_id="request_id", prompt="prompt", lora_request=None) - assert reference_tokenizer.encode( - "prompt") == await tokenizer.encode_async(request_id="request_id", - prompt="prompt", - lora_request=None) - assert isinstance(tokenizer.get_lora_tokenizer(None), - PreTrainedTokenizerBase) - assert tokenizer.get_lora_tokenizer( - None) == await tokenizer.get_lora_tokenizer_async(None) - - -@pytest.mark.asyncio -async def test_transformers_tokenizer_lora(sql_lora_files): - reference_tokenizer = AutoTokenizer.from_pretrained(sql_lora_files) - tokenizer = TokenizerGroup( - tokenizer_id="gpt2", - enable_lora=True, - max_num_seqs=1, - max_input_length=None, - ) - lora_request = LoRARequest("1", 1, sql_lora_files) - assert reference_tokenizer.encode("prompt") == tokenizer.encode( - request_id="request_id", prompt="prompt", lora_request=lora_request) - assert reference_tokenizer.encode( - "prompt") == await tokenizer.encode_async(request_id="request_id", - prompt="prompt", - lora_request=lora_request) - assert isinstance(tokenizer.get_lora_tokenizer(None), - PreTrainedTokenizerBase) - assert tokenizer.get_lora_tokenizer( - None) == await tokenizer.get_lora_tokenizer_async(None) - - assert isinstance(tokenizer.get_lora_tokenizer(lora_request), - PreTrainedTokenizerBase) - assert tokenizer.get_lora_tokenizer( - lora_request) != tokenizer.get_lora_tokenizer(None) - assert tokenizer.get_lora_tokenizer( - lora_request) == await tokenizer.get_lora_tokenizer_async(lora_request) - - -def test_get_lora_tokenizer(sql_lora_files, tmpdir): - lora_request = None - tokenizer = get_lora_tokenizer(lora_request) - assert not tokenizer - - lora_request = LoRARequest("1", 1, sql_lora_files) - tokenizer = get_lora_tokenizer(lora_request) - assert tokenizer.get_added_vocab() - - lora_request = LoRARequest("1", 1, str(tmpdir)) - tokenizer = get_lora_tokenizer(lora_request) - assert not tokenizer diff --git a/tests/lora/test_tokenizer_group.py b/tests/lora/test_tokenizer_group.py new file mode 100644 index 0000000000000..5fec3f179925a --- /dev/null +++ b/tests/lora/test_tokenizer_group.py @@ -0,0 +1,53 @@ +import pytest +from transformers import AutoTokenizer, PreTrainedTokenizerBase +from vllm.lora.request import LoRARequest +from vllm.transformers_utils.tokenizer_group import get_tokenizer_group +from vllm.transformers_utils.tokenizer import get_lora_tokenizer +from ..conftest import get_tokenizer_pool_config + + +@pytest.mark.asyncio +@pytest.mark.parametrize("tokenizer_group_type", [None, "ray"]) +async def test_tokenizer_group_lora(sql_lora_files, tokenizer_group_type): + reference_tokenizer = AutoTokenizer.from_pretrained(sql_lora_files) + tokenizer_group = get_tokenizer_group( + get_tokenizer_pool_config(tokenizer_group_type), + tokenizer_id="gpt2", + enable_lora=True, + max_num_seqs=1, + max_input_length=None, + ) + lora_request = LoRARequest("1", 1, sql_lora_files) + assert reference_tokenizer.encode("prompt") == tokenizer_group.encode( + request_id="request_id", prompt="prompt", lora_request=lora_request) + assert reference_tokenizer.encode( + "prompt") == await tokenizer_group.encode_async( + request_id="request_id", + prompt="prompt", + lora_request=lora_request) + assert isinstance(tokenizer_group.get_lora_tokenizer(None), + PreTrainedTokenizerBase) + assert tokenizer_group.get_lora_tokenizer( + None) == await tokenizer_group.get_lora_tokenizer_async(None) + + assert isinstance(tokenizer_group.get_lora_tokenizer(lora_request), + PreTrainedTokenizerBase) + assert tokenizer_group.get_lora_tokenizer( + lora_request) != tokenizer_group.get_lora_tokenizer(None) + assert tokenizer_group.get_lora_tokenizer( + lora_request) == await tokenizer_group.get_lora_tokenizer_async( + lora_request) + + +def test_get_lora_tokenizer(sql_lora_files, tmpdir): + lora_request = None + tokenizer = get_lora_tokenizer(lora_request) + assert not tokenizer + + lora_request = LoRARequest("1", 1, sql_lora_files) + tokenizer = get_lora_tokenizer(lora_request) + assert tokenizer.get_added_vocab() + + lora_request = LoRARequest("1", 1, str(tmpdir)) + tokenizer = get_lora_tokenizer(lora_request) + assert not tokenizer diff --git a/tests/models/test_marlin.py b/tests/models/test_marlin.py index 7c0382dfa7b34..3767c8331e010 100644 --- a/tests/models/test_marlin.py +++ b/tests/models/test_marlin.py @@ -1,3 +1,4 @@ +# UPSTREAM SYNC: Use the current downstream version. """Compare the outputs of a GPTQ model to a Marlin model. Note: GPTQ and Marlin do not have bitwise correctness. diff --git a/tests/test_cache_block_hashing.py b/tests/test_cache_block_hashing.py index fb541f38f3489..498472530973b 100644 --- a/tests/test_cache_block_hashing.py +++ b/tests/test_cache_block_hashing.py @@ -7,7 +7,7 @@ import pytest from vllm.lora.request import LoRARequest -from vllm.transformers_utils.tokenizer import TokenizerGroup +from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.sequence import Sequence # Make two prefixes with different first blocks. @@ -66,8 +66,8 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int, hashes.append([]) prompts = [prefix + prompt for prompt in sample_prompts] - seq_id = 0 - for prompt in prompts: + # UPSTREAM SYNC: seq_id in enumerate needed to pass ruff + for seq_id, prompt in enumerate(prompts): hashes[-1].append([]) prompt_token_ids = tokenizer.encode(prompt) seq = Sequence(seq_id, prompt, prompt_token_ids, block_size, @@ -77,8 +77,6 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int, for idx in range(num_blocks): hashes[-1][-1].append(seq.hash_of_block(idx)) - seq_id += 1 - # Check that hashes made with two prefixes with different first blocks are # different everywhere. for hash0, hash1 in zip(flatten_2d(hashes[0]), flatten_2d(hashes[1])): diff --git a/tests/tokenization/__init__.py b/tests/tokenization/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/tokenization/test_cached_tokenizer.py b/tests/tokenization/test_cached_tokenizer.py new file mode 100644 index 0000000000000..181e800325128 --- /dev/null +++ b/tests/tokenization/test_cached_tokenizer.py @@ -0,0 +1,20 @@ +from copy import deepcopy +from vllm.transformers_utils.tokenizer import get_cached_tokenizer +from transformers import AutoTokenizer + + +def test_cached_tokenizer(): + reference_tokenizer = AutoTokenizer.from_pretrained("gpt2") + reference_tokenizer.add_special_tokens({"cls_token": ""}) + reference_tokenizer.add_special_tokens( + {"additional_special_tokens": [""]}) + cached_tokenizer = get_cached_tokenizer(deepcopy(reference_tokenizer)) + + assert reference_tokenizer.encode("prompt") == cached_tokenizer.encode( + "prompt") + assert set(reference_tokenizer.all_special_ids) == set( + cached_tokenizer.all_special_ids) + assert set(reference_tokenizer.all_special_tokens) == set( + cached_tokenizer.all_special_tokens) + assert set(reference_tokenizer.all_special_tokens_extended) == set( + cached_tokenizer.all_special_tokens_extended) diff --git a/tests/engine/test_detokenize.py b/tests/tokenization/test_detokenize.py similarity index 100% rename from tests/engine/test_detokenize.py rename to tests/tokenization/test_detokenize.py diff --git a/tests/tokenization/test_tokenizer_group.py b/tests/tokenization/test_tokenizer_group.py new file mode 100644 index 0000000000000..d0788ee87563d --- /dev/null +++ b/tests/tokenization/test_tokenizer_group.py @@ -0,0 +1,100 @@ +import os +import pytest +import asyncio +from unittest.mock import patch + +from transformers import AutoTokenizer, PreTrainedTokenizerBase +from vllm.transformers_utils.tokenizer_group import get_tokenizer_group +from vllm.transformers_utils.tokenizer_group.ray_tokenizer_group import ( + RayTokenizerGroupPool) +from vllm.transformers_utils.tokenizer_group.tokenizer_group import ( + TokenizerGroup) +from ..conftest import get_tokenizer_pool_config + + +@pytest.mark.asyncio +@pytest.mark.parametrize("tokenizer_group_type", [None, "ray"]) +async def test_tokenizer_group(tokenizer_group_type): + reference_tokenizer = AutoTokenizer.from_pretrained("gpt2") + tokenizer_group = get_tokenizer_group( + get_tokenizer_pool_config(tokenizer_group_type), + tokenizer_id="gpt2", + enable_lora=False, + max_num_seqs=1, + max_input_length=None, + ) + assert reference_tokenizer.encode("prompt") == tokenizer_group.encode( + request_id="request_id", prompt="prompt", lora_request=None) + assert reference_tokenizer.encode( + "prompt") == await tokenizer_group.encode_async( + request_id="request_id", prompt="prompt", lora_request=None) + assert isinstance(tokenizer_group.get_lora_tokenizer(None), + PreTrainedTokenizerBase) + assert tokenizer_group.get_lora_tokenizer( + None) == await tokenizer_group.get_lora_tokenizer_async(None) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("tokenizer_group_type", ["ray"]) +async def test_tokenizer_group_pool(tokenizer_group_type): + reference_tokenizer = AutoTokenizer.from_pretrained("gpt2") + tokenizer_group_pool = get_tokenizer_group( + get_tokenizer_pool_config(tokenizer_group_type), + tokenizer_id="gpt2", + enable_lora=False, + max_num_seqs=1, + max_input_length=None, + ) + # Send multiple requests to the tokenizer group pool + # (more than the pool size) + # and check that all requests are processed correctly. + num_requests = tokenizer_group_pool.pool_size * 5 + requests = [ + tokenizer_group_pool.encode_async(request_id=str(i), + prompt=f"prompt {i}", + lora_request=None) + for i in range(num_requests) + ] + results = await asyncio.gather(*requests) + expected_results = [ + reference_tokenizer.encode(f"prompt {i}") for i in range(num_requests) + ] + assert results == expected_results + + +@pytest.mark.asyncio +@pytest.mark.parametrize("tokenizer_group_type", ["ray"]) +async def test_tokenizer_group_ray_pool_env_var_propagation( + tokenizer_group_type): + """Test that env vars from caller process are propagated to + tokenizer Ray actors.""" + env_var = "MY_ENV_VAR" + + class EnvVarCheckerTokenizerGroup(TokenizerGroup): + + def ping(self): + assert os.environ.get(env_var) == "1" + return super().ping() + + class EnvVarCheckerRayTokenizerGroupPool(RayTokenizerGroupPool): + _worker_cls = EnvVarCheckerTokenizerGroup + + tokenizer_pool_config = get_tokenizer_pool_config(tokenizer_group_type) + tokenizer_pool = EnvVarCheckerRayTokenizerGroupPool.from_config( + tokenizer_pool_config, + tokenizer_id="gpt2", + enable_lora=False, + max_num_seqs=1, + max_input_length=None) + with pytest.raises(AssertionError): + tokenizer_pool.ping() + + with patch.dict(os.environ, {env_var: "1"}): + tokenizer_pool_config = get_tokenizer_pool_config(tokenizer_group_type) + tokenizer_pool = EnvVarCheckerRayTokenizerGroupPool.from_config( + tokenizer_pool_config, + tokenizer_id="gpt2", + enable_lora=False, + max_num_seqs=1, + max_input_length=None) + tokenizer_pool.ping() diff --git a/vllm/__init__.py b/vllm/__init__.py index 501b2dee6e789..101f1c7f55005 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -8,6 +8,7 @@ from vllm.outputs import CompletionOutput, RequestOutput from vllm.sampling_params import SamplingParams +# UPSTREAM SYNC: use the current downstream. __version__ = "0.1.0" __all__ = [ diff --git a/vllm/config.py b/vllm/config.py index 56fe6b522e7ee..f1c3c912d4f8a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -5,6 +5,7 @@ import os from packaging.version import Version +import json import torch from transformers import PretrainedConfig @@ -201,6 +202,7 @@ def _verify_quantization(self) -> None: if hf_quant_config is not None: hf_quant_method = str(hf_quant_config["quant_method"]).lower() + # UPSTREAM SYNC: Accept current downstream. # If the GPTQ model is serialized in marlin format, use marlin. marlin_format_flag = "is_marlin_format" if (hf_quant_method == "gptq" @@ -415,6 +417,58 @@ def verify_with_parallel_config( logger.warning("Possibly too large swap space. " + msg) +@dataclass +class TokenizerPoolConfig: + """Configuration for the tokenizer pool. + + Args: + pool_size: Number of tokenizer workers in the pool. + pool_type: Type of the pool. + extra_config: Additional config for the pool. + The way the config will be used depends on the + pool type. + """ + pool_size: int + pool_type: str + extra_config: dict + + def __post_init__(self): + if self.pool_type not in ("ray", ): + raise ValueError(f"Unknown pool type: {self.pool_type}") + if not isinstance(self.extra_config, dict): + raise ValueError("extra_config must be a dictionary.") + + @classmethod + def create_config( + cls, tokenizer_pool_size: int, tokenizer_pool_type: str, + tokenizer_pool_extra_config: Optional[Union[str, dict]] + ) -> Optional["TokenizerPoolConfig"]: + """Create a TokenizerPoolConfig from the given parameters. + + If tokenizer_pool_size is 0, return None. + + Args: + tokenizer_pool_size: Number of tokenizer workers in the pool. + tokenizer_pool_type: Type of the pool. + tokenizer_pool_extra_config: Additional config for the pool. + The way the config will be used depends on the + pool type. This can be a JSON string (will be parsed). + """ + if tokenizer_pool_size: + if isinstance(tokenizer_pool_extra_config, str): + tokenizer_pool_extra_config_parsed = json.loads( + tokenizer_pool_extra_config) + else: + tokenizer_pool_extra_config_parsed = ( + tokenizer_pool_extra_config or {}) + tokenizer_pool_config = cls(tokenizer_pool_size, + tokenizer_pool_type, + tokenizer_pool_extra_config_parsed) + else: + tokenizer_pool_config = None + return tokenizer_pool_config + + class ParallelConfig: """Configuration for the distributed execution. @@ -429,6 +483,8 @@ class ParallelConfig: parallel and large models. disable_custom_all_reduce: Disable the custom all-reduce kernel and fall back to NCCL. + tokenizer_pool_config: Config for the tokenizer pool. + If None, will use synchronous tokenization. ray_workers_use_nsight: Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler. """ @@ -440,6 +496,7 @@ def __init__( worker_use_ray: bool, max_parallel_loading_workers: Optional[int] = None, disable_custom_all_reduce: bool = False, + tokenizer_pool_config: Optional[TokenizerPoolConfig] = None, ray_workers_use_nsight: bool = False, placement_group: Optional["PlacementGroup"] = None, ) -> None: @@ -456,6 +513,7 @@ def __init__( self.worker_use_ray = worker_use_ray self.max_parallel_loading_workers = max_parallel_loading_workers self.disable_custom_all_reduce = disable_custom_all_reduce + self.tokenizer_pool_config = tokenizer_pool_config self.ray_workers_use_nsight = ray_workers_use_nsight self.placement_group = placement_group diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 9255f91be55cb..c3f93a2928df5 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -160,7 +160,7 @@ def _schedule(self) -> SchedulerOutputs: blocks_to_copy: Dict[int, List[int]] = {} # Fix the current time. - now = time.monotonic() + now = time.time() # Join waiting sequences if possible. if not self.swapped: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index f4ca26b816964..34ba15c1c16c4 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -6,7 +6,8 @@ from typing import Optional, Tuple from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, - ParallelConfig, SchedulerConfig, LoRAConfig) + ParallelConfig, SchedulerConfig, LoRAConfig, + TokenizerPoolConfig) @dataclass @@ -43,6 +44,9 @@ class EngineArgs: enforce_eager: bool = False max_context_len_to_capture: int = 8192 disable_custom_all_reduce: bool = False + tokenizer_pool_size: int = 0 + tokenizer_pool_type: str = "ray" + tokenizer_pool_extra_config: Optional[dict] = None enable_lora: bool = False max_loras: int = 1 max_lora_rank: int = 16 @@ -262,6 +266,25 @@ def add_cli_args( action='store_true', default=EngineArgs.disable_custom_all_reduce, help='See ParallelConfig') + parser.add_argument('--tokenizer-pool-size', + type=int, + default=EngineArgs.tokenizer_pool_size, + help='Size of tokenizer pool to use for ' + 'asynchronous tokenization. If 0, will ' + 'use synchronous tokenization.') + parser.add_argument('--tokenizer-pool-type', + type=str, + default=EngineArgs.tokenizer_pool_type, + help='Type of tokenizer pool to use for ' + 'asynchronous tokenization. Ignored ' + 'if tokenizer_pool_size is 0.') + parser.add_argument('--tokenizer-pool-extra-config', + type=str, + default=EngineArgs.tokenizer_pool_extra_config, + help='Extra config for tokenizer pool. ' + 'This should be a JSON string that will be ' + 'parsed into a dictionary. Ignored if ' + 'tokenizer_pool_size is 0.') # LoRA related configs parser.add_argument('--enable-lora', action='store_true', @@ -316,23 +339,37 @@ def create_engine_configs( DeviceConfig, Optional[LoRAConfig]]: device_config = DeviceConfig(self.device) model_config = ModelConfig( - self.model, self.tokenizer, self.tokenizer_mode, - self.trust_remote_code, self.download_dir, self.load_format, - self.dtype, self.seed, self.revision, self.code_revision, - self.tokenizer_revision, self.max_model_len, self.quantization, - self.sparsity, self.enforce_eager, self.max_context_len_to_capture, + self.model, + self.tokenizer, + self.tokenizer_mode, + self.trust_remote_code, + self.download_dir, + self.load_format, + self.dtype, + self.seed, + self.revision, + self.code_revision, + self.tokenizer_revision, + self.max_model_len, + self.quantization, + # UPSTREAM SYNC: Accept current. + self.sparsity, + self.enforce_eager, + self.max_context_len_to_capture, self.max_logprobs) cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.swap_space, self.kv_cache_dtype, - model_config.get_sliding_window(), - self.enable_prefix_caching) - parallel_config = ParallelConfig(self.pipeline_parallel_size, - self.tensor_parallel_size, - self.worker_use_ray, - self.max_parallel_loading_workers, - self.disable_custom_all_reduce, - self.ray_workers_use_nsight) + model_config.get_sliding_window()) + parallel_config = ParallelConfig( + self.pipeline_parallel_size, self.tensor_parallel_size, + self.worker_use_ray, self.max_parallel_loading_workers, + self.disable_custom_all_reduce, + TokenizerPoolConfig.create_config( + self.tokenizer_pool_size, + self.tokenizer_pool_type, + self.tokenizer_pool_extra_config, + ), self.ray_workers_use_nsight) scheduler_config = SchedulerConfig(self.max_num_batched_tokens, self.max_num_seqs, model_config.max_model_len, diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 0cee604c14d45..8bcd1e0ede6e5 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -604,8 +604,7 @@ async def generate( >>> ... """ # Preprocess the request. - # This should not be used for logging, as it is monotonic time. - arrival_time = time.monotonic() + arrival_time = time.time() try: stream = await self.add_request( diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index cb1c7d48d154a..b4b354fe52e1e 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -17,8 +17,9 @@ from vllm.sampling_params import SamplingParams from vllm.sequence import (Logprob, SamplerOutput, Sequence, SequenceGroup, SequenceGroupOutput, SequenceOutput, SequenceStatus) -from vllm.transformers_utils.tokenizer import (detokenize_incrementally, - TokenizerGroup) +from vllm.transformers_utils.tokenizer import detokenize_incrementally +from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup, + get_tokenizer_group) from vllm.utils import Counter logger = init_logger(__name__) @@ -103,6 +104,10 @@ def __init__( parallel_config, scheduler_config, device_config, lora_config) + # Ping the tokenizer to ensure liveness if it runs in a + # different process. + self.tokenizer.ping() + # Create the scheduler. # NOTE: the cache_config here have been updated with the numbers of # GPU and CPU blocks, which are profiled in the distributed executor. @@ -153,6 +158,7 @@ def get_tokenizer_for_seq(self, def _init_tokenizer(self, **tokenizer_init_kwargs): init_kwargs = dict( + tokenizer_id=self.model_config.tokenizer, enable_lora=bool(self.lora_config), max_num_seqs=self.scheduler_config.max_num_seqs, max_input_length=None, @@ -160,8 +166,9 @@ def _init_tokenizer(self, **tokenizer_init_kwargs): trust_remote_code=self.model_config.trust_remote_code, revision=self.model_config.tokenizer_revision) init_kwargs.update(tokenizer_init_kwargs) - self.tokenizer: TokenizerGroup = TokenizerGroup( - self.model_config.tokenizer, **init_kwargs) + + self.tokenizer: BaseTokenizerGroup = get_tokenizer_group( + self.parallel_config.tokenizer_pool_config, **init_kwargs) def _verify_args(self) -> None: self.model_config.verify_with_parallel_config(self.parallel_config) @@ -245,7 +252,7 @@ def add_request( raise ValueError(f"Cannot request more than " f"{max_logprobs} logprobs.") if arrival_time is None: - arrival_time = time.monotonic() + arrival_time = time.time() prompt_token_ids = self.encode_request( request_id=request_id, prompt=prompt, @@ -629,7 +636,7 @@ def do_log_stats(self) -> None: def _get_stats(self, scheduler_outputs: Optional[SchedulerOutputs]) -> Stats: """Get Stats to be Logged to Prometheus.""" - now = time.monotonic() + now = time.time() # KV Cache Usage in %. num_total_gpu = self.cache_config.num_gpu_blocks diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index 5130586e036b2..ba93b1beb2aa4 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -8,6 +8,7 @@ import argparse import json +import ssl from typing import AsyncGenerator from fastapi import FastAPI, Request @@ -86,10 +87,12 @@ async def stream_results() -> AsyncGenerator[bytes, None]: type=str, default=None, help="The CA certificates file") - parser.add_argument("--ssl-cert-reqs", - type=int, - default=0, - help="Whether client certificate is required") + parser.add_argument( + "--ssl-cert-reqs", + type=int, + default=int(ssl.CERT_NONE), + help="Whether client certificate is required (see stdlib ssl module's)" + ) parser.add_argument( "--root-path", type=str, diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 00407bc0e809c..e0626ca4e9da1 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -5,6 +5,7 @@ import os import importlib import inspect +import ssl from prometheus_client import make_asgi_app import fastapi @@ -124,6 +125,16 @@ def parse_args(): type=str, default=None, help="The file path to the SSL cert file") + parser.add_argument("--ssl-ca-certs", + type=str, + default=None, + help="The CA certificates file") + parser.add_argument( + "--ssl-cert-reqs", + type=int, + default=int(ssl.CERT_NONE), + help="Whether client certificate is required (see stdlib ssl module's)" + ) parser.add_argument( "--root-path", type=str, @@ -262,4 +273,6 @@ async def authentication(request: Request, call_next): log_level=args.uvicorn_log_level, timeout_keep_alive=TIMEOUT_KEEP_ALIVE, ssl_keyfile=args.ssl_keyfile, - ssl_certfile=args.ssl_certfile) + ssl_certfile=args.ssl_certfile, + ssl_ca_certs=args.ssl_ca_certs, + ssl_cert_reqs=args.ssl_cert_reqs) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 26499b8d7a66f..9421880411611 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -55,6 +55,11 @@ class UsageInfo(BaseModel): completion_tokens: Optional[int] = 0 +class ResponseFormat(BaseModel): + # type must be "json_object" or "text" + type: str = Literal["text", "json_object"] + + class ChatCompletionRequest(BaseModel): model: str messages: List[Dict[str, str]] @@ -89,6 +94,8 @@ class ChatCompletionRequest(BaseModel): guided_json: Optional[Union[str, dict, BaseModel]] = None guided_regex: Optional[str] = None guided_choice: Optional[List[str]] = None + guided_grammar: Optional[str] = None + response_format: Optional[ResponseFormat] = None def to_sampling_params(self) -> SamplingParams: if self.logprobs and not self.top_logprobs: @@ -183,6 +190,8 @@ class CompletionRequest(BaseModel): guided_json: Optional[Union[str, dict, BaseModel]] = None guided_regex: Optional[str] = None guided_choice: Optional[List[str]] = None + guided_grammar: Optional[str] = None + response_format: Optional[ResponseFormat] = None def to_sampling_params(self): echo_without_generation = self.echo and self.max_tokens == 0 diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index d2fb9ca001b15..bfdfe39f210ed 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -103,7 +103,7 @@ async def chat_completion_stream_generator( ) -> Union[ErrorResponse, AsyncGenerator[str, None]]: model_name = request.model - created_time = int(time.monotonic()) + created_time = int(time.time()) chunk_object_type = "chat.completion.chunk" first_iteration = True @@ -244,7 +244,7 @@ async def chat_completion_full_generator( request_id: str) -> Union[ErrorResponse, ChatCompletionResponse]: model_name = request.model - created_time = int(time.monotonic()) + created_time = int(time.time()) final_res: RequestOutput = None async for res in result_generator: diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index b78f053800f3c..5f2be878a7b76 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -118,7 +118,7 @@ async def create_completion(self, request: CompletionRequest, model_name = request.model request_id = f"cmpl-{random_uuid()}" - created_time = int(time.monotonic()) + created_time = int(time.time()) # Schedule the request and get the result generator. generators = [] @@ -309,10 +309,7 @@ async def completion_stream_generator( except ValueError as e: # TODO: Use a vllm-specific Validation Error data = self.create_streaming_error_response(str(e)) - print("yield", f"data: {data}\n\n") yield f"data: {data}\n\n" - - print("yield", "data: [DONE]\n\n") yield "data: [DONE]\n\n" def request_output_to_completion_response( diff --git a/vllm/model_executor/guided_decoding.py b/vllm/model_executor/guided_decoding.py index 00984460d79a6..bd09cf9cb6ee3 100644 --- a/vllm/model_executor/guided_decoding.py +++ b/vllm/model_executor/guided_decoding.py @@ -6,19 +6,50 @@ from json import dumps as json_dumps from re import escape as regex_escape from typing import Union, Tuple + from pydantic import BaseModel +from transformers import PreTrainedTokenizerBase from vllm.entrypoints.openai.protocol import (CompletionRequest, ChatCompletionRequest) from vllm.model_executor.guided_logits_processors import (JSONLogitsProcessor, - RegexLogitsProcessor) + RegexLogitsProcessor, + CFGLogitsProcessor) class GuidedDecodingMode(Enum): JSON = "json" REGEX = "regex" CHOICE = "choice" + GRAMMAR = "grammar" + + +# https://github.com/outlines-dev/outlines/blob/main/outlines/grammars/json.lark +# the main difference is that we changed the start: value to +# start: object | array, so we are denying scalar values as the root of the +# JSON. Starting with scalars as the root seems to cause llama to generate +# without stop. +JSON_GRAMMAR = r""" +?start: object | array + +?value: object +| array +| UNESCAPED_STRING +| SIGNED_NUMBER -> number +| "true" -> true +| "false" -> false +| "null" -> null + +array : "[" [value ("," value)*] "]" +object : "{" [pair ("," pair)*] "}" +pair : UNESCAPED_STRING ":" value + +%import common.UNESCAPED_STRING +%import common.SIGNED_NUMBER +%import common.WS +%ignore WS +""" global_thread_pool = None # used for generating logits processor fsm @@ -57,9 +88,6 @@ def _get_guide_and_mode( ) -> Tuple[str, GuidedDecodingMode]: if request.guided_json: - if not isinstance(request.guided_json, (str, dict, BaseModel)): - raise TypeError("JSON schema must be str, dict, or BaseModel") - json = request.guided_json if isinstance(json, dict): # turn dict into hashable string @@ -69,33 +97,33 @@ def _get_guide_and_mode( # with the same fields will get hashed the same json = str(json.__signature__) return json, GuidedDecodingMode.JSON - elif request.guided_regex: - if not isinstance(request.guided_regex, str): - raise TypeError("Regex must be string") return request.guided_regex, GuidedDecodingMode.REGEX - elif request.guided_choice: - if not isinstance(request.guided_choice, list): - raise TypeError("Choices must be a list") - # choice just uses regex choices = [ regex_escape(str(choice)) for choice in request.guided_choice ] choices_regex = "(" + "|".join(choices) + ")" return choices_regex, GuidedDecodingMode.CHOICE - + elif request.guided_grammar: + return request.guided_grammar, GuidedDecodingMode.GRAMMAR + elif (request.response_format is not None + and request.response_format.type == "json_object"): + return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR else: return None, None @lru_cache(maxsize=32) -def _get_cached_logits_processor(guide: str, tokenizer, +def _get_cached_logits_processor(guide: str, + tokenizer: PreTrainedTokenizerBase, mode: GuidedDecodingMode): if mode == GuidedDecodingMode.JSON: return JSONLogitsProcessor(guide, tokenizer) elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE: return RegexLogitsProcessor(guide, tokenizer) + elif mode == GuidedDecodingMode.GRAMMAR: + return CFGLogitsProcessor(guide, tokenizer) else: raise ValueError(f"Unknown guided decoding mode {mode}") diff --git a/vllm/model_executor/guided_logits_processors.py b/vllm/model_executor/guided_logits_processors.py index 76d41aa37dd7b..2cd1ae1571065 100644 --- a/vllm/model_executor/guided_logits_processors.py +++ b/vllm/model_executor/guided_logits_processors.py @@ -16,30 +16,60 @@ import json import math from collections import defaultdict -from typing import Union, DefaultDict, Dict, List, Optional +from typing import Union, DefaultDict, Dict, List, Optional, Callable import torch from pydantic import BaseModel -from outlines.fsm.fsm import RegexFSM +from transformers import PreTrainedTokenizerBase +from outlines.fsm.fsm import RegexFSM, CFGFSM from outlines.fsm.json_schema import build_regex_from_schema -class RegexLogitsProcessor: +class BaseLogitsProcessor: - def __init__(self, regex_string: str, tokenizer): - """Compile the FSM that drives the regex-structured generation. + def adapt_tokenizer(self, tokenizer: PreTrainedTokenizerBase): + """Adapt vLLM's tokenizer to use to compile the FSM. - Parameters - ---------- - regex_string - A string that represents a regular expression - tokenizer - The model's tokenizer + The API of Outlines tokenizers is slightly different to that of + `transformers`. The decoder of outlines, returns a list whereas + the decode of vLLM returns an str. To sync the vLLM decoder with + outlines internal api, the decoder should be adapted. In addition + we need to handle the missing spaces to Llama's tokenizer to be + able to compile FSMs for this model. """ - tokenizer = self.adapt_tokenizer(tokenizer) - fsm = RegexFSM(regex_string, tokenizer) - self.fsm = fsm + if getattr(tokenizer, "_outlines_adapted", False): + return tokenizer + + tokenizer.vocabulary = tokenizer.get_vocab() + tokenizer.special_tokens = set(tokenizer.all_special_tokens) + + def convert_token_to_string(token: str) -> str: + from transformers.file_utils import SPIECE_UNDERLINE + + string = tokenizer.convert_tokens_to_string([token]) + + # A hack to handle missing spaces to HF's Llama tokenizers + if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": + return " " + string + + return string + + def change_decoder( + decoder: Callable[[List[int]], str] + ) -> Callable[[List[int]], List[str]]: + """Sync vLLM's decoder with the outlines by returning list.""" + + def new_decoder(inp_tokens: List[int]) -> List[str]: + return [decoder(inp_tokens)] + + return new_decoder + + tokenizer.convert_token_to_string = convert_token_to_string + tokenizer.decode = change_decoder(tokenizer.decode) + setattr(tokenizer, "_outlines_adapted", True) # noqa: B010 + + return tokenizer def init_state(self): """Initialize the FSM states.""" @@ -69,38 +99,30 @@ def __call__(self, input_ids: List[int], return scores - def adapt_tokenizer(self, tokenizer): - """Adapt vLLM's tokenizer to use to compile the FSM. - - The API of Outlines tokenizers is slightly different to that of - `transformers`. In addition we need to handle the missing spaces to - Llama's tokenizer to be able to compile FSMs for this model. - - """ - tokenizer.vocabulary = tokenizer.get_vocab() - tokenizer.special_tokens = set(tokenizer.all_special_tokens) - - def convert_token_to_string(token: str) -> str: - from transformers.file_utils import SPIECE_UNDERLINE - string = tokenizer.convert_tokens_to_string([token]) +class RegexLogitsProcessor(BaseLogitsProcessor): - # A hack to handle missing spaces to HF's Llama tokenizers - if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": - return " " + string - - return string + def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase): + """Compile the FSM that drives the regex-structured generation. - tokenizer.convert_token_to_string = convert_token_to_string + Parameters + ---------- + regex_string + A string that represents a regular expression + tokenizer + The model's tokenizer - return tokenizer + """ + tokenizer = self.adapt_tokenizer(tokenizer) + fsm = RegexFSM(regex_string, tokenizer) + self.fsm = fsm class JSONLogitsProcessor(RegexLogitsProcessor): def __init__(self, schema: Union[str, Dict, BaseModel], - tokenizer, + tokenizer: PreTrainedTokenizerBase, whitespace_pattern: Optional[str] = None): """Compile the FSM that drives the JSON-guided generation. @@ -130,3 +152,21 @@ def __init__(self, f"the JSON Schema specification") regex_string = build_regex_from_schema(schema_str, whitespace_pattern) super().__init__(regex_string, tokenizer) + + +class CFGLogitsProcessor(BaseLogitsProcessor): + + def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase): + """Compile the FSM that drives the context free grammar generation. + + Parameters + ---------- + cfg + A string that represents a context-free grammar + tokenizer + The model's tokenizer + + """ + tokenizer = self.adapt_tokenizer(tokenizer) + fsm = CFGFSM(cfg, tokenizer) + self.fsm = fsm diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py index f0a88ac8e27f8..ebba0ba0a261a 100644 --- a/vllm/model_executor/input_metadata.py +++ b/vllm/model_executor/input_metadata.py @@ -1,8 +1,10 @@ +from dataclasses import dataclass from typing import Optional import torch +@dataclass class InputMetadata: """Metadata for input sequences. Used in PagedAttention. @@ -15,40 +17,17 @@ class InputMetadata: kv_cache_dtype: Data type to store kv cache. """ - def __init__( - self, - is_prompt: bool, - slot_mapping: torch.Tensor, - prompt_lens: Optional[torch.Tensor], - max_seq_len: Optional[int], - start_loc: Optional[torch.Tensor], - max_context_len: Optional[int], - context_lens: Optional[torch.Tensor], - block_tables: Optional[torch.Tensor], - use_cuda_graph: bool, - kv_cache_dtype: str, - ) -> None: - self.is_prompt = is_prompt - self.prompt_lens = prompt_lens - self.max_seq_len = max_seq_len - self.start_loc = start_loc - self.max_context_len = max_context_len - self.slot_mapping = slot_mapping - self.context_lens = context_lens - self.block_tables = block_tables - self.use_cuda_graph = use_cuda_graph - self.kv_cache_dtype = kv_cache_dtype + is_prompt: bool + slot_mapping: torch.Tensor + prompt_lens: Optional[torch.Tensor] + max_seq_len: Optional[int] + start_loc: Optional[torch.Tensor] + max_context_len: Optional[int] + context_lens: Optional[torch.Tensor] + block_tables: Optional[torch.Tensor] + use_cuda_graph: bool + kv_cache_dtype: str - # Set during the execution of the first attention op. - # FIXME(woosuk): This is a hack. + def __post_init__(self): + # will not appear in the __repr__ and __init__ self.attn_bias = None - - def __repr__(self) -> str: - return ("InputMetadata(" - f"is_prompt={self.is_prompt}, " - f"max_context_len={self.max_context_len}, " - f"slot_mapping={self.slot_mapping}, " - f"context_lens={self.context_lens}, " - f"block_tables={self.block_tables}, " - f"use_cuda_graph={self.use_cuda_graph}, " - f"kv_cache_dtype={self.kv_cache_dtype})") diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py index 0c4f20d9e3a58..48e44445a4a20 100644 --- a/vllm/model_executor/layers/quantization/marlin.py +++ b/vllm/model_executor/layers/quantization/marlin.py @@ -47,7 +47,7 @@ def __init__( self.perm_len = 1024 def __repr__(self) -> str: - return f"MarlinConfig(group_size={self.group_size}" + return f"MarlinConfig(group_size={self.group_size})" @classmethod def get_name(cls) -> str: diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 3e4f843e649b4..12e0feddcb7f1 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -299,7 +299,11 @@ def __init__( self.config = config self.linear_method = linear_method self.model = Qwen2Model(config, linear_method) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + + if not config.tie_word_embeddings: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size) + self.sampler = Sampler(config.vocab_size) def forward( @@ -318,7 +322,11 @@ def sample( hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head.weight, hidden_states, + if self.config.tie_word_embeddings: + lm_head_weight = self.model.embed_tokens.weight + else: + lm_head_weight = self.lm_head.weight + next_tokens = self.sampler(lm_head_weight, hidden_states, sampling_metadata) return next_tokens @@ -340,6 +348,8 @@ def load_weights(self, model_name_or_path, cache_dir, load_format, revision): if "rotary_emb.inv_freq" in name: continue + if self.config.tie_word_embeddings and "lm_head.weight" in name: + continue for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue diff --git a/vllm/model_executor/parallel_utils/communication_op.py b/vllm/model_executor/parallel_utils/communication_op.py index 521b6b8a383b0..6f00fd001d956 100644 --- a/vllm/model_executor/parallel_utils/communication_op.py +++ b/vllm/model_executor/parallel_utils/communication_op.py @@ -177,7 +177,7 @@ def broadcast_tensor_dict( for key, value in metadata_list: if isinstance(value, TensorMetadata): tensor = tensor_dict[key] - torch.distributed.broadcast(tensor, src=src) + torch.distributed.broadcast(tensor, src=src, group=group) else: recv_metadata_list = [None] torch.distributed.broadcast_object_list(recv_metadata_list, diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 2600ea2642da2..f7a1a19a89bcf 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -5,12 +5,48 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.utils import make_async, LRUCache +from vllm.utils import make_async from vllm.transformers_utils.tokenizers import * logger = init_logger(__name__) +def get_cached_tokenizer( + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] +) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + """Get tokenizer with cached properties. + + This will patch the tokenizer object in place. + + By default, transformers will recompute multiple tokenizer properties + each time they are called, leading to a significant slowdown. This + function caches these properties for faster access.""" + + tokenizer_all_special_ids = set(tokenizer.all_special_ids) + tokenizer_all_special_tokens_extended = ( + tokenizer.all_special_tokens_extended) + tokenizer_all_special_tokens = set(tokenizer.all_special_tokens) + + class CachedTokenizer(tokenizer.__class__): + + @property + def all_special_ids(self): + return tokenizer_all_special_ids + + @property + def all_special_tokens(self): + return tokenizer_all_special_tokens + + @property + def all_special_tokens_extended(self): + return tokenizer_all_special_tokens_extended + + CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}" + + tokenizer.__class__ = CachedTokenizer + return tokenizer + + def get_tokenizer( tokenizer_name: str, *args, @@ -64,7 +100,7 @@ def get_tokenizer( logger.warning( "Using a slow tokenizer. This might cause a significant " "slowdown. Consider using a fast tokenizer instead.") - return tokenizer + return get_cached_tokenizer(tokenizer) def get_lora_tokenizer(lora_request: LoRARequest, *args, @@ -88,65 +124,6 @@ def get_lora_tokenizer(lora_request: LoRARequest, *args, get_lora_tokenizer_async = make_async(get_lora_tokenizer) -class TokenizerGroup: - """A group of tokenizers that can be used for LoRA adapters.""" - - def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int, - max_input_length: Optional[int], **tokenizer_config): - self.tokenizer_id = tokenizer_id - self.tokenizer_config = tokenizer_config - self.enable_lora = enable_lora - self.max_input_length = max_input_length - self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config) - if enable_lora: - self.lora_tokenizers = LRUCache(capacity=max_num_seqs) - else: - self.lora_tokenizers = None - - def encode(self, - prompt: str, - request_id: Optional[str] = None, - lora_request: Optional[LoRARequest] = None) -> List[int]: - tokenizer = self.get_lora_tokenizer(lora_request) - return tokenizer.encode(prompt) - - async def encode_async( - self, - prompt: str, - request_id: Optional[str] = None, - lora_request: Optional[LoRARequest] = None) -> List[int]: - tokenizer = await self.get_lora_tokenizer_async(lora_request) - return tokenizer.encode(prompt) - - def get_lora_tokenizer( - self, - lora_request: Optional[LoRARequest] = None - ) -> "PreTrainedTokenizer": - if not lora_request or not self.enable_lora: - return self.tokenizer - if lora_request.lora_int_id not in self.lora_tokenizers: - tokenizer = (get_lora_tokenizer( - lora_request, **self.tokenizer_config) or self.tokenizer) - self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer) - return tokenizer - else: - return self.lora_tokenizers.get(lora_request.lora_int_id) - - async def get_lora_tokenizer_async( - self, - lora_request: Optional[LoRARequest] = None - ) -> "PreTrainedTokenizer": - if not lora_request or not self.enable_lora: - return self.tokenizer - if lora_request.lora_int_id not in self.lora_tokenizers: - tokenizer = (await get_lora_tokenizer_async( - lora_request, **self.tokenizer_config) or self.tokenizer) - self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer) - return tokenizer - else: - return self.lora_tokenizers.get(lora_request.lora_int_id) - - def _convert_tokens_to_string_with_added_encoders( tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], output_tokens: List[str], diff --git a/vllm/transformers_utils/tokenizer_group/__init__.py b/vllm/transformers_utils/tokenizer_group/__init__.py new file mode 100644 index 0000000000000..adc8d9b90ddb6 --- /dev/null +++ b/vllm/transformers_utils/tokenizer_group/__init__.py @@ -0,0 +1,32 @@ +from typing import Optional +from vllm.config import TokenizerPoolConfig +from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import ( + BaseTokenizerGroup) +from vllm.transformers_utils.tokenizer_group.tokenizer_group import ( + TokenizerGroup) +from vllm.engine.ray_utils import ray + +if ray: + from vllm.transformers_utils.tokenizer_group.ray_tokenizer_group import ( + RayTokenizerGroupPool) +else: + RayTokenizerGroupPool = None + + +def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig], + **init_kwargs) -> BaseTokenizerGroup: + if tokenizer_pool_config is None: + return TokenizerGroup(**init_kwargs) + if tokenizer_pool_config.pool_type == "ray": + if RayTokenizerGroupPool is None: + raise ImportError( + "RayTokenizerGroupPool is not available. Please install " + "the ray package to use the Ray tokenizer group pool.") + return RayTokenizerGroupPool.from_config(tokenizer_pool_config, + **init_kwargs) + else: + raise ValueError( + f"Unknown pool type: {tokenizer_pool_config.pool_type}") + + +__all__ = ["get_tokenizer_group", "BaseTokenizerGroup"] diff --git a/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py new file mode 100644 index 0000000000000..99518a606fabe --- /dev/null +++ b/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py @@ -0,0 +1,48 @@ +from abc import ABC, abstractmethod +from typing import List, Optional + +from transformers import PreTrainedTokenizer + +from vllm.lora.request import LoRARequest + + +class BaseTokenizerGroup(ABC): + """A group of tokenizers that can be used for LoRA adapters.""" + + @abstractmethod + def ping(self) -> bool: + """Check if the tokenizer group is alive.""" + pass + + @abstractmethod + def get_max_input_len(self, + lora_request: Optional[LoRARequest] = None + ) -> Optional[int]: + """Get the maximum input length for the LoRA request.""" + pass + + @abstractmethod + def encode(self, prompt: str, request_id: Optional[str], + lora_request: Optional[LoRARequest]) -> List[int]: + """Encode a prompt using the tokenizer group.""" + pass + + @abstractmethod + async def encode_async(self, prompt: str, request_id: Optional[str], + lora_request: Optional[LoRARequest]) -> List[int]: + """Encode a prompt using the tokenizer group.""" + pass + + @abstractmethod + def get_lora_tokenizer( + self, + lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer": + """Get a tokenizer for a LoRA request.""" + pass + + @abstractmethod + async def get_lora_tokenizer_async( + self, + lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer": + """Get a tokenizer for a LoRA request.""" + pass diff --git a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py new file mode 100644 index 0000000000000..e048ec05bece7 --- /dev/null +++ b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py @@ -0,0 +1,166 @@ +import asyncio +import os +from typing import List, Optional + +from transformers import PreTrainedTokenizer + +from vllm.config import TokenizerPoolConfig +from vllm.lora.request import LoRARequest +from vllm.engine.ray_utils import ray +from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import ( + BaseTokenizerGroup) +from vllm.transformers_utils.tokenizer_group.tokenizer_group import ( + TokenizerGroup) +from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy + + +class RayTokenizerGroupPool(BaseTokenizerGroup): + """A Ray-based pool of TokenizerGroups for async tokenization.""" + + # Class to use for workers making up the pool. + _worker_cls = TokenizerGroup + + @classmethod + def from_config(cls, tokenizer_pool_config: TokenizerPoolConfig, + **init_kwargs) -> "RayTokenizerGroupPool": + ray_actor_options = (tokenizer_pool_config.extra_config or { + "num_cpus": 0 + }) + ray_actor_options.setdefault( + "scheduling_strategy", + NodeAffinitySchedulingStrategy( + node_id=ray.get_runtime_context().get_node_id(), soft=True)) + + # Carry over the env vars to the actors. + # This is necessary for API keys and such. + ray_actor_options.setdefault("runtime_env", {}) + _carry_over_env_vars_to_runtime_env(ray_actor_options["runtime_env"]) + + init_kwargs["num_actors"] = tokenizer_pool_config.pool_size + init_kwargs["ray_actor_options"] = ray_actor_options + + return cls(**init_kwargs) + + def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int, + max_input_length: Optional[int], num_actors: int, + ray_actor_options: dict, **tokenizer_config): + # Store a local copy of the TokenizerGroup for quick access + # to underlying HF tokenizers. + self._local_tokenizer_group = self._worker_cls( + tokenizer_id=tokenizer_id, + enable_lora=enable_lora, + max_num_seqs=max_num_seqs, + max_input_length=max_input_length, + ) + + ray_tokenizer_group_cls = ray.remote( + self._worker_cls).options(**ray_actor_options) + self.tokenizer_actors = [ + ray_tokenizer_group_cls.remote(tokenizer_id, enable_lora, + max_num_seqs, max_input_length, + **tokenizer_config) + for _ in range(num_actors) + ] + self._idle_actors: Optional[asyncio.Queue] = None + + @property + def pool_size(self) -> int: + return len(self.tokenizer_actors) + + def ping(self): + return ray.get( + [actor.ping.remote() for actor in self.tokenizer_actors]) + + def _ensure_queue_initialized(self): + if self._idle_actors is None: + self._idle_actors = asyncio.Queue() + for actor in self.tokenizer_actors: + self._idle_actors.put_nowait(actor) + + def encode(self, + prompt: str, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None) -> List[int]: + """Encode a prompt using the tokenizer group. + + We pick an idle actor and use it to encode the prompt. + The actor is then put back in the queue for future use. + This is blocking. + """ + self._ensure_queue_initialized() + + if self._idle_actors.empty(): + raise RuntimeError("No idle actors available.") + actor = self._idle_actors.get_nowait() + try: + ret = ray.get( + actor.encode.remote(request_id=request_id, + prompt=prompt, + lora_request=lora_request)) + finally: + # Put the actor back in the queue. + # This is done in a finally block to ensure that the actor is + # always put back in the queue, even if an exception/cancellation + # is raised. + self._idle_actors.put_nowait(actor) + return ret + + async def encode_async( + self, + prompt: str, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None) -> List[int]: + """Encode a prompt using the tokenizer group. + + We pick an idle actor and use it to encode the prompt. + If there are no idle actors, we wait until one becomes + available. + The actor is then put back in the queue for future use. + This is non-blocking. + """ + self._ensure_queue_initialized() + + actor = await self._idle_actors.get() + try: + ret = await actor.encode.remote(request_id=request_id, + prompt=prompt, + lora_request=lora_request) + finally: + # Put the actor back in the queue. + # This is done in a finally block to ensure that the actor is + # always put back in the queue, even if an exception/cancellation + # is raised. + self._idle_actors.put_nowait(actor) + return ret + + def get_max_input_len(self, + lora_request: Optional[LoRARequest] = None + ) -> Optional[int]: + """Get the maximum input length for the LoRA request.""" + return self._local_tokenizer_group.get_max_input_len(lora_request) + + def get_lora_tokenizer( + self, + lora_request: Optional[LoRARequest] = None + ) -> "PreTrainedTokenizer": + return self._local_tokenizer_group.get_lora_tokenizer(lora_request) + + async def get_lora_tokenizer_async( + self, + lora_request: Optional[LoRARequest] = None + ) -> "PreTrainedTokenizer": + return await self._local_tokenizer_group.get_lora_tokenizer_async( + lora_request) + + +def _carry_over_env_vars_to_runtime_env(runtime_env: dict) -> None: + """Copy over all current process environment variables to the runtime_env. + + The variables in runtime_env will take precedence over the current process + environment variables. + + runtime_env will be modified in place.""" + env_vars = os.environ.copy() + runtime_env.setdefault("env_vars", {}) + env_vars.update(runtime_env["env_vars"]) + runtime_env["env_vars"] = env_vars diff --git a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py new file mode 100644 index 0000000000000..3af1334cb5ede --- /dev/null +++ b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py @@ -0,0 +1,80 @@ +from typing import List, Optional + +from transformers import PreTrainedTokenizer + +from vllm.lora.request import LoRARequest +from vllm.transformers_utils.tokenizer import (get_lora_tokenizer, + get_lora_tokenizer_async) +from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import ( + BaseTokenizerGroup) +from vllm.utils import LRUCache +from vllm.transformers_utils.tokenizer import get_tokenizer + + +class TokenizerGroup(BaseTokenizerGroup): + """A group of tokenizers that can be used for LoRA adapters.""" + + def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int, + max_input_length: Optional[int], **tokenizer_config): + self.tokenizer_id = tokenizer_id + self.tokenizer_config = tokenizer_config + self.enable_lora = enable_lora + self.max_input_length = max_input_length + self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config) + if enable_lora: + self.lora_tokenizers = LRUCache(capacity=max_num_seqs) + else: + self.lora_tokenizers = None + + def ping(self) -> bool: + """Check if the tokenizer group is alive.""" + return True + + def get_max_input_len(self, + lora_request: Optional[LoRARequest] = None + ) -> Optional[int]: + """Get the maximum input length for the LoRA request.""" + return self.max_input_length + + def encode(self, + prompt: str, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None) -> List[int]: + tokenizer = self.get_lora_tokenizer(lora_request) + return tokenizer.encode(prompt) + + async def encode_async( + self, + prompt: str, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None) -> List[int]: + tokenizer = await self.get_lora_tokenizer_async(lora_request) + return tokenizer.encode(prompt) + + def get_lora_tokenizer( + self, + lora_request: Optional[LoRARequest] = None + ) -> "PreTrainedTokenizer": + if not lora_request or not self.enable_lora: + return self.tokenizer + if lora_request.lora_int_id not in self.lora_tokenizers: + tokenizer = (get_lora_tokenizer( + lora_request, **self.tokenizer_config) or self.tokenizer) + self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer) + return tokenizer + else: + return self.lora_tokenizers.get(lora_request.lora_int_id) + + async def get_lora_tokenizer_async( + self, + lora_request: Optional[LoRARequest] = None + ) -> "PreTrainedTokenizer": + if not lora_request or not self.enable_lora: + return self.tokenizer + if lora_request.lora_int_id not in self.lora_tokenizers: + tokenizer = (await get_lora_tokenizer_async( + lora_request, **self.tokenizer_config) or self.tokenizer) + self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer) + return tokenizer + else: + return self.lora_tokenizers.get(lora_request.lora_int_id) diff --git a/vllm/utils.py b/vllm/utils.py index fe6fd27962cd3..d4a8c962c3bfc 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -21,6 +21,7 @@ from typing import Any, Hashable, Optional from vllm.logger import init_logger +import warnings T = TypeVar("T") logger = init_logger(__name__) @@ -172,16 +173,35 @@ def _async_wrapper(*args, **kwargs) -> asyncio.Future: def get_ip() -> str: + host_ip = os.environ.get("HOST_IP") + if host_ip: + return host_ip + + # IP is not set, try to get it from the network interface + # try ipv4 s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) try: s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable return s.getsockname()[0] - except OSError: - # try ipv6 + except Exception: + pass + + # try ipv6 + try: s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) - s.connect(("dns.google", 80)) + # Google's public DNS server, see + # https://developers.google.com/speed/public-dns/docs/using#addresses + s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable return s.getsockname()[0] + except Exception: + pass + + warnings.warn( + "Failed to get the IP address, using 0.0.0.0 by default." + "The value can be set by the environment variable HOST_IP.", + stacklevel=2) + return "0.0.0.0" def get_distributed_init_method(ip: str, port: int) -> str: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 7eac576e3f0fe..1ef783da6d08e 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1,4 +1,5 @@ import contextlib +import dataclasses import time from typing import Dict, List, Optional, Tuple, Set, Union @@ -521,45 +522,27 @@ def prepare_input_tensors( metadata_dict = { "input_tokens": input_tokens, "input_positions": input_positions, - "is_prompt": input_metadata.is_prompt, - "slot_mapping": input_metadata.slot_mapping, - "prompt_lens": input_metadata.prompt_lens, - "max_seq_len": input_metadata.max_seq_len, - "start_loc": input_metadata.start_loc, - "max_context_len": input_metadata.max_context_len, - "context_lens": input_metadata.context_lens, - "block_tables": input_metadata.block_tables, - "use_cuda_graph": input_metadata.use_cuda_graph, - "kv_cache_dtype": input_metadata.kv_cache_dtype, "selected_token_indices": sampling_metadata.selected_token_indices, "lora_requests": lora_requests, "lora_mapping": lora_mapping, } + metadata_dict.update(dataclasses.asdict(input_metadata)) broadcast_tensor_dict(metadata_dict, src=0) else: metadata_dict = broadcast_tensor_dict(src=0) - input_tokens = metadata_dict["input_tokens"] - input_positions = metadata_dict["input_positions"] - lora_mapping = metadata_dict["lora_mapping"] - lora_requests = metadata_dict["lora_requests"] - input_metadata = InputMetadata( - is_prompt=metadata_dict["is_prompt"], - slot_mapping=metadata_dict["slot_mapping"], - prompt_lens=metadata_dict["prompt_lens"], - max_seq_len=metadata_dict["max_seq_len"], - start_loc=metadata_dict["start_loc"], - max_context_len=metadata_dict["max_context_len"], - context_lens=metadata_dict["context_lens"], - block_tables=metadata_dict["block_tables"], - use_cuda_graph=metadata_dict["use_cuda_graph"], - kv_cache_dtype=metadata_dict["kv_cache_dtype"], - ) + input_tokens = metadata_dict.pop("input_tokens") + input_positions = metadata_dict.pop("input_positions") + selected_token_indices = metadata_dict.pop( + "selected_token_indices") + lora_mapping = metadata_dict.pop("lora_mapping") + lora_requests = metadata_dict.pop("lora_requests") + input_metadata = InputMetadata(**metadata_dict) sampling_metadata = SamplingMetadata( seq_groups=None, seq_data=None, prompt_lens=None, - selected_token_indices=metadata_dict["selected_token_indices"], + selected_token_indices=selected_token_indices, categorized_sample_indices=None, generators=None, perform_sampling=False,