diff --git a/README.md b/README.md index c6e6a3c7379db..8ea4d029dc64f 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,7 @@ Easy, fast, and cheap LLM serving for everyone - [2023/06] We officially released vLLM! FastChat-vLLM integration has powered [LMSYS Vicuna and Chatbot Arena](https://chat.lmsys.org) since mid-April. Check out our [blog post](https://vllm.ai). --- - +## About vLLM is a fast and easy-to-use library for LLM inference and serving. vLLM is fast with: @@ -54,6 +54,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi - Baichuan & Baichuan2 (`baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc.) - BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.) - ChatGLM (`THUDM/chatglm2-6b`, `THUDM/chatglm3-6b`, etc.) +- DeciLM (`Deci/DeciLM-7B`, `Deci/DeciLM-7B-instruct`, etc.) - Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.) - GPT-2 (`gpt2`, `gpt2-xl`, etc.) - GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.) diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu index e1dc711778ffb..486ebe1d464c8 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/pos_encoding_kernels.cu @@ -43,8 +43,8 @@ __global__ void rotary_embedding_kernel( scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] const int rot_dim, - const int query_stride, - const int key_stride, + const int64_t query_stride, + const int64_t key_stride, const int num_heads, const int num_kv_heads, const int head_size) { @@ -60,7 +60,7 @@ __global__ void rotary_embedding_kernel( const int nq = num_heads * embed_dim; for (int i = threadIdx.x; i < nq; i += blockDim.x) { const int head_idx = i / embed_dim; - const int token_head = token_idx * query_stride + head_idx * head_size; + const int64_t token_head = token_idx * query_stride + head_idx * head_size; const int rot_offset = i % embed_dim; apply_rotary_embedding(query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); @@ -69,7 +69,7 @@ __global__ void rotary_embedding_kernel( const int nk = num_kv_heads * embed_dim; for (int i = threadIdx.x; i < nk; i += blockDim.x) { const int head_idx = i / embed_dim; - const int token_head = token_idx * key_stride + head_idx * head_size; + const int64_t token_head = token_idx * key_stride + head_idx * head_size; const int rot_offset = i % embed_dim; apply_rotary_embedding(key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); @@ -89,8 +89,8 @@ void rotary_embedding( int rot_dim = cos_sin_cache.size(1); int num_heads = query.size(-1) / head_size; int num_kv_heads = key.size(-1) / head_size; - int query_stride = query.stride(-2); - int key_stride = key.stride(-2); + int64_t query_stride = query.stride(-2); + int64_t key_stride = key.stride(-2); dim3 grid(num_tokens); dim3 block(std::min(num_heads * rot_dim / 2, 512)); diff --git a/csrc/quantization/gptq/q_gemm.cu b/csrc/quantization/gptq/q_gemm.cu index 6d070c658f153..eb0d75f1293c4 100644 --- a/csrc/quantization/gptq/q_gemm.cu +++ b/csrc/quantization/gptq/q_gemm.cu @@ -28,6 +28,7 @@ namespace gptq { #define DIVIDE(x, size) (((x) + (size) - 1) / (size)) #if defined(USE_ROCM) +#include __host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle, hipblasOperation_t transA, hipblasOperation_t transB, @@ -520,12 +521,21 @@ __global__ void gemm_half_q_half_alt_kernel( zeros_tmp[tmp_k] = zero; } for (int m = 0; m < b_end; m++) { +#ifndef USE_ROCM res2 = {}; +#else + res2.x = __half_as_ushort(__float2half(0)); + res2.y = __half_as_ushort(__float2half(0)); +#endif res2 = __hfma2(__hfma2(deq2[(tmp >> 0) & 0xff][off], scales_tmp[0], zeros_tmp[0]), blockvec[m][k + 0], res2); res2 = __hfma2(__hfma2(deq2[(tmp >> 8) & 0xff][off], scales_tmp[1], zeros_tmp[1]), blockvec[m][k + 1], res2); res2 = __hfma2(__hfma2(deq2[(tmp >> 16) & 0xff][off], scales_tmp[2], zeros_tmp[2]), blockvec[m][k + 2], res2); res2 = __hfma2(__hfma2(deq2[(tmp >> 24) & 0xff][off], scales_tmp[3], zeros_tmp[3]), blockvec[m][k + 3], res2); +#ifndef USE_ROCM res[m] = __hadd(res[m], __hadd(res2.x, res2.y)); +#else + res[m] = __hadd(res[m], __hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y))); +#endif } i += width; k += 4; diff --git a/docs/source/getting_started/amd-installation.rst b/docs/source/getting_started/amd-installation.rst index 6fb072a0c3c9f..181c970e0b2a7 100644 --- a/docs/source/getting_started/amd-installation.rst +++ b/docs/source/getting_started/amd-installation.rst @@ -116,6 +116,7 @@ Alternatively, if you plan to install vLLM-ROCm on a local machine or start from - `ROCm `_ - `Pytorch `_ +- `hipBLAS `_ 1. Install `flash attention for ROCm `_ diff --git a/docs/source/getting_started/installation.rst b/docs/source/getting_started/installation.rst index e7a2d0a6f0d03..911c3d8f9a4ab 100644 --- a/docs/source/getting_started/installation.rst +++ b/docs/source/getting_started/installation.rst @@ -42,6 +42,10 @@ You can install vLLM using pip: $ pip uninstall torch -y $ pip install torch --upgrade --index-url https://download.pytorch.org/whl/cu118 + $ # Re-install xFormers with CUDA 11.8. + $ pip uninstall xformers -y + $ pip install --upgrade xformers --index-url https://download.pytorch.org/whl/cu118 + .. _build_from_source: diff --git a/docs/source/models/engine_args.rst b/docs/source/models/engine_args.rst index a70c22e9af11a..d89b795149501 100644 --- a/docs/source/models/engine_args.rst +++ b/docs/source/models/engine_args.rst @@ -89,9 +89,11 @@ Below, you can find an explanation of every engine argument for vLLM: CPU swap space size (GiB) per GPU. -.. option:: --gpu-memory-utilization +.. option:: --gpu-memory-utilization - The percentage of GPU memory to be used for the model executor. + The fraction of GPU memory to be used for the model executor, which can range from 0 to 1. + For example, a value of 0.5 would imply 50% GPU memory utilization. + If unspecified, will use the default value of 0.9. .. option:: --max-num-batched-tokens diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index c95b158e871fe..361ad5f5a22bd 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -23,6 +23,9 @@ Alongside each architecture, we include some popular models that use it. * - :code:`ChatGLMModel` - ChatGLM - :code:`THUDM/chatglm2-6b`, :code:`THUDM/chatglm3-6b`, etc. + * - :code:`DeciLMForCausalLM` + - DeciLM + - :code:`Deci/DeciLM-7B`, :code:`Deci/DeciLM-7B-instruct`, etc. * - :code:`BloomForCausalLM` - BLOOM, BLOOMZ, BLOOMChat - :code:`bigscience/bloom`, :code:`bigscience/bloomz`, etc. @@ -90,7 +93,7 @@ Alternatively, you can raise an issue on our `GitHub `_ instead of HuggingFace Hub, set an environment variable: + To use models from `ModelScope `_ instead of HuggingFace Hub, set an environment variable: .. code-block:: shell diff --git a/docs/source/serving/serving_with_langchain.rst b/docs/source/serving/serving_with_langchain.rst index 8ae75d7a80d24..2e1ce688290ad 100644 --- a/docs/source/serving/serving_with_langchain.rst +++ b/docs/source/serving/serving_with_langchain.rst @@ -28,4 +28,4 @@ To run inference on a single or multiple GPUs, use ``VLLM`` class from ``langcha print(llm("What is the capital of France ?")) -Please refer to this `Tutorial `_ for more details. \ No newline at end of file +Please refer to this `Tutorial `_ for more details. \ No newline at end of file diff --git a/setup.py b/setup.py index 45a18776798fb..811d494e7a01f 100644 --- a/setup.py +++ b/setup.py @@ -219,13 +219,13 @@ def get_torch_arch_list() -> Set[str]: "csrc/activation_kernels.cu", "csrc/layernorm_kernels.cu", "csrc/quantization/squeezellm/quant_cuda_kernel.cu", + "csrc/quantization/gptq/q_gemm.cu", "csrc/cuda_utils_kernels.cu", "csrc/pybind.cpp", ] if _is_cuda(): vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu") - vllm_extension_sources.append("csrc/quantization/gptq/q_gemm.cu") vllm_extension = CUDAExtension( name="vllm._C", diff --git a/tests/async_engine/test_api_server.py b/tests/async_engine/test_api_server.py index d90ba37b27bb9..f1891f78491c2 100644 --- a/tests/async_engine/test_api_server.py +++ b/tests/async_engine/test_api_server.py @@ -44,13 +44,14 @@ def test_api_server(api_server): """ with Pool(32) as pool: # Wait until the server is ready - prompts = ["Hello world"] * 1 + prompts = ["warm up"] * 1 result = None while not result: try: - for _ in pool.map(_query_server, prompts): + for r in pool.map(_query_server, prompts): + result = r break - except Exception: + except requests.exceptions.ConnectionError: time.sleep(1) # Actual tests start here @@ -63,13 +64,14 @@ def test_api_server(api_server): assert num_aborted_requests == 0 # Try with 100 prompts - prompts = ["Hello world"] * 100 + prompts = ["test prompt"] * 100 for result in pool.map(_query_server, prompts): assert result # Cancel requests + prompts = ["canceled requests"] * 100 pool.map_async(_query_server, prompts) - time.sleep(0.01) + time.sleep(0.001) pool.terminate() pool.join() @@ -81,6 +83,6 @@ def test_api_server(api_server): # check that server still runs after cancellations with Pool(32) as pool: # Try with 100 prompts - prompts = ["Hello world"] * 100 + prompts = ["test prompt after canceled"] * 100 for result in pool.map(_query_server, prompts): assert result diff --git a/tests/conftest.py b/tests/conftest.py index 16c04e01d703c..8d6afdbd00358 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,8 +8,9 @@ from vllm import LLM, SamplingParams from vllm.transformers_utils.tokenizer import get_tokenizer -_TEST_PROMPTS = ["prompts/example.txt"] -_LONG_PROMPTS = ["prompts/summary.txt"] +_TEST_DIR = os.path.dirname(__file__) +_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")] +_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")] def _read_prompts(filename: str) -> str: @@ -24,7 +25,7 @@ def _read_prompts(filename: str) -> str: def example_prompts() -> List[str]: prompts = [] for filename in _TEST_PROMPTS: - prompts += _read_prompts(os.path.join("tests", filename)) + prompts += _read_prompts(filename) return prompts @@ -32,7 +33,7 @@ def example_prompts() -> List[str]: def example_long_prompts() -> List[str]: prompts = [] for filename in _LONG_PROMPTS: - prompts += _read_prompts(os.path.join("tests", filename)) + prompts += _read_prompts(filename) return prompts diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index 733c7395811ef..b9895b3e71794 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -8,7 +8,7 @@ import torch from vllm.config import ParallelConfig -from vllm.engine.ray_utils import get_open_port +from vllm.utils import get_open_port from vllm.model_executor.parallel_utils.communication_op import ( tensor_model_parallel_all_reduce, tensor_model_parallel_all_gather, diff --git a/tests/models/test_models.py b/tests/models/test_models.py index e65c424c601a2..518eae201ed32 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -8,6 +8,7 @@ "facebook/opt-125m", "meta-llama/Llama-2-7b-hf", "mistralai/Mistral-7B-v0.1", + "Deci/DeciLM-7b", "tiiuae/falcon-7b", "gpt2", "bigcode/tiny_starcoder_py", diff --git a/vllm/__init__.py b/vllm/__init__.py index 6ad0d3b6d7e99..72464acd616a2 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -9,7 +9,7 @@ from vllm.outputs import CompletionOutput, RequestOutput from vllm.sampling_params import SamplingParams -__version__ = "0.2.5" +__version__ = "0.2.6" __all__ = [ "LLM", diff --git a/vllm/config.py b/vllm/config.py index a2b2050240f57..ff9a1308a5c88 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -112,24 +112,20 @@ def _verify_load_format(self) -> None: supported_load_format = [ "auto", "pt", "safetensors", "npcache", "dummy" ] - rocm_not_supported_load_format = ["safetensors"] + rocm_not_supported_load_format = [] if load_format not in supported_load_format: raise ValueError( f"Unknown load format: {self.load_format}. Must be one of " "'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.") - if is_hip(): - if load_format in ["safetensors"]: - rocm_supported_load_format = [ - f for f in supported_load_format - if (f not in rocm_not_supported_load_format) - ] - raise ValueError( - f"load format \'{load_format}\' is not supported in ROCm. " - f"Supported load format are " - f"{rocm_supported_load_format}") - # Force ROCm to load from pt weights if nothing specific is set - if load_format == "auto": - load_format = "pt" + if is_hip() and load_format in rocm_not_supported_load_format: + rocm_supported_load_format = [ + f for f in supported_load_format + if (f not in rocm_not_supported_load_format) + ] + raise ValueError( + f"load format \'{load_format}\' is not supported in ROCm. " + f"Supported load format are " + f"{rocm_supported_load_format}") # TODO: Remove this check once HF updates the pt weights of Mixtral. architectures = getattr(self.hf_config, "architectures", []) @@ -149,7 +145,7 @@ def _verify_tokenizer_mode(self) -> None: def _verify_quantization(self) -> None: supported_quantization = ["awq", "gptq", "squeezellm"] - rocm_not_supported_quantization = ["awq", "gptq"] + rocm_not_supported_quantization = ["awq"] if self.quantization is not None: self.quantization = self.quantization.lower() @@ -185,10 +181,11 @@ def _verify_cuda_graph(self) -> None: self.max_context_len_to_capture = self.max_model_len self.max_context_len_to_capture = min(self.max_context_len_to_capture, self.max_model_len) - if self.quantization == "gptq" and not self.enforce_eager: + if (self.quantization in ["gptq", "squeezellm"] + and not self.enforce_eager): # Related issue: https://github.com/vllm-project/vllm/issues/2147 - logger.warning("GPTQ does not support CUDA graph yet. Disabling " - "CUDA graph.") + logger.warning(f"{self.quantization} does not support CUDA graph " + "yet. Disabling CUDA graph.") self.enforce_eager = True def verify_with_parallel_config( diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index 8b26319b88cd3..3bde005997bde 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -103,7 +103,7 @@ def __init__( def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: # FIXME(woosuk): Here we assume that all sequences in the group share # the same prompt. This may not be true for preempted sequences. - seq = seq_group.get_seqs()[0] + seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] num_required_blocks = len(seq.logical_token_blocks) if self.block_sliding_window is not None: num_required_blocks = min(num_required_blocks, @@ -122,7 +122,7 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: def allocate(self, seq_group: SequenceGroup) -> None: # NOTE: Here we assume that all sequences in the group have the same # prompt. - seq = seq_group.get_seqs()[0] + seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] # Allocate new physical token blocks that will store the prompt tokens. block_table: BlockTable = [] @@ -137,7 +137,7 @@ def allocate(self, seq_group: SequenceGroup) -> None: block_table.append(block) # Assign the block table for each sequence. - for seq in seq_group.get_seqs(): + for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): self.block_tables[seq.seq_id] = block_table.copy() def can_append_slot(self, seq_group: SequenceGroup) -> bool: diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index ca28bbdc2fb95..398585a88fb52 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -139,15 +139,17 @@ def _schedule(self) -> SchedulerOutputs: while self.waiting: seq_group = self.waiting[0] - assert seq_group.num_seqs() == 1, ( + waiting_seqs = seq_group.get_seqs( + status=SequenceStatus.WAITING) + assert len(waiting_seqs) == 1, ( "Waiting sequence group should have only one prompt " "sequence.") - num_prompt_tokens = seq_group.get_seqs()[0].get_len() + num_prompt_tokens = waiting_seqs[0].get_len() if num_prompt_tokens > self.prompt_limit: logger.warning( f"Input prompt ({num_prompt_tokens} tokens) is too long" f" and exceeds limit of {self.prompt_limit}") - for seq in seq_group.get_seqs(): + for seq in waiting_seqs: seq.status = SequenceStatus.FINISHED_IGNORED ignored_seq_groups.append(seq_group) self.waiting.pop(0) @@ -161,7 +163,7 @@ def _schedule(self) -> SchedulerOutputs: logger.warning( f"Input prompt ({num_prompt_tokens} tokens) is too long" f" and exceeds the capacity of block_manager") - for seq in seq_group.get_seqs(): + for seq in waiting_seqs: seq.status = SequenceStatus.FINISHED_IGNORED ignored_seq_groups.append(seq_group) self.waiting.pop(0) @@ -317,7 +319,7 @@ def free_finished_seq_groups(self) -> None: def _allocate(self, seq_group: SequenceGroup) -> None: self.block_manager.allocate(seq_group) - for seq in seq_group.get_seqs(): + for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): seq.status = SequenceStatus.RUNNING def _append_slot( diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 7a571ceefbc85..7e58069e2c22d 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -156,11 +156,13 @@ def add_cli_args( type=int, default=EngineArgs.swap_space, help='CPU swap space size (GiB) per GPU') - parser.add_argument('--gpu-memory-utilization', - type=float, - default=EngineArgs.gpu_memory_utilization, - help='the percentage of GPU memory to be used for' - 'the model executor') + parser.add_argument( + '--gpu-memory-utilization', + type=float, + default=EngineArgs.gpu_memory_utilization, + help='the fraction of GPU memory to be used for ' + 'the model executor, which can range from 0 to 1.' + 'If unspecified, will use the default value of 0.9.') parser.add_argument('--max-num-batched-tokens', type=int, default=EngineArgs.max_num_batched_tokens, diff --git a/vllm/engine/async_llava_engine.py b/vllm/engine/async_llava_engine.py index a2b7c1d5b3e31..4c75be541d203 100644 --- a/vllm/engine/async_llava_engine.py +++ b/vllm/engine/async_llava_engine.py @@ -28,21 +28,19 @@ async def step_async(self) -> List[RequestOutput]: future when we merge the execute_llava_model function to the execute_model. """ - seq_group_metadata_list, scheduler_outputs, ignored = self._schedule() - if scheduler_outputs.is_empty(): - return ignored + seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() # Execute the model. - output = await self._run_workers_async( + output = (await self._run_workers_async( "execute_model", seq_group_metadata_list=seq_group_metadata_list, blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, blocks_to_copy=scheduler_outputs.blocks_to_copy, runner_method="execute_llava_model", - ) + )) if not scheduler_outputs.is_empty() else [] - return self._process_model_outputs(output, scheduler_outputs) + ignored + return self._process_model_outputs(output, scheduler_outputs) class AsyncLLaVAEngine(AsyncLLMEngine): diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index d854a20b8b95a..611da51f61931 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -183,20 +183,18 @@ async def step_async(self) -> List[RequestOutput]: and updates the scheduler with the model outputs. Finally, it decodes the sequences and returns the newly generated results. """ - seq_group_metadata_list, scheduler_outputs, ignored = self._schedule() - if scheduler_outputs.is_empty(): - return ignored + seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() # Execute the model. - output = await self._run_workers_async( + output = (await self._run_workers_async( "execute_model", seq_group_metadata_list=seq_group_metadata_list, blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, blocks_to_copy=scheduler_outputs.blocks_to_copy, - ) + )) if not scheduler_outputs.is_empty() else [] - return self._process_model_outputs(output, scheduler_outputs) + ignored + return self._process_model_outputs(output, scheduler_outputs) async def _run_workers_async( self, diff --git a/vllm/engine/llava_engine.py b/vllm/engine/llava_engine.py index d201db4cd7b68..6e99979be9c8f 100644 --- a/vllm/engine/llava_engine.py +++ b/vllm/engine/llava_engine.py @@ -103,9 +103,7 @@ def step(self) -> List[RequestOutput]: and updates the scheduler with the model outputs. Finally, it decodes the sequences and returns the newly generated results. """ - seq_group_metadata_list, scheduler_outputs, ignored = self._schedule() - if scheduler_outputs.is_empty(): - return ignored + seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() # Execute the model. output = self._run_workers( @@ -115,6 +113,6 @@ def step(self) -> List[RequestOutput]: blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, blocks_to_copy=scheduler_outputs.blocks_to_copy, runner_method="execute_llava_model", - ) + ) if not scheduler_outputs.is_empty() else [] return self._process_model_outputs(output, scheduler_outputs) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index d91ab1430735c..43bf9747ee184 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1,4 +1,5 @@ import copy +import os import time from functools import partial from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union @@ -13,8 +14,7 @@ from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup, - SequenceGroupMetadata, SequenceGroupOutput, - SequenceOutput, SequenceStatus) + SequenceGroupOutput, SequenceOutput, SequenceStatus) from vllm.transformers_utils.tokenizer import (detokenize_incrementally, get_tokenizer) from vllm.utils import Counter @@ -105,6 +105,10 @@ def __init__( # Create the parallel GPU workers. if self.parallel_config.worker_use_ray: + # Disable Ray usage stats collection. + ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0") + if ray_usage != "1": + os.environ["RAY_USAGE_STATS_ENABLED"] = "0" self._init_workers_ray(placement_group) else: self._init_workers(distributed_init_method) @@ -227,6 +231,14 @@ def _init_cache(self) -> None: raise ValueError("No available memory for the cache blocks. " "Try increasing `gpu_memory_utilization` when " "initializing the engine.") + max_seq_len = self.cache_config.block_size * num_gpu_blocks + if self.model_config.max_model_len > max_seq_len: + raise ValueError( + f"The model's max seq len ({self.model_config.max_model_len}) " + "is larger than the maximum number of tokens that can be " + f"stored in KV cache ({max_seq_len}). Try increasing " + "`gpu_memory_utilization` or decreasing `max_model_len` when " + "initializing the engine.") self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks @@ -315,16 +327,6 @@ def has_unfinished_requests(self) -> bool: """Returns True if there are unfinished requests.""" return self.scheduler.has_unfinished_seqs() - def _schedule( - self - ) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, - List[RequestOutput]]: - seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() - return seq_group_metadata_list, scheduler_outputs, [ - RequestOutput.from_seq_group(seq_group) - for seq_group in scheduler_outputs.ignored_seq_groups - ] - def _check_beam_search_early_stopping( self, early_stopping: Union[bool, str], @@ -573,9 +575,7 @@ def step(self) -> List[RequestOutput]: and updates the scheduler with the model outputs. Finally, it decodes the sequences and returns the newly generated results. """ - seq_group_metadata_list, scheduler_outputs, ignored = self._schedule() - if scheduler_outputs.is_empty(): - return ignored + seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() # Execute the model. output = self._run_workers( @@ -584,7 +584,7 @@ def step(self) -> List[RequestOutput]: blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, blocks_to_copy=scheduler_outputs.blocks_to_copy, - ) + ) if not scheduler_outputs.is_empty() else [] return self._process_model_outputs(output, scheduler_outputs) diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index fb29837da8cf0..6910b3265dfd2 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -73,6 +73,8 @@ async def stream_results() -> AsyncGenerator[bytes, None]: parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default=None) parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--ssl-keyfile", type=str, default=None) + parser.add_argument("--ssl-certfile", type=str, default=None) parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() @@ -83,4 +85,6 @@ async def stream_results() -> AsyncGenerator[bytes, None]: host=args.host, port=args.port, log_level="debug", - timeout_keep_alive=TIMEOUT_KEEP_ALIVE) + timeout_keep_alive=TIMEOUT_KEEP_ALIVE, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 0f131ce6f4dc0..be5f4190e633f 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -80,6 +80,14 @@ def parse_args(): default="assistant", help="The role name to return if " "`request.add_generation_prompt=true`.") + parser.add_argument("--ssl-keyfile", + type=str, + default=None, + help="The file path to the SSL key file") + parser.add_argument("--ssl-certfile", + type=str, + default=None, + help="The file path to the SSL cert file") parser = AsyncEngineArgs.add_cli_args(parser) return parser.parse_args() @@ -744,4 +752,6 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]: host=args.host, port=args.port, log_level="info", - timeout_keep_alive=TIMEOUT_KEEP_ALIVE) + timeout_keep_alive=TIMEOUT_KEEP_ALIVE, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 13da9aa38af03..f9d95fa7548fd 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -6,13 +6,11 @@ from vllm.model_executor.parallel_utils.communication_op import ( tensor_model_parallel_all_gather) -from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.sampling_metadata import SamplingMetadata, SamplingTensors from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import (PromptLogprobs, SampleLogprobs, SamplerOutput, SequenceData, SequenceGroupOutput, SequenceOutput) -_SAMPLING_EPS = 1e-5 - class Sampler(nn.Module): """Samples the next tokens from the model's outputs. @@ -47,40 +45,34 @@ def forward( logits = _get_logits(hidden_states, embedding, embedding_bias, self.vocab_size) + _, vocab_size = logits.shape + # Apply logits processors (if any). logits = _apply_logits_processors(logits, sampling_metadata) + + # Prepare sampling tensors with pinned memory to avoid blocking. + (sampling_tensors, do_penalties, do_top_p_top_k, + do_min_p) = SamplingTensors.from_sampling_metadata( + sampling_metadata, vocab_size, logits.device, logits.dtype) + # Apply presence and frequency penalties. - presence_penalties, frequency_penalties, repetition_penalties = ( - _get_penalties(sampling_metadata)) - assert len(presence_penalties) == logits.shape[0] - assert len(frequency_penalties) == logits.shape[0] - assert len(repetition_penalties) == logits.shape[0] - logits = _apply_penalties(logits, sampling_metadata, - presence_penalties, frequency_penalties, - repetition_penalties) + if do_penalties: + logits = _apply_penalties(logits, sampling_tensors.prompt_tokens, + sampling_tensors.output_tokens, + sampling_tensors.presence_penalties, + sampling_tensors.frequency_penalties, + sampling_tensors.repetition_penalties) # Apply temperature scaling. - temperatures = _get_temperatures(sampling_metadata) - assert len(temperatures) == logits.shape[0] - if any(t != 1.0 for t in temperatures): - t = torch.tensor(temperatures, - dtype=logits.dtype, - device=logits.device) - # Use in-place division to avoid creating a new tensor. - logits.div_(t.unsqueeze(dim=1)) - - # Apply top-p and top-k truncation. - top_ps, top_ks, min_ps = _get_top_p_top_k_min_p( - sampling_metadata, self.vocab_size) - assert len(top_ps) == len(top_ks) == logits.shape[0] - do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps) - do_top_k = any(k != self.vocab_size for k in top_ks) - if do_top_p or do_top_k: - logits = _apply_top_p_top_k(logits, top_ps, top_ks) - - do_min_p = any(mp > _SAMPLING_EPS for mp in min_ps) + # Use in-place division to avoid creating a new tensor. + logits.div_(sampling_tensors.temperatures.unsqueeze_(dim=1)) + + if do_top_p_top_k: + logits = _apply_top_p_top_k(logits, sampling_tensors.top_ps, + sampling_tensors.top_ks) + if do_min_p: - logits = _apply_min_p(logits, min_ps) + logits = _apply_min_p(logits, sampling_tensors.min_ps) # We use float32 for probabilities and log probabilities. # Compute the probabilities. @@ -120,32 +112,6 @@ def _prune_hidden_states( sampling_metadata.selected_token_indices) -def _get_penalties( - sampling_metadata: SamplingMetadata -) -> Tuple[List[float], List[float], List[float]]: - # Collect the presence and frequency penalties. - presence_penalties: List[float] = [] - frequency_penalties: List[float] = [] - repetition_penalties: List[float] = [] - for i, seq_group in enumerate(sampling_metadata.seq_groups): - seq_ids, sampling_params = seq_group - p = sampling_params.presence_penalty - f = sampling_params.frequency_penalty - r = sampling_params.repetition_penalty - if (i < sampling_metadata.num_prompts - and sampling_params.prompt_logprobs is not None): - # NOTE: We do not apply presence and frequency penalties for the - # prompt token positions where we don't sample new tokens. - prompt_len = sampling_metadata.prompt_lens[i] - presence_penalties += [0] * (prompt_len - 1) - frequency_penalties += [0] * (prompt_len - 1) - repetition_penalties += [1] * (prompt_len - 1) - presence_penalties += [p] * len(seq_ids) - frequency_penalties += [f] * len(seq_ids) - repetition_penalties += [r] * len(seq_ids) - return presence_penalties, frequency_penalties, repetition_penalties - - def _get_prompt_and_output_tokens( sampling_metadata: SamplingMetadata, ) -> Tuple[List[List[int]], List[List[int]]]: @@ -168,25 +134,16 @@ def _get_prompt_and_output_tokens( def _get_bin_counts_and_mask( - logits: torch.Tensor, - tokens: List[List[int]], + tokens: torch.Tensor, vocab_size: int, num_seqs: int, ) -> Tuple[torch.Tensor, torch.Tensor]: - max_len = max(len(tokens) for tokens in tokens) - padded_tokens = [ - tokens + [vocab_size] * (max_len - len(tokens)) for tokens in tokens - ] - tokens_tensor = torch.tensor(padded_tokens, - dtype=torch.long, - device=logits.device) - # Compute the bin counts for the tokens. # vocab_size + 1 for padding. bin_counts = torch.zeros((num_seqs, vocab_size + 1), dtype=torch.long, - device=logits.device) - bin_counts.scatter_add_(1, tokens_tensor, torch.ones_like(tokens_tensor)) + device=tokens.device) + bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens)) bin_counts = bin_counts[:, :vocab_size] mask = bin_counts > 0 @@ -217,45 +174,16 @@ def _apply_logits_processors( return logits -def _apply_penalties( - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - presence_penalties: List[float], - frequency_penalties: List[float], - repetition_penalties: List[float], -) -> torch.Tensor: +def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, + output_tokens_tensor: torch.Tensor, + presence_penalties: torch.Tensor, + frequency_penalties: torch.Tensor, + repetition_penalties: torch.Tensor) -> torch.Tensor: num_seqs, vocab_size = logits.shape - for i in range(num_seqs): - p = presence_penalties[i] - f = frequency_penalties[i] - r = repetition_penalties[i] - if abs(p) < _SAMPLING_EPS and abs(f) < _SAMPLING_EPS and abs( - r - 1.0) < _SAMPLING_EPS: - continue - break - else: - # Return early if all sequences have zero penalties. - return logits - - prompt_tokens, output_tokens = ( - _get_prompt_and_output_tokens(sampling_metadata)) - assert len(prompt_tokens) == logits.shape[0] - assert len(output_tokens) == logits.shape[0] - - prompt_bin_counts, prompt_mask = _get_bin_counts_and_mask( - logits, prompt_tokens, vocab_size, num_seqs) + _, prompt_mask = _get_bin_counts_and_mask(prompt_tokens_tensor, vocab_size, + num_seqs) output_bin_counts, output_mask = _get_bin_counts_and_mask( - logits, output_tokens, vocab_size, num_seqs) - - repetition_penalties = torch.tensor(repetition_penalties, - dtype=logits.dtype, - device=logits.device) - frequency_penalties = torch.tensor(frequency_penalties, - dtype=logits.dtype, - device=logits.device) - presence_penalties = torch.tensor(presence_penalties, - dtype=logits.dtype, - device=logits.device) + output_tokens_tensor, vocab_size, num_seqs) repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size) repetition_penalties[~(prompt_mask | output_mask)] = 1.0 @@ -264,109 +192,65 @@ def _apply_penalties( # We follow the definition in OpenAI API. # Refer to https://platform.openai.com/docs/api-reference/parameter-details - logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts - logits -= presence_penalties.unsqueeze(dim=1) * output_mask + logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts + logits -= presence_penalties.unsqueeze_(dim=1) * output_mask return logits -def _get_temperatures(sampling_metadata: SamplingMetadata) -> List[float]: - # Collect the temperatures for the logits. - temperatures: List[float] = [] - for i, seq_group in enumerate(sampling_metadata.seq_groups): - seq_ids, sampling_params = seq_group - temperature = sampling_params.temperature - if temperature < _SAMPLING_EPS: - # NOTE: Zero temperature means deterministic sampling - # (i.e., greedy sampling or beam search). - # Set the temperature to 1 to avoid division by zero. - temperature = 1.0 - if (i < sampling_metadata.num_prompts - and sampling_params.prompt_logprobs is not None): - prompt_len = sampling_metadata.prompt_lens[i] - temperatures += [temperature] * (prompt_len - 1) - temperatures += [temperature] * len(seq_ids) - return temperatures - - -def _get_top_p_top_k_min_p( - sampling_metadata: SamplingMetadata, - vocab_size: int, -) -> Tuple[List[float], List[int], List[float]]: - top_ps: List[float] = [] - top_ks: List[int] = [] - min_ps: List[float] = [] - for i, seq_group in enumerate(sampling_metadata.seq_groups): - seq_ids, sampling_params = seq_group - top_p = sampling_params.top_p - min_p = sampling_params.min_p - # k should not be greater than the vocab size. - top_k = min(sampling_params.top_k, vocab_size) - # k=-1 means no truncation. - top_k = vocab_size if top_k == -1 else top_k - if (i < sampling_metadata.num_prompts - and sampling_params.prompt_logprobs is not None): - prompt_len = sampling_metadata.prompt_lens[i] - top_ps += [top_p] * (prompt_len - 1) - top_ks += [top_k] * (prompt_len - 1) - min_ps += [min_p] * (prompt_len - 1) - top_ps += [top_p] * len(seq_ids) - top_ks += [top_k] * len(seq_ids) - min_ps += [min_p] * len(seq_ids) - return top_ps, top_ks, min_ps - - def _apply_top_p_top_k( logits: torch.Tensor, - top_ps: List[float], - top_ks: List[int], + p: torch.Tensor, + k: torch.Tensor, ) -> torch.Tensor: - p = torch.tensor(top_ps, dtype=logits.dtype, device=logits.device) - k = torch.tensor(top_ks, dtype=torch.int, device=logits.device) logits_sort, logits_idx = logits.sort(dim=-1, descending=True) # Apply top-p. probs_sort = logits_sort.softmax(dim=-1) - probs_sum = probs_sort.cumsum(dim=-1) - top_p_mask = (probs_sum - probs_sort) > p.unsqueeze(dim=1) - logits_sort[top_p_mask] = -float("inf") + probs_sum = probs_sort.cumsum(dim=-1).sub_(probs_sort) + top_p_mask = probs_sum > p.unsqueeze_(dim=1) # Apply top-k. # Create a mask for the top-k elements. top_k_mask = torch.arange(logits_idx.shape[-1], device=logits_idx.device) top_k_mask = top_k_mask.expand(logits_idx.shape[0], -1) - top_k_mask = top_k_mask >= k.unsqueeze(dim=1) - logits_sort[top_k_mask] = -float("inf") + top_k_mask = top_k_mask >= k.unsqueeze_(dim=1) + + # Final mask. + mask = (top_p_mask | top_k_mask) + logits_sort.masked_fill_(mask, -float("inf")) # Re-sort the probabilities. - logits = torch.gather(logits_sort, - dim=-1, - index=torch.argsort(logits_idx, dim=-1)) + src = torch.arange(logits_idx.shape[-1], + device=logits_idx.device).expand_as(logits_idx) + logits_idx_inv = torch.empty_like(logits_idx).scatter_(dim=-1, + index=logits_idx, + src=src) + logits = torch.gather(logits_sort, dim=-1, index=logits_idx_inv) return logits def _apply_min_p( logits: torch.Tensor, - min_ps: List[float], + min_p: torch.Tensor, ) -> torch.Tensor: """ Adapted from https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17 """ - min_p = torch.tensor(min_ps, dtype=logits.dtype, device=logits.device) probs = torch.softmax(logits, dim=-1) top_probs, _ = probs.max(dim=-1, keepdim=True) - scaled_min_p = min_p.unsqueeze(dim=1) * top_probs + scaled_min_p = min_p.unsqueeze_(dim=1) * top_probs tokens_to_remove = probs < scaled_min_p - logits = logits.masked_fill(tokens_to_remove, -float("inf")) + logits = logits.masked_fill_(tokens_to_remove, -float("inf")) return logits def _greedy_sample( selected_seq_groups: List[Tuple[List[int], SamplingParams]], - logprobs: torch.Tensor, + samples: torch.Tensor, ) -> List[Tuple[List[int], List[int]]]: - samples = torch.argmax(logprobs, dim=-1).cpu() + samples = samples.tolist() sample_idx = 0 results = [] for seq_group in selected_seq_groups: @@ -375,27 +259,19 @@ def _greedy_sample( assert num_parent_seqs == 1, ( "Greedy sampling should have only one seq.") parent_ids = list(range(num_parent_seqs)) - next_token_ids = [samples[sample_idx].item()] + next_token_ids = [samples[sample_idx]] results.append((next_token_ids, parent_ids)) sample_idx += num_parent_seqs - assert sample_idx == logprobs.size(0) return results def _random_sample( selected_seq_groups: List[Tuple[List[int], SamplingParams]], is_prompts: List[bool], - probs: torch.Tensor, + random_samples: torch.Tensor, ) -> List[Tuple[List[int], List[int]]]: # Find the maximum best_of value of the prompt phase requests. - max_best_of = 1 - for seq_group, is_prompt in zip(selected_seq_groups, is_prompts): - if is_prompt: - seq_ids, sampling_params = seq_group - max_best_of = max(max_best_of, sampling_params.best_of) - random_samples = torch.multinomial(probs, - num_samples=max_best_of, - replacement=True).cpu() + random_samples = random_samples.cpu() sample_idx = 0 results = [] for seq_group, is_prompt in zip(selected_seq_groups, is_prompts): @@ -403,8 +279,6 @@ def _random_sample( num_parent_seqs = len(seq_ids) if is_prompt: # Prompt phase. - assert num_parent_seqs == 1, ( - "Prompt input should have only one seq.") parent_ids = [0] * sampling_params.best_of next_token_ids = random_samples[ sample_idx, :sampling_params.best_of].tolist() @@ -415,7 +289,6 @@ def _random_sample( num_parent_seqs, 0].tolist() results.append((next_token_ids, parent_ids)) sample_idx += num_parent_seqs - assert sample_idx == probs.size(0) return results @@ -472,6 +345,28 @@ def _beam_search_sample( return results +# torch.multinomial forces a GPU<->CPU sync. +# Therefore, we use an optimized implementation instead. +# Note that we always sample with replacement. +# probs will be modified in place, but this is fine, as we pass +# in a copy already. +def _multinomial( + probs: torch.Tensor, + num_samples: int, +): + if num_samples > 1: + # This is equivalent to torch.repeat_interleaved (which also + # forces a GPU<->CPU sync). + # This allows us to do sampling with replacement by creating + # num_samples copies of each row in the tensor, and then + # batch sampling the resulting tensor. + probs = probs[:, None, :].expand(probs.shape[0], num_samples, + probs.shape[1]).contiguous().view( + -1, probs.shape[1]) + q = torch.empty_like(probs).exponential_(1) + return probs.div_(q).argmax(dim=1).view(-1, num_samples) + + def _sample( probs: torch.Tensor, logprobs: torch.Tensor, @@ -485,28 +380,51 @@ def _sample( categorized_seq_group_ids[sampling_type].append(i) sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {} + sample_metadata = {} + + # Counterintiutively, having two loops here is actually faster. + # The first loop can run without waiting on GPU<->CPU sync. for sampling_type in SamplingType: - seq_group_ids = categorized_seq_group_ids[sampling_type] - seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids] - is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids] sample_indices = categorized_sample_indices[sampling_type] num_tokens = len(sample_indices) if num_tokens == 0: continue + seq_group_ids = categorized_seq_group_ids[sampling_type] + seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids] + is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids] + sample_metadata[sampling_type] = (seq_group_ids, seq_groups, + is_prompts, sample_indices) if sampling_type == SamplingType.GREEDY: - category_logprobs = logprobs[sample_indices] - sample_results = _greedy_sample(seq_groups, category_logprobs) + greedy_samples = torch.argmax(logprobs[sample_indices], dim=-1) + elif sampling_type == SamplingType.RANDOM: + max_best_of = 1 + for seq_group, is_prompt in zip(seq_groups, is_prompts): + if is_prompt: + _, sampling_params = seq_group + max_best_of = max(max_best_of, sampling_params.best_of) + multinomial_samples = _multinomial(probs[sample_indices], + max_best_of) + elif sampling_type == SamplingType.BEAM: + beam_search_logprobs = logprobs[sample_indices] + else: + raise ValueError(f"Unsupported sampling type: {sampling_type}") + + # GPU<->CPU sync happens in the loop below. + + for sampling_type in SamplingType: + if sampling_type not in sample_metadata: + continue + seq_group_ids, seq_groups, is_prompts, sample_indices = sample_metadata[ + sampling_type] + if sampling_type == SamplingType.GREEDY: + sample_results = _greedy_sample(seq_groups, greedy_samples) elif sampling_type == SamplingType.RANDOM: - category_probs = probs[sample_indices] sample_results = _random_sample(seq_groups, is_prompts, - category_probs) + multinomial_samples) elif sampling_type == SamplingType.BEAM: - category_logprobs = logprobs[sample_indices] sample_results = _beam_search_sample(seq_groups, is_prompts, sampling_metadata.seq_data, - category_logprobs) - else: - raise ValueError(f"Unsupported sampling type: {sampling_type}") + beam_search_logprobs) sample_results_dict.update(zip(seq_group_ids, sample_results)) sample_results = [ @@ -557,7 +475,7 @@ def _get_logprobs( batched_logprobs_query_result = logprobs[[ batched_logprobs_query_seq_indices, batched_logprobs_query_token_indices - ]].cpu() + ]] # Batched query for logprobs of topk tokens if largest_num_logprobs > 0: @@ -569,6 +487,8 @@ def _get_logprobs( else: top_logprobs, top_token_ids = None, None + batched_logprobs_query_result = batched_logprobs_query_result.cpu() + # Gather results result_prompt_logprobs: List[Optional[PromptLogprobs]] = [] result_sample_logprobs: List[SampleLogprobs] = [] diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 708a25454ff81..80e7988ab3169 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -17,6 +17,7 @@ "BloomForCausalLM": ("bloom", "BloomForCausalLM"), "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"), "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"), + "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"), "FalconForCausalLM": ("falcon", "FalconForCausalLM"), "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"), "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"), diff --git a/vllm/model_executor/models/decilm.py b/vllm/model_executor/models/decilm.py new file mode 100644 index 0000000000000..984be0cccd16d --- /dev/null +++ b/vllm/model_executor/models/decilm.py @@ -0,0 +1,123 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 DeciAI Research Team. All rights reserved. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on MistralAI GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only DeciLM model compatible with HuggingFace weights.""" + +from typing import Optional + +import torch +from transformers import PretrainedConfig + +from vllm.model_executor.layers.linear import LinearMethodBase +from vllm.model_executor.models.llama import LlamaForCausalLM +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) + + +class DeciLMForCausalLM(LlamaForCausalLM): + """ + Implementation for https://huggingface.co/Deci/DeciLM-7b-instruct. + Based on the llama executor. + + The main difference is that DeciLM uses Variable Grouped Query Attention. + The constant number of GQA heads in the decoder is overriden with a value + per layer. + + Usually, in the HuggingFace implementation, instead of + "config.num_key_value_heads", we use + "config.num_key_value_heads_per_layer[i]" which varies. + + Currently, PagedAttention does not work well with variable GQA, so we + normalize the weights upon loading, and use uniform GQA with the max value + instead. + """ + + def __init__( + self, + config: Optional[PretrainedConfig] = None, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + config.num_key_value_heads = max(config.num_key_value_heads_per_layer) + delattr(config, "num_key_value_heads_per_layer") + super().__init__(config=config, linear_method=linear_method) + + def load_weights(self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision): + if "rotary_emb.inv_freq" in name: + continue + + if "k_proj" in name or "v_proj" in name: + loaded_weight = self._degroup_weight(loaded_weight) + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + def _degroup_weight(self, loaded_weight: torch.Tensor) -> torch.Tensor: + hidden_size = self.config.hidden_size + head_size = self.config.hidden_size // self.config.num_attention_heads + target_num_kv_heads = self.config.num_key_value_heads + num_kv_heads = loaded_weight.shape[0] // head_size + n_repeats = target_num_kv_heads / num_kv_heads + assert n_repeats == int(n_repeats) + + n_repeats = int(n_repeats) + loaded_weight = loaded_weight.view(num_kv_heads, head_size, + hidden_size) + loaded_weight = torch.repeat_interleave(loaded_weight, + repeats=n_repeats, + dim=0) + loaded_weight = loaded_weight.reshape(target_num_kv_heads * head_size, + hidden_size) + + return loaded_weight diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 1298cc52a6488..1939a09fd736f 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -251,9 +251,10 @@ def load_weights(self, for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): if name.startswith("model."): - name = name[6:] # remove "model." prefix + name = name[6:] # remove "model." prefix - if name.startswith("language_model"): # load language model weights + if name.startswith( + "language_model"): # load language model weights # name = name[6:] # remove "model." prefix if "rotary_emb.inv_freq" in name: continue @@ -279,8 +280,7 @@ def load_weights(self, default_weight_loader) weight_loader(param, loaded_weight) elif name.startswith("vision_tower") or name.startswith( - 'multi_modal_projector' - ): # load vision model weights + 'multi_modal_projector'): # load vision model weights # name = name[6:] # remove "model." prefix if params_dict.get(name, None) is None: unused_keys.append(name) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 13473857b3309..e61b401a78a2b 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -49,7 +49,6 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) -from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -94,30 +93,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return current_hidden_states -class DummyModule(nn.Module): - - def __init__(self) -> None: - super().__init__() - - self.w1 = nn.Linear(0, 0, bias=False) - self.w2 = nn.Linear(0, 0, bias=False) - self.w3 = nn.Linear(0, 0, bias=False) - - set_weight_attrs(self.w1.weight, - {"weight_loader": self.dummy_weight_loader}) - set_weight_attrs(self.w2.weight, - {"weight_loader": self.dummy_weight_loader}) - set_weight_attrs(self.w3.weight, - {"weight_loader": self.dummy_weight_loader}) - - def forward(self, *args, **kwargs) -> None: - raise NotImplementedError() - - def dummy_weight_loader(self, *args, **kwargs) -> None: # pylint: disable=unused-argument - # Noop - return - - class MixtralMoE(nn.Module): def __init__( @@ -147,7 +122,7 @@ def __init__( config.hidden_size, config.intermediate_size, linear_method=linear_method) - if idx in self.expert_indicies else DummyModule() + if idx in self.expert_indicies else None for idx in range(self.num_total_experts) ]) self.gate = ReplicatedLinear(config.hidden_size, @@ -427,6 +402,10 @@ def load_weights(self, # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + # Skip experts that are not assigned to this worker. + if ("block_sparse_moe.experts." in name + and name not in params_dict): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index deb779f537c69..49013ec273787 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -1,9 +1,13 @@ +from dataclasses import dataclass from typing import Dict, List, Tuple import torch from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import SequenceData +from vllm.utils import in_wsl + +_SAMPLING_EPS = 1e-5 class SamplingMetadata: @@ -41,3 +45,186 @@ def __repr__(self) -> str: f"prompt_lens={self.prompt_lens}, " f"selected_token_indices={self.selected_token_indices}, " f"categorized_sample_indices={self.categorized_sample_indices})") + + +@dataclass +class SamplingTensors: + """Tensors for sampling.""" + + temperatures: torch.Tensor + top_ps: torch.Tensor + top_ks: torch.Tensor + min_ps: torch.Tensor + presence_penalties: torch.Tensor + frequency_penalties: torch.Tensor + repetition_penalties: torch.Tensor + prompt_tokens: torch.Tensor + output_tokens: torch.Tensor + + @classmethod + def from_sampling_metadata( + cls, sampling_metadata: "SamplingMetadata", vocab_size: int, + device: torch.device, + dtype: torch.dtype) -> Tuple["SamplingTensors", bool, bool, bool]: + prompt_tokens: List[List[int]] = [] + output_tokens: List[List[int]] = [] + top_ks: List[int] = [] + temperatures: List[float] = [] + top_ps: List[float] = [] + min_ps: List[float] = [] + presence_penalties: List[float] = [] + frequency_penalties: List[float] = [] + repetition_penalties: List[float] = [] + do_penalties = False + do_top_p_top_k = False + do_min_p = False + for i, seq_group in enumerate(sampling_metadata.seq_groups): + seq_ids, sampling_params = seq_group + temperature = sampling_params.temperature + p = sampling_params.presence_penalty + f = sampling_params.frequency_penalty + r = sampling_params.repetition_penalty + top_p = sampling_params.top_p + min_p = sampling_params.min_p + # k should not be greater than the vocab size. + top_k = min(sampling_params.top_k, vocab_size) + top_k = vocab_size if top_k == -1 else top_k + if temperature < _SAMPLING_EPS: + # NOTE: Zero temperature means deterministic sampling + # (i.e., greedy sampling or beam search). + # Set the temperature to 1 to avoid division by zero. + temperature = 1.0 + if not do_top_p_top_k and (top_p < 1.0 - _SAMPLING_EPS + or top_k != vocab_size): + do_top_p_top_k = True + if not do_min_p and min_p > _SAMPLING_EPS: + do_min_p = True + if not do_penalties and (abs(p) >= _SAMPLING_EPS + or abs(f) >= _SAMPLING_EPS + or abs(r - 1.0) >= _SAMPLING_EPS): + do_penalties = True + if (i < sampling_metadata.num_prompts + and sampling_params.prompt_logprobs is not None): + # For tokens in the prompt that we only need to get their logprobs + prompt_len = sampling_metadata.prompt_lens[i] + temperatures += [temperature] * (prompt_len - 1) + top_ps += [top_p] * (prompt_len - 1) + top_ks += [top_k] * (prompt_len - 1) + min_ps += [min_p] * (prompt_len - 1) + presence_penalties += [0] * (prompt_len - 1) + frequency_penalties += [0] * (prompt_len - 1) + repetition_penalties += [1] * (prompt_len - 1) + prompt_tokens.extend([] for _ in range(prompt_len - 1)) + output_tokens.extend([] for _ in range(prompt_len - 1)) + for seq_id in seq_ids: + seq_data = sampling_metadata.seq_data[seq_id] + prompt_tokens.append(seq_data.prompt_token_ids) + output_tokens.append(seq_data.output_token_ids) + temperatures += [temperature] * len(seq_ids) + top_ps += [top_p] * len(seq_ids) + top_ks += [top_k] * len(seq_ids) + min_ps += [min_p] * len(seq_ids) + presence_penalties += [p] * len(seq_ids) + frequency_penalties += [f] * len(seq_ids) + repetition_penalties += [r] * len(seq_ids) + + sampling_tensors = SamplingTensors.from_lists( + temperatures, top_ps, top_ks, min_ps, presence_penalties, + frequency_penalties, repetition_penalties, prompt_tokens, + output_tokens, vocab_size, device, dtype) + return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p) + + @classmethod + def from_lists(cls, temperatures: List[float], top_ps: List[float], + top_ks: List[int], min_ps: List[float], + presence_penalties: List[float], + frequency_penalties: List[float], + repetition_penalties: List[float], + prompt_tokens: List[List[int]], + output_tokens: List[List[int]], vocab_size: int, + device: torch.device, + dtype: torch.dtype) -> "SamplingTensors": + # Note that the performance will be very bad without + # pinned memory. + pin_memory = not in_wsl() + prompt_max_len = max(len(tokens) for tokens in prompt_tokens) + prompt_padded_tokens = [ + tokens + [vocab_size] * (prompt_max_len - len(tokens)) + for tokens in prompt_tokens + ] + output_max_len = max(len(tokens) for tokens in output_tokens) + output_padded_tokens = [ + tokens + [vocab_size] * (output_max_len - len(tokens)) + for tokens in output_tokens + ] + + temperatures_t = torch.tensor( + temperatures, + device="cpu", + dtype=dtype, + pin_memory=pin_memory, + ) + top_ps_t = torch.tensor( + top_ps, + device="cpu", + dtype=dtype, + pin_memory=pin_memory, + ) + min_ps_t = torch.tensor( + min_ps, + device="cpu", + dtype=dtype, + pin_memory=pin_memory, + ) + presence_penalties_t = torch.tensor( + presence_penalties, + device="cpu", + dtype=dtype, + pin_memory=pin_memory, + ) + frequency_penalties_t = torch.tensor( + frequency_penalties, + device="cpu", + dtype=dtype, + pin_memory=pin_memory, + ) + repetition_penalties_t = torch.tensor( + repetition_penalties, + device="cpu", + dtype=dtype, + pin_memory=pin_memory, + ) + top_ks_t = torch.tensor( + top_ks, + device="cpu", + dtype=torch.int, + pin_memory=pin_memory, + ) + prompt_tensor = torch.tensor( + prompt_padded_tokens, + device="cpu", + dtype=torch.long, + pin_memory=pin_memory, + ) + output_tensor = torch.tensor( + output_padded_tokens, + device="cpu", + dtype=torch.long, + pin_memory=pin_memory, + ) + # Because the memory is pinned, we can do non-blocking + # transfer to device. + return cls( + temperatures=temperatures_t.to(device=device, non_blocking=True), + top_ps=top_ps_t.to(device=device, non_blocking=True), + top_ks=top_ks_t.to(device=device, non_blocking=True), + min_ps=min_ps_t.to(device=device, non_blocking=True), + presence_penalties=presence_penalties_t.to(device=device, + non_blocking=True), + frequency_penalties=frequency_penalties_t.to(device=device, + non_blocking=True), + repetition_penalties=repetition_penalties_t.to(device=device, + non_blocking=True), + prompt_tokens=prompt_tensor.to(device=device, non_blocking=True), + output_tokens=output_tensor.to(device=device, non_blocking=True), + ) diff --git a/vllm/model_executor/weight_utils.py b/vllm/model_executor/weight_utils.py index bff4fb2f7729e..365c847a435fe 100644 --- a/vllm/model_executor/weight_utils.py +++ b/vllm/model_executor/weight_utils.py @@ -146,7 +146,7 @@ def prepare_hf_model_weights( raise ValueError(f"Unknown load_format: {load_format}") if fall_back_to_pt: - allow_patterns += [".pt"] + allow_patterns += ["*.pt"] if not is_local: # Use file lock to prevent multiple processes from diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index fedb0ecfef092..9bd4ab2ae44b3 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -11,6 +11,7 @@ from vllm.model_executor import get_model, InputMetadata, SamplingMetadata from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata +from vllm.utils import in_wsl logger = init_logger(__name__) @@ -53,6 +54,8 @@ def __init__( # The shape of the cached block table will be # (max batch size to capture, max context len to capture / block size). self.graph_block_tables = None # Set after initial profiling. + # cache in_wsl result + self.in_wsl = in_wsl() def load_model(self) -> None: self.model = get_model(self.model_config) @@ -204,24 +207,29 @@ def _prepare_decode( # When using CUDA graph, we don't need to make the tensors on the GPU # because they will be eventually copied to the designated GPU buffer. device = "cpu" if use_captured_graph else "cuda" + pin_memory = use_captured_graph and not self.in_wsl input_tokens = _make_tensor_with_pad(input_tokens, max_len=1, pad=0, dtype=torch.long, - device=device) + device=device, + pin_memory=pin_memory) input_positions = _make_tensor_with_pad(input_positions, max_len=1, pad=0, dtype=torch.long, - device=device) + device=device, + pin_memory=pin_memory) slot_mapping = _make_tensor_with_pad(slot_mapping, max_len=1, pad=_PAD_SLOT_ID, dtype=torch.long, - device=device) + device=device, + pin_memory=pin_memory) context_lens = torch.tensor(context_lens, dtype=torch.int, - device=device) + device=device, + pin_memory=pin_memory) if use_captured_graph: # The shape of graph_block_tables is @@ -230,7 +238,7 @@ def _prepare_decode( for i, block_table in enumerate(block_tables): if block_table: input_block_tables[i, :len(block_table)] = block_table - block_tables = torch.from_numpy(input_block_tables).to(device) + block_tables = torch.tensor(input_block_tables, device=device) else: block_tables = _make_tensor_with_pad( block_tables, @@ -298,11 +306,11 @@ def _prepare_sample( categorized_sample_indices_start_idx + num_seqs)) categorized_sample_indices_start_idx += num_seqs - selected_token_indices = torch.tensor(selected_token_indices, - dtype=torch.long, - device="cuda") + selected_token_indices = _async_h2d(selected_token_indices, + dtype=torch.long, + pin_memory=not self.in_wsl) categorized_sample_indices = { - t: torch.tensor(seq_ids, dtype=torch.int, device="cuda") + t: _async_h2d(seq_ids, dtype=torch.int, pin_memory=not self.in_wsl) for t, seq_ids in categorized_sample_indices.items() } @@ -335,8 +343,6 @@ def execute_model( else: inputs = self._prepare_decode(seq_group_metadata_list) input_tokens, input_positions, input_metadata = inputs - sampling_metadata = self._prepare_sample(seq_group_metadata_list, - input_metadata.prompt_lens) # Execute the model. if input_metadata.use_cuda_graph: @@ -351,6 +357,9 @@ def execute_model( input_metadata=input_metadata, ) + sampling_metadata = self._prepare_sample(seq_group_metadata_list, + input_metadata.prompt_lens) + # Sample the next token. output = self.model.sample( hidden_states=hidden_states, @@ -388,8 +397,6 @@ def execute_llava_model( else: inputs = self._prepare_decode(seq_group_metadata_list) input_tokens, input_positions, input_metadata = inputs - sampling_metadata = self._prepare_sample(seq_group_metadata_list, - input_metadata.prompt_lens) # Execute the model. if input_metadata.use_cuda_graph: @@ -403,6 +410,8 @@ def execute_llava_model( input_metadata=input_metadata, **extra_kwargs) + sampling_metadata = self._prepare_sample(seq_group_metadata_list, + input_metadata.prompt_lens) # Sample the next token. output = self.model.sample( hidden_states=hidden_states, @@ -448,6 +457,9 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: "unexpected consequences if the model is not static. To " "run the model in eager mode, set 'enforce_eager=True' or " "use '--enforce-eager' in the CLI.") + logger.info("CUDA graphs can take additional 1~3 GiB memory per GPU. " + "If you are running out of memory, consider decreasing " + "`gpu_memory_utilization` or enforcing eager mode.") start_time = time.perf_counter() # Prepare dummy inputs. These will be reused for all batch sizes. @@ -565,7 +577,6 @@ def forward( self.input_buffers[key].copy_(value) else: self.input_buffers[key] = value - # Run the graph. self.graph.replay() @@ -587,9 +598,13 @@ def _make_tensor_with_pad( pad: int, dtype: torch.dtype, device: Union[str, torch.device] = "cuda", + pin_memory: bool = False, ) -> torch.Tensor: padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x] - return torch.tensor(padded_x, dtype=dtype, device=device) + return torch.tensor(padded_x, + dtype=dtype, + device=device, + pin_memory=pin_memory and str(device) == "cpu") def _get_graph_batch_size(batch_size: int) -> int: @@ -599,3 +614,8 @@ def _get_graph_batch_size(batch_size: int) -> int: return 4 else: return (batch_size + 7) // 8 * 8 + + +def _async_h2d(data: list, dtype, pin_memory): + t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory) + return t.to(device="cuda", non_blocking=True)