diff --git a/tests/core/block/e2e/test_correctness.py b/tests/core/block/e2e/test_correctness.py index 5a7f828456e2d..94b65401e1dd4 100644 --- a/tests/core/block/e2e/test_correctness.py +++ b/tests/core/block/e2e/test_correctness.py @@ -16,7 +16,7 @@ # Allow only 5 sequences of ~1024 tokens in worst case. "block_size": 16, - "forced_num_gpu_blocks": 5 * (64 + 1), + "num_gpu_blocks_override": 5 * (64 + 1), }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{ @@ -162,14 +162,14 @@ def test_v1_v2_greedy_equality_with_cow(baseline_llm_generator, # Allow only 2 sequences of ~128 tokens in worst case. # Note 8 = 128/block_size - "forced_num_gpu_blocks": 2 * (8 + 1), + "num_gpu_blocks_override": 2 * (8 + 1), }, { "block_size": 8, # Allow only 2 sequences of ~128 tokens in worst case. # Note 16 = 128/block_size - "forced_num_gpu_blocks": 2 * (16 + 1), + "num_gpu_blocks_override": 2 * (16 + 1), } ]) @pytest.mark.parametrize("baseline_llm_kwargs", [{ diff --git a/tests/lora/test_worker.py b/tests/lora/test_worker.py index 60aa90fe4ee8a..54594690f7922 100644 --- a/tests/lora/test_worker.py +++ b/tests/lora/test_worker.py @@ -3,8 +3,8 @@ import tempfile from unittest.mock import patch -from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, - SchedulerConfig) +from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, + ParallelConfig, SchedulerConfig) from vllm.lora.models import LoRAMapping from vllm.lora.request import LoRARequest from vllm.worker.worker import Worker @@ -27,6 +27,10 @@ def test_worker_apply_lora(sql_lora_files): parallel_config=ParallelConfig(1, 1, False), scheduler_config=SchedulerConfig(32, 32, 32), device_config=DeviceConfig("cuda"), + cache_config=CacheConfig(block_size=16, + gpu_memory_utilization=1., + swap_space=0, + cache_dtype="auto"), local_rank=0, rank=0, lora_config=LoRAConfig(max_lora_rank=8, max_cpu_loras=32, diff --git a/tests/spec_decode/test_spec_decode_worker.py b/tests/spec_decode/test_spec_decode_worker.py index 825d360671965..47aff8f575413 100644 --- a/tests/spec_decode/test_spec_decode_worker.py +++ b/tests/spec_decode/test_spec_decode_worker.py @@ -512,8 +512,8 @@ def test_init_device(): @torch.inference_mode() -def test_init_cache_engine(): - """Verify SpecDecodeWorker invokes init_cache_engine on proposer/scorer +def test_initialize_cache(): + """Verify SpecDecodeWorker invokes initialize_cache on proposer/scorer workers. """ draft_worker = mock_worker(cls=MultiStepWorker) @@ -525,12 +525,11 @@ def test_init_cache_engine(): worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, metrics_collector) - cache_config = MagicMock() + kwargs = {"num_gpu_blocks": 1024, "num_cpu_blocks": 1023} + worker.initialize_cache(**kwargs) - worker.init_cache_engine(cache_config) - - draft_worker.init_cache_engine.assert_called_once_with(cache_config) - target_worker.init_cache_engine.assert_called_once_with(cache_config) + draft_worker.initialize_cache.assert_called_once_with(**kwargs) + target_worker.initialize_cache.assert_called_once_with(**kwargs) @pytest.mark.parametrize('available_gpu_blocks', [1, 1024]) @@ -538,10 +537,10 @@ def test_init_cache_engine(): @pytest.mark.parametrize('target_cache_block_size_bytes', [2 * 2 * 4096]) @pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096]) @pytest.mark.skip_global_cleanup -def test_profile_num_available_blocks(available_gpu_blocks: int, - available_cpu_blocks: int, - target_cache_block_size_bytes: int, - draft_kv_size_bytes: int): +def test_determine_num_available_blocks(available_gpu_blocks: int, + available_cpu_blocks: int, + target_cache_block_size_bytes: int, + draft_kv_size_bytes: int): """Verify SpecDecodeWorker correctly profiles num available GPU blocks. Specifically, it should run profiling in the scorer worker, and then evenly split the blocks between proposer and scorer worker. @@ -552,7 +551,7 @@ def test_profile_num_available_blocks(available_gpu_blocks: int, rejection_sampler.token_id_dtype = torch.int64 metrics_collector = MagicMock(spec=AsyncMetricsCollector) - target_worker.profile_num_available_blocks.return_value = ( + target_worker.determine_num_available_blocks.return_value = ( available_gpu_blocks, available_cpu_blocks) target_worker.get_cache_block_size_bytes.return_value = ( target_cache_block_size_bytes) @@ -561,17 +560,9 @@ def test_profile_num_available_blocks(available_gpu_blocks: int, worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, metrics_collector) - # These values do not directly impact the adjusted block size calculation, - # so they can be fixed. - gpu_memory_utilization = 0.9 - cpu_swap_space = 100 - block_size = 16 - - num_gpu_blocks, num_cpu_blocks = worker.profile_num_available_blocks( - block_size, gpu_memory_utilization, cpu_swap_space, cache_dtype="auto") + num_gpu_blocks, num_cpu_blocks = worker.determine_num_available_blocks() - target_worker.profile_num_available_blocks.assert_called_once_with( - block_size, gpu_memory_utilization, cpu_swap_space, "auto") + target_worker.determine_num_available_blocks.assert_called_once() assert num_cpu_blocks == available_cpu_blocks assert num_gpu_blocks == split_num_cache_blocks_evenly( diff --git a/tests/spec_decode/utils.py b/tests/spec_decode/utils.py index 5ef1cc28253e9..4637826f254d6 100644 --- a/tests/spec_decode/utils.py +++ b/tests/spec_decode/utils.py @@ -117,6 +117,7 @@ def create_worker(cls: type, parallel_config=engine_config.parallel_config, scheduler_config=engine_config.scheduler_config, device_config=engine_config.device_config, + cache_config=engine_config.cache_config, local_rank=0, rank=0, distributed_init_method=distributed_init_method, @@ -128,8 +129,9 @@ def create_worker(cls: type, engine_config.cache_config.num_gpu_blocks = num_gpu_blocks engine_config.cache_config.num_cpu_blocks = 0 - worker.init_cache_engine(engine_config.cache_config) - worker.warm_up_model() + worker.initialize_cache( + num_gpu_blocks=engine_config.cache_config.num_gpu_blocks, + num_cpu_blocks=engine_config.cache_config.num_cpu_blocks) return worker diff --git a/tests/worker/test_swap.py b/tests/worker/test_swap.py index 5d6ba51ea0f06..8edb1cf05c08e 100644 --- a/tests/worker/test_swap.py +++ b/tests/worker/test_swap.py @@ -11,8 +11,8 @@ def test_swap() -> None: dtype="half", load_format="dummy") engine_config = engine_args.create_engine_config() - engine_config.cache_config.num_gpu_blocks = 100 - engine_config.cache_config.num_cpu_blocks = 100 + engine_config.cache_config.num_gpu_blocks = 1000 + engine_config.cache_config.num_cpu_blocks = 1000 # Create the worker. distributed_init_method = get_distributed_init_method( @@ -22,6 +22,7 @@ def test_swap() -> None: parallel_config=engine_config.parallel_config, scheduler_config=engine_config.scheduler_config, device_config=engine_config.device_config, + cache_config=engine_config.cache_config, local_rank=0, rank=0, distributed_init_method=distributed_init_method, @@ -31,8 +32,9 @@ def test_swap() -> None: # Initialize the worker. worker.init_device() worker.load_model() - worker.init_cache_engine(engine_config.cache_config) - worker.warm_up_model() + worker.initialize_cache( + num_gpu_blocks=engine_config.cache_config.num_gpu_blocks, + num_cpu_blocks=engine_config.cache_config.num_cpu_blocks) # Randomly initialize the cache. gpu_cache = worker.cache_engine.gpu_cache diff --git a/vllm/config.py b/vllm/config.py index 3498d2a285818..84746bd41c205 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -366,7 +366,7 @@ class CacheConfig: vLLM execution. swap_space: Size of the CPU swap space per GPU (in GiB). cache_dtype: Data type for kv cache storage. - forced_num_gpu_blocks: Number of GPU blocks to use. This overrides the + num_gpu_blocks_override: Number of GPU blocks to use. This overrides the profiled num_gpu_blocks if specified. Does nothing if None. """ @@ -376,14 +376,14 @@ def __init__( gpu_memory_utilization: float, swap_space: int, cache_dtype: str, - forced_num_gpu_blocks: Optional[int] = None, + num_gpu_blocks_override: Optional[int] = None, sliding_window: Optional[int] = None, enable_prefix_caching: bool = False, ) -> None: self.block_size = block_size self.gpu_memory_utilization = gpu_memory_utilization self.swap_space_bytes = swap_space * _GB - self.forced_num_gpu_blocks = forced_num_gpu_blocks + self.num_gpu_blocks_override = num_gpu_blocks_override self.cache_dtype = cache_dtype self.sliding_window = sliding_window self.enable_prefix_caching = enable_prefix_caching diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 7e8760d0b0e04..01fd5f023b69b 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -59,7 +59,7 @@ class EngineArgs: max_cpu_loras: Optional[int] = None device: str = 'auto' ray_workers_use_nsight: bool = False - forced_num_gpu_blocks: Optional[int] = None + num_gpu_blocks_override: Optional[int] = None num_lookahead_slots: int = 0 # Related to Vision-language models such as llava @@ -250,7 +250,7 @@ def add_cli_args( 'the model executor, which can range from 0 to 1.' 'If unspecified, will use the default value of 0.9.') parser.add_argument( - '--forced-num-gpu-blocks', + '--num-gpu-blocks-override', type=int, default=None, help='If specified, ignore GPU profiling result and use this number' @@ -454,7 +454,7 @@ def create_engine_config(self, ) -> EngineConfig: cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.swap_space, self.kv_cache_dtype, - self.forced_num_gpu_blocks, + self.num_gpu_blocks_override, model_config.get_sliding_window(), self.enable_prefix_caching) parallel_config = ParallelConfig( diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index e260d7a620b07..2649561101433 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -129,6 +129,8 @@ def __init__( speculative_config=speculative_config, ) + self._initialize_kv_caches() + # If usage stat is enabled, collect relevant info. if is_usage_stats_enabled(): from vllm.model_executor.model_loader import ( @@ -180,6 +182,26 @@ def __init__( labels=dict(model_name=model_config.model)) self.stat_logger.info("cache_config", self.cache_config) + def _initialize_kv_caches(self) -> None: + """Initialize the KV cache in the worker(s). + + The workers will determine the number of blocks in both the GPU cache + and the swap CPU cache. + """ + num_gpu_blocks, num_cpu_blocks = ( + self.model_executor.determine_num_available_blocks()) + + if self.cache_config.num_gpu_blocks_override is not None: + num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override + logger.info(f"Overriding {num_gpu_blocks=} with " + f"{num_gpu_blocks_override=}") + num_gpu_blocks = num_gpu_blocks_override + + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + + self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks) + @classmethod def from_engine_args( cls, diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index 7b3cc784c98e5..2bf97338da0ed 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -35,7 +35,6 @@ def __init__(self, model_config: ModelConfig, cache_config: CacheConfig, # Instantiate the worker and load the model to CPU. self._init_worker() - self._init_cache() def _init_worker(self): from vllm.worker.cpu_worker import CPUWorker @@ -46,10 +45,11 @@ def _init_worker(self): distributed_init_method = get_distributed_init_method( get_ip(), get_open_port()) self.driver_worker = CPUWorker( - self.model_config, - self.parallel_config, - self.scheduler_config, - self.device_config, + model_config=self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + device_config=self.device_config, + cache_config=self.cache_config, local_rank=0, rank=0, distributed_init_method=distributed_init_method, @@ -60,35 +60,21 @@ def _init_worker(self): self.driver_worker.init_device() self.driver_worker.load_model() - def _init_cache(self) -> None: - num_cpu_blocks = self.driver_worker.get_cpu_cache_block_num( - block_size=self.cache_config.block_size, - cache_space=self.cache_config.cpu_kvcache_space_bytes, - cache_dtype=self.cache_config.cache_dtype, - ) - + def determine_num_available_blocks(self) -> tuple[int, int]: + """Determine the number of available KV blocks by invoking the + underlying worker. + """ + return self.driver_worker.determine_num_available_blocks() + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Initialize the KV cache by invoking the underlying worker. + """ + # NOTE: We log here to avoid multiple logs when number of workers is + # greater than one. We could log in the engine, but not all executors + # have GPUs. logger.info(f"# CPU blocks: {num_cpu_blocks}") - if num_cpu_blocks <= 0: - raise ValueError("No available memory for the cache blocks. " - "Try increasing `VLLM_CPU_KVCACHE_SPACE` when " - "initializing the engine.") - - max_seq_len = self.cache_config.block_size * num_cpu_blocks - if self.model_config.max_model_len > max_seq_len: - raise ValueError( - f"The model's max seq len ({self.model_config.max_model_len}) " - "is larger than the maximum number of tokens that can be " - f"stored in KV cache ({max_seq_len}). Try increasing " - "`VLLM_CPU_KVCACHE_SPACE` or decreasing `max_model_len` when " - "initializing the engine.") - - # Note: To reuse the cache management procedure, - # use cpu cache as 'gpu cache'. - self.cache_config.num_gpu_blocks = num_cpu_blocks # type: ignore - self.cache_config.num_cpu_blocks = 0 # type: ignore - - # Initialize the cache. - self.driver_worker.init_cache_engine(cache_config=self.cache_config) + self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) def execute_model(self, seq_group_metadata_list: List[SequenceGroupMetadata], @@ -104,13 +90,13 @@ def execute_model(self, return output def add_lora(self, lora_request: LoRARequest) -> bool: - raise NotImplementedError("LoRA is not implemented for cpu backend.") + return self.driver_worker.add_lora(lora_request) def remove_lora(self, lora_id: int) -> bool: - raise NotImplementedError("LoRA is not implemented for cpu backend.") + return self.driver_worker.remove_lora(lora_id) def list_loras(self) -> List[int]: - raise NotImplementedError("LoRA is not implemented for cpu backend.") + return self.driver_worker.list_loras() def check_health(self) -> None: # CPUExecutor will always be healthy as long as diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 8ec5dfe1e00eb..c18edd75d7a4d 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -30,6 +30,29 @@ def __init__( ) -> None: raise NotImplementedError + @abstractmethod + def determine_num_available_blocks(self) -> tuple[int, int]: + """Determine the number of available blocks for the GPU KV cache and + swappable CPU KV cache. + + Normally, this should simply delegate to the underlying Worker. Some + ExecutorBase may require modification of the result, e.g. to ensure the + selected cache sizes are compatible with all workers. + + Returns a tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks + are blocks that are "active" on the device and can be appended to. + num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be + appended to. + """ + raise NotImplementedError + + @abstractmethod + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Initialize the KV cache with the given size in blocks. + """ + raise NotImplementedError + @abstractmethod def execute_model(self, seq_group_metadata_list: List[SequenceGroupMetadata], diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 7b683107d30e5..80ca5cb7367c5 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -4,7 +4,6 @@ ParallelConfig, SchedulerConfig, SpeculativeConfig, VisionLanguageConfig) from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase -from vllm.executor.utils import check_block_size_valid from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.sequence import SamplerOutput, SequenceGroupMetadata @@ -41,9 +40,6 @@ def __init__( # Instantiate the worker and load the model to GPU. self._init_worker() - # Profile the memory usage and initialize the cache. - self._init_cache() - def _init_worker(self): # Lazy import the Worker to avoid importing torch.cuda/xformers # before CUDA_VISIBLE_DEVICES is set in the Worker @@ -55,61 +51,37 @@ def _init_worker(self): distributed_init_method = get_distributed_init_method( get_ip(), get_open_port()) self.driver_worker = Worker( - self.model_config, - self.parallel_config, - self.scheduler_config, - self.device_config, + model_config=self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + device_config=self.device_config, + cache_config=self.cache_config, local_rank=0, rank=0, distributed_init_method=distributed_init_method, lora_config=self.lora_config, vision_language_config=self.vision_language_config, - kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=True, ) self.driver_worker.init_device() self.driver_worker.load_model() - def _init_cache(self) -> None: - """Profiles the memory usage and initializes the KV cache. - - The engine first profiles the existing memory usage. - Then, it allocates the remaining memory for KV blocks. - - .. tip:: - You may limit the usage of GPU memory - by adjusting the `gpu_memory_utilization` parameter. + def determine_num_available_blocks(self) -> tuple[int, int]: + """Determine the number of available KV blocks by invoking the + underlying worker. """ - # Get the maximum number of blocks that can be allocated on GPU and CPU. - num_gpu_blocks, num_cpu_blocks = ( - self.driver_worker.profile_num_available_blocks( - block_size=self.cache_config.block_size, - gpu_memory_utilization=self.cache_config. - gpu_memory_utilization, - cpu_swap_space=self.cache_config.swap_space_bytes, - cache_dtype=self.cache_config.cache_dtype, - )) - - if self.cache_config.forced_num_gpu_blocks is not None: - forced_num_gpu_blocks = self.cache_config.forced_num_gpu_blocks - logger.info(f"Replacing profiled {num_gpu_blocks=} with " - f"{forced_num_gpu_blocks=}") - num_gpu_blocks = forced_num_gpu_blocks + return self.driver_worker.determine_num_available_blocks() + def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None: + """Initialize the KV cache by invoking the underlying worker. + """ + # NOTE: This is logged in the executor because there can be >1 worker + # with other executors. We could log in the engine level, but work + # remains to abstract away the device for non-GPU configurations. logger.info(f"# GPU blocks: {num_gpu_blocks}, " f"# CPU blocks: {num_cpu_blocks}") - check_block_size_valid(num_gpu_blocks, self.cache_config.block_size, - self.model_config.max_model_len) - - self.cache_config.num_gpu_blocks = num_gpu_blocks - self.cache_config.num_cpu_blocks = num_cpu_blocks - - # Initialize the cache. - self.driver_worker.init_cache_engine(cache_config=self.cache_config) - # Warm up the model. This includes capturing the model into CUDA graph - # if enforce_eager is False. - self.driver_worker.warm_up_model() + self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) def execute_model(self, seq_group_metadata_list: List[SequenceGroupMetadata], diff --git a/vllm/executor/neuron_executor.py b/vllm/executor/neuron_executor.py index c0af058cb90b5..57436a85cfa27 100644 --- a/vllm/executor/neuron_executor.py +++ b/vllm/executor/neuron_executor.py @@ -25,7 +25,6 @@ def __init__( speculative_config: Optional[SpeculativeConfig], ) -> None: self.model_config = model_config - self.cache_config = cache_config assert lora_config is None, "LoRA is not supported for Neuron backend." self.parallel_config = parallel_config self.scheduler_config = scheduler_config @@ -33,12 +32,6 @@ def __init__( assert (not speculative_config ), "Speculative decoding not yet supported for Neuron backend." - # Set the number of GPU blocks to be the same as the maximum number of - # sequences that can be processed in a single batch. This is equivalent - # to schedule without PagedAttention. - self.cache_config.num_gpu_blocks = self.scheduler_config.max_num_seqs - self.cache_config.num_cpu_blocks = 0 - # Instantiate the worker and load the model to the device. self._init_worker() @@ -54,6 +47,18 @@ def _init_worker(self): self.driver_worker.init_device() self.driver_worker.load_model() + def determine_num_available_blocks(self) -> tuple[int, int]: + """Determine the number of available KV blocks by invoking the + underlying worker. + """ + return self.driver_worker.determine_num_available_blocks() + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Initialize the KV cache by invoking the underlying worker. + """ + self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) + def execute_model(self, seq_group_metadata_list: List[SequenceGroupMetadata], blocks_to_swap_in: Dict[int, int], @@ -68,16 +73,13 @@ def execute_model(self, return output def add_lora(self, lora_request: LoRARequest) -> bool: - raise NotImplementedError( - "LoRA is not implemented for neuron backend.") + return self.driver_worker.add_lora(lora_request) def remove_lora(self, lora_id: int) -> bool: - raise NotImplementedError( - "LoRA is not implemented for neuron backend.") + return self.driver_worker.remove_lora(lora_id) def list_loras(self) -> List[int]: - raise NotImplementedError( - "LoRA is not implemented for neuron backend.") + return self.driver_worker.list_loras() def check_health(self) -> None: # NeuronExecutor will always be healthy as long as diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 43cb37cfb5e0c..6c0ccd7e64c90 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -10,7 +10,6 @@ VisionLanguageConfig) from vllm.engine.ray_utils import RayWorkerVllm, ray from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase -from vllm.executor.utils import check_block_size_valid from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.sequence import SamplerOutput, SequenceGroupMetadata @@ -65,9 +64,6 @@ def __init__( # Create the parallel GPU workers. self._init_workers_ray(placement_group) - # Profile the memory usage and initialize the cache. - self._init_cache() - self.forward_dag = None if USE_RAY_COMPILED_DAG: self.forward_dag = self._compiled_ray_dag() @@ -154,8 +150,8 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", scheduler_config = copy.deepcopy(self.scheduler_config) device_config = copy.deepcopy(self.device_config) lora_config = copy.deepcopy(self.lora_config) + cache_config = copy.deepcopy(self.cache_config) vision_language_config = copy.deepcopy(self.vision_language_config) - kv_cache_dtype = self.cache_config.cache_dtype # Initialize the actual workers with the Worker class. for rank, (worker, (node_id, _)) in enumerate( @@ -165,32 +161,32 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", local_rank = node_workers[node_id].index(rank) worker.init_worker.remote( lambda rank=rank, local_rank=local_rank: Worker( - model_config, - parallel_config, - scheduler_config, - device_config, - local_rank, - rank, - distributed_init_method, + model_config=model_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + device_config=device_config, + cache_config=cache_config, + local_rank=local_rank, + rank=rank, + distributed_init_method=distributed_init_method, lora_config=lora_config, vision_language_config=vision_language_config, - kv_cache_dtype=kv_cache_dtype, )) # Initialize the driver worker with the Worker class. driver_rank = 0 driver_local_rank = node_workers[driver_node_id].index(driver_rank) self.driver_worker = Worker( - self.model_config, - self.parallel_config, - self.scheduler_config, - self.device_config, - driver_local_rank, - driver_rank, - distributed_init_method, + model_config=self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + device_config=self.device_config, + cache_config=self.cache_config, + local_rank=driver_local_rank, + rank=driver_rank, + distributed_init_method=distributed_init_method, lora_config=self.lora_config, vision_language_config=self.vision_language_config, - kv_cache_dtype=kv_cache_dtype, is_driver_worker=True, ) @@ -201,35 +197,18 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", max_parallel_loading_workers, ) - def _init_cache(self) -> None: - """Profiles the memory usage and initializes the KV cache. - - The engine will first conduct a profiling of the existing memory usage. - Then, it calculate the maximum possible number of GPU and CPU blocks - that can be allocated with the remaining free memory. - More details can be found in the - :meth:`~vllm.worker.worker.Worker.profile_num_available_blocks` method - from class :class:`~vllm.worker.Worker`. + def determine_num_available_blocks(self) -> tuple[int, int]: + """Determine the number of available KV blocks. - Afterwards, as there may be multiple workers, - we take the minimum number of blocks across all workers - to ensure this can be applied to all of them. + This invokes `determine_num_available_blocks` on each worker and takes + the min of the results, guaranteeing that the selected cache sizes are + compatible with all workers. - Finally, the engine will initialize the KV cache - with the calculated number of blocks. - - .. tip:: - You may limit the usage of GPU memory - by adjusting the `gpu_memory_utilization` parameter. + Returns: + - tuple[num_gpu_blocks, num_cpu_blocks] """ # Get the maximum number of blocks that can be allocated on GPU and CPU. - num_blocks = self._run_workers( - "profile_num_available_blocks", - block_size=self.cache_config.block_size, - gpu_memory_utilization=self.cache_config.gpu_memory_utilization, - cpu_swap_space=self.cache_config.swap_space_bytes, - cache_dtype=self.cache_config.cache_dtype, - ) + num_blocks = self._run_workers("determine_num_available_blocks", ) # Since we use a shared centralized controller, we take the minimum # number of blocks across all workers to make sure all the memory @@ -237,26 +216,25 @@ def _init_cache(self) -> None: num_gpu_blocks = min(b[0] for b in num_blocks) num_cpu_blocks = min(b[1] for b in num_blocks) - if self.cache_config.forced_num_gpu_blocks is not None: - forced_num_gpu_blocks = self.cache_config.forced_num_gpu_blocks - logger.info(f"Replacing profiled {num_gpu_blocks=} with " - f"{forced_num_gpu_blocks=}") - num_gpu_blocks = forced_num_gpu_blocks + return num_gpu_blocks, num_cpu_blocks + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Initialize the KV cache in all workers. + """ + + # NOTE: We log here to avoid multiple logs when number of workers is + # greater than one. We could log in the engine, but not all executors + # have GPUs. logger.info(f"# GPU blocks: {num_gpu_blocks}, " f"# CPU blocks: {num_cpu_blocks}") - check_block_size_valid(num_gpu_blocks, self.cache_config.block_size, - self.model_config.max_model_len) - self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks - # Initialize the cache. - self._run_workers("init_cache_engine", cache_config=self.cache_config) - # Warm up the model. This includes capturing the model into CUDA graph - # if enforce_eager is False. - self._run_workers("warm_up_model") + self._run_workers("initialize_cache", + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks) def execute_model(self, seq_group_metadata_list: List[SequenceGroupMetadata], diff --git a/vllm/executor/utils.py b/vllm/executor/utils.py deleted file mode 100644 index 44976696a77c6..0000000000000 --- a/vllm/executor/utils.py +++ /dev/null @@ -1,13 +0,0 @@ -def check_block_size_valid(num_gpu_blocks, block_size, max_model_len) -> None: - if num_gpu_blocks <= 0: - raise ValueError("No available memory for the cache blocks. " - "Try increasing `gpu_memory_utilization` when " - "initializing the engine.") - max_seq_len = block_size * num_gpu_blocks - if max_model_len > max_seq_len: - raise ValueError( - f"The model's max seq len ({max_model_len}) " - "is larger than the maximum number of tokens that can be " - f"stored in KV cache ({max_seq_len}). Try increasing " - "`gpu_memory_utilization` or decreasing `max_model_len` when " - "initializing the engine.") diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 59f9d5b5107f3..885bf537568e3 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -3,7 +3,6 @@ import torch -from vllm.config import CacheConfig from vllm.model_executor.layers.rejection_sampler import RejectionSampler from vllm.sequence import (SamplerOutput, SequenceGroupMetadata, SequenceGroupOutput, SequenceOutput) @@ -15,9 +14,10 @@ from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range, split_batch_by_proposal_len) from vllm.worker.worker import Worker +from vllm.worker.worker_base import LoraNotSupportedWorkerBase -class SpecDecodeWorker: +class SpecDecodeWorker(LoraNotSupportedWorkerBase): """Worker which implements speculative decoding. Speculative decoding reduces decoding per-token latency by using a proposal @@ -94,10 +94,7 @@ def init_device(self) -> None: device=self.device, vocab_size=self._vocab_size) - def profile_num_available_blocks(self, block_size: int, - gpu_memory_utilization: float, - cpu_swap_space: int, - cache_dtype: str) -> Tuple[int, int]: + def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of cache blocks to use. This is done by profiling the scorer model (which is typically the @@ -106,27 +103,26 @@ def profile_num_available_blocks(self, block_size: int, such that the number of blocks is equal in both KV caches. """ num_gpu_blocks, num_cpu_blocks = ( - self.scorer_worker.profile_num_available_blocks( - block_size, gpu_memory_utilization, cpu_swap_space, - cache_dtype)) + self.scorer_worker.determine_num_available_blocks()) scorer_cache_block_size_bytes = ( - self.scorer_worker.get_cache_block_size_bytes( - block_size, cache_dtype)) + self.scorer_worker.get_cache_block_size_bytes()) proposer_cache_block_size_bytes = ( - self.proposer_worker.get_cache_block_size_bytes( - block_size, cache_dtype)) + self.proposer_worker.get_cache_block_size_bytes()) new_num_gpu_blocks = split_num_cache_blocks_evenly( scorer_cache_block_size_bytes, proposer_cache_block_size_bytes, num_gpu_blocks) return new_num_gpu_blocks, num_cpu_blocks - def init_cache_engine(self, cache_config: CacheConfig): + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: """Initialize the cache engine of the scorer and proposer workers. """ - self.scorer_worker.init_cache_engine(cache_config) - self.proposer_worker.init_cache_engine(cache_config) + self.scorer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks) + self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks) @torch.inference_mode() def execute_model( @@ -351,6 +347,16 @@ def rank(self): def device(self): return self.scorer_worker.device + def get_cache_block_size_bytes(self): + """Return the size of a cache block in bytes. + + This function is only used to compose workers within a SpecDecodeWorker. + We leave composing a SpecDecodeWorker within a SpecDecodeWorker + undefined for now, although it could be implemented in the future. + See https://arxiv.org/abs/2308.04623. + """ + raise NotImplementedError + def split_num_cache_blocks_evenly(scorer_cache_block_size_bytes: int, proposer_cache_block_size_bytes: int, diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 27d1727cd16a3..c34ee0648626b 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -82,8 +82,7 @@ def copy(self, src_to_dsts: Dict[int, List[int]]) -> None: @staticmethod def get_cache_block_size( - block_size: int, - cache_dtype: str, + cache_config: CacheConfig, model_config: ModelConfig, parallel_config: ParallelConfig, ) -> int: @@ -91,13 +90,13 @@ def get_cache_block_size( num_heads = model_config.get_num_kv_heads(parallel_config) num_layers = model_config.get_num_layers(parallel_config) - key_cache_block = block_size * num_heads * head_size + key_cache_block = cache_config.block_size * num_heads * head_size value_cache_block = key_cache_block total = num_layers * (key_cache_block + value_cache_block) - if cache_dtype == "auto": + if cache_config.cache_dtype == "auto": dtype = model_config.dtype else: - dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype] + dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] dtype_size = _get_dtype_size(dtype) return dtype_size * total diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index e1daa64346a9c..42f0828b826e2 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -17,6 +17,7 @@ from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.worker.model_runner import ModelRunner +from vllm.worker.worker_base import LoraNotSupportedWorkerBase logger = init_logger(__name__) @@ -112,7 +113,7 @@ def get_cache_block_size( return dtype_size * total -class CPUWorker: +class CPUWorker(LoraNotSupportedWorkerBase): """A worker class that executes (a partition of) the model on a CPU socket. Each worker is associated with a single CPU socket. The worker is @@ -127,6 +128,7 @@ def __init__( parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, device_config: DeviceConfig, + cache_config: CacheConfig, local_rank: int, rank: int, distributed_init_method: str, @@ -138,6 +140,7 @@ def __init__( self.parallel_config = parallel_config self.scheduler_config = scheduler_config self.device_config = device_config + self.cache_config = cache_config self.local_rank = local_rank self.rank = rank self.distributed_init_method = distributed_init_method @@ -154,8 +157,7 @@ def __init__( kv_cache_dtype=kv_cache_dtype, is_driver_worker=is_driver_worker) # Uninitialized cache engine. Will be initialized by - # self.init_cache_engine(). - self.cache_config = None + # initialize_cache. self.cache_engine = None self.cpu_cache = None @@ -167,28 +169,70 @@ def init_device(self) -> None: def load_model(self): self.model_runner.load_model() - def get_cpu_cache_block_num( - self, - block_size: int, - cache_space: int, - cache_dtype: str, - ) -> int: - """ - Args: - block_size: The size of the cache block. - cache_space: The size of the CPU KV cache space in bytes. + def determine_num_available_blocks(self) -> tuple[int, int]: + """Determine the number of blocks available for the KV cache. + + This determines how many KV blocks can fit into the configured CPU + KV cache space. + + Note that since vLLM assumes a block resides on GPU if it can be + modified, we return num_gpu_blocks=num_cpu_blocks and num_cpu_blocks=0. + This allows us to reuse the scheduler of vLLM without generalizing it + to different devices. """ # For CPU device, the block number will be calculated based on the # cpu_kvcache_space. - cache_block_size = CPUCacheEngine.get_cache_block_size( - block_size, cache_dtype, self.model_config, self.parallel_config) - num_cpu_blocks = int(cache_space // cache_block_size) + cache_block_size = self.get_cache_block_size_bytes() + num_cpu_blocks = int(self.cache_config.cpu_kvcache_space_bytes // + cache_block_size) num_cpu_blocks = max(num_cpu_blocks, 0) - return num_cpu_blocks + # Note: To reuse the cache management procedure, + # use cpu cache as 'gpu cache'. + num_gpu_blocks = num_cpu_blocks + num_cpu_blocks = 0 + return num_gpu_blocks, num_cpu_blocks - def init_cache_engine(self, cache_config: CacheConfig) -> None: - self.cache_config = cache_config + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Initialize the KV cache. Currently, swappable CPU memory is not + supported. + + Since this worker does not support GPUs, we use the num_gpu_blocks to + determine how many non-swappable CPU blocks to allocate. + """ + assert (num_cpu_blocks == 0 + ), f"{type(self)} does not support swappable cache" + + # Note: To reuse the cache management procedure, + # use cpu cache as 'gpu cache'. + num_cpu_blocks = num_gpu_blocks + + self._validate_num_cpu_blocks(num_cpu_blocks) + self.cache_config.num_gpu_blocks = num_cpu_blocks + self.cache_config.num_cpu_blocks = 0 + + # Initialize the cache. + self._init_cache_engine() + + def _validate_num_cpu_blocks(self, num_cpu_blocks: int) -> None: + """Raise errors if the num_cpu_blocks is invalid. + """ + if num_cpu_blocks <= 0: + raise ValueError("No available memory for the cache blocks. " + "Try increasing `VLLM_CPU_KVCACHE_SPACE` when " + "initializing the engine.") + + max_seq_len = self.cache_config.block_size * num_cpu_blocks + if self.model_config.max_model_len > max_seq_len: + raise ValueError( + f"The model's max seq len ({self.model_config.max_model_len}) " + "is larger than the maximum number of tokens that can be " + f"stored in KV cache ({max_seq_len}). Try increasing " + "`VLLM_CPU_KVCACHE_SPACE` or decreasing `max_model_len` when " + "initializing the engine.") + + def _init_cache_engine(self) -> None: self.cache_engine = CPUCacheEngine(self.cache_config, self.model_config, self.parallel_config, @@ -264,3 +308,10 @@ def init_distributed_environment(self) -> None: ensure_model_parallel_initialized( parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size) + + def get_cache_block_size_bytes(self) -> int: + """Return the size in bytes of a single KV cache block. + """ + return CPUCacheEngine.get_cache_block_size( + self.cache_config.block_size, self.cache_config.cache_dtype, + self.model_config, self.parallel_config) diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py index 0ae067aafb29b..6136d50d0c068 100644 --- a/vllm/worker/neuron_worker.py +++ b/vllm/worker/neuron_worker.py @@ -4,14 +4,15 @@ import torch import torch.distributed -from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig, - SchedulerConfig) +from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, + ParallelConfig, SchedulerConfig) from vllm.model_executor import set_random_seed from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.worker.neuron_model_runner import NeuronModelRunner +from vllm.worker.worker_base import LoraNotSupportedWorkerBase -class NeuronWorker: +class NeuronWorker(LoraNotSupportedWorkerBase): """A worker class that executes the model on a group of neuron cores. """ @@ -21,11 +22,13 @@ def __init__( parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, device_config: DeviceConfig, + cache_config: CacheConfig, ) -> None: self.model_config = model_config self.parallel_config = parallel_config self.scheduler_config = scheduler_config self.device_config = device_config + self.cache_config = cache_config self.model_runner = NeuronModelRunner(model_config, parallel_config, scheduler_config, device_config) @@ -37,6 +40,35 @@ def init_device(self) -> None: def load_model(self): self.model_runner.load_model() + def determine_num_available_blocks(self) -> tuple[int, int]: + """Determine the number of available KV blocks. + + Swapping is not yet supported, so always return num_cpu_blocks=0. + + We configure num_gpu_blocks to be equal to max_num_seqs. + """ + # Set the number of GPU blocks to be the same as the maximum number of + # sequences that can be processed in a single batch. This is equivalent + # to schedule without PagedAttention. + num_gpu_blocks = self.scheduler_config.max_num_seqs + + # Swap not yet supported with Neuron backend. + num_cpu_blocks = 0 + + return num_gpu_blocks, num_cpu_blocks + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Initialize the KV cache. + """ + + # Different values are not tested. + assert num_cpu_blocks == 0 + assert num_gpu_blocks == self.scheduler_config.max_num_seqs + + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + @torch.inference_mode() def execute_model( self, @@ -50,3 +82,10 @@ def execute_model( output = self.model_runner.execute_model(seq_group_metadata_list) return output + + def get_cache_block_size_bytes(self) -> int: + """Determine the size in bytes of a cache block. + + This is required for speculative decoding; it is not yet implemented. + """ + raise NotImplementedError diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index bf0c6073ea9a9..19de33089b2db 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -19,9 +19,10 @@ from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.worker.cache_engine import CacheEngine from vllm.worker.model_runner import ModelRunner +from vllm.worker.worker_base import WorkerBase -class Worker: +class Worker(WorkerBase): """A worker class that executes (a partition of) the model on a GPU. Each worker is associated with a single GPU. The worker is responsible for @@ -35,18 +36,19 @@ def __init__( parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, device_config: DeviceConfig, + cache_config: CacheConfig, local_rank: int, rank: int, distributed_init_method: str, lora_config: Optional[LoRAConfig] = None, vision_language_config: Optional[VisionLanguageConfig] = None, - kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, ) -> None: self.model_config = model_config self.parallel_config = parallel_config self.scheduler_config = scheduler_config self.device_config = device_config + self.cache_config = cache_config self.local_rank = local_rank self.rank = rank self.distributed_init_method = distributed_init_method @@ -66,12 +68,11 @@ def __init__( scheduler_config, device_config, lora_config=self.lora_config, - kv_cache_dtype=kv_cache_dtype, + kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=is_driver_worker, vision_language_config=vision_language_config) # Uninitialized cache engine. Will be initialized by - # self.init_cache_engine(). - self.cache_config = None + # initialize_cache. self.cache_engine = None self.gpu_cache = None @@ -107,20 +108,17 @@ def load_model(self): self.model_runner.load_model() @torch.inference_mode() - def profile_num_available_blocks( - self, - block_size: int, - gpu_memory_utilization: float, - cpu_swap_space: int, - cache_dtype: str, - ) -> Tuple[int, int]: - """Profiles the peak memory usage of the model and returns the maximum - number of GPU and CPU cache blocks that can be allocated. - - Args: - block_size: The size of the cache block. - gpu_memory_utilization: The fraction of the total GPU memory to use. - cpu_swap_space: The size of the CPU swap space in bytes. + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Profiles the peak memory usage of the model to determine how many + KV blocks may be allocated without OOMs. + + The engine will first conduct a profiling of the existing memory usage. + Then, it calculate the maximum possible number of GPU and CPU blocks + that can be allocated with the remaining free memory. + + .. tip:: + You may limit the usage of GPU memory + by adjusting the `gpu_memory_utilization` parameter. """ # Profile the memory usage of the model and get the maximum number of # cache blocks that can be allocated with the remaining free memory. @@ -141,12 +139,12 @@ def profile_num_available_blocks( "Error in memory profiling. This happens when the GPU memory was " "not properly cleaned up before initializing the vLLM instance.") - cache_block_size = self.get_cache_block_size_bytes( - block_size, cache_dtype) + cache_block_size = self.get_cache_block_size_bytes() num_gpu_blocks = int( - (total_gpu_memory * gpu_memory_utilization - peak_memory) // - cache_block_size) - num_cpu_blocks = int(cpu_swap_space // cache_block_size) + (total_gpu_memory * self.cache_config.gpu_memory_utilization - + peak_memory) // cache_block_size) + num_cpu_blocks = int(self.cache_config.swap_space_bytes // + cache_block_size) num_gpu_blocks = max(num_gpu_blocks, 0) num_cpu_blocks = max(num_cpu_blocks, 0) if self.model_runner.lora_manager: @@ -155,14 +153,30 @@ def profile_num_available_blocks( torch.cuda.empty_cache() return num_gpu_blocks, num_cpu_blocks - def init_cache_engine(self, cache_config: CacheConfig) -> None: - self.cache_config = cache_config + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Allocate GPU and CPU KV cache with the specified number of blocks. + + This also warms up the model, which may record CUDA graphs. + """ + raise_if_cache_size_invalid(num_gpu_blocks, + self.cache_config.block_size, + self.model_config.max_model_len) + + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + + self._init_cache_engine() + self._warm_up_model() + + def _init_cache_engine(self): + assert self.cache_config.num_gpu_blocks is not None self.cache_engine = CacheEngine(self.cache_config, self.model_config, self.parallel_config) self.gpu_cache = self.cache_engine.gpu_cache self.model_runner.set_block_size(self.cache_engine.block_size) - def warm_up_model(self) -> None: + def _warm_up_model(self) -> None: if not self.model_config.enforce_eager: self.model_runner.capture_model(self.gpu_cache) # Reset the seed to ensure that the random state is not affected by @@ -239,11 +253,10 @@ def max_model_len(self) -> int: def vocab_size(self) -> int: return self.model_runner.vocab_size - def get_cache_block_size_bytes(self, block_size: int, - cache_dtype: str) -> int: + def get_cache_block_size_bytes(self) -> int: """Get the size of the KV cache block size in bytes. """ - return CacheEngine.get_cache_block_size(block_size, cache_dtype, + return CacheEngine.get_cache_block_size(self.cache_config, self.model_config, self.parallel_config) @@ -300,3 +313,19 @@ def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): f"{compute_capability[0]}.{compute_capability[1]}. " "You can use float16 instead by explicitly setting the" "`dtype` flag in CLI, for example: --dtype=half.") + + +def raise_if_cache_size_invalid(num_gpu_blocks, block_size, + max_model_len) -> None: + if num_gpu_blocks <= 0: + raise ValueError("No available memory for the cache blocks. " + "Try increasing `gpu_memory_utilization` when " + "initializing the engine.") + max_seq_len = block_size * num_gpu_blocks + if max_model_len > max_seq_len: + raise ValueError( + f"The model's max seq len ({max_model_len}) " + "is larger than the maximum number of tokens that can be " + f"stored in KV cache ({max_seq_len}). Try increasing " + "`gpu_memory_utilization` or decreasing `max_model_len` when " + "initializing the engine.") diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py new file mode 100644 index 0000000000000..e3027c406ffeb --- /dev/null +++ b/vllm/worker/worker_base.py @@ -0,0 +1,83 @@ +from abc import ABC, abstractmethod +from typing import Dict, List + +from vllm.lora.request import LoRARequest +from vllm.sequence import SamplerOutput, SequenceGroupMetadata + + +class WorkerBase(ABC): + """Worker interface that allows vLLM to cleanly separate implementations for + different hardware. + """ + + @abstractmethod + def init_device(self) -> None: + """Initialize device state, such as loading the model or other on-device + memory allocations. + """ + raise NotImplementedError + + @abstractmethod + def determine_num_available_blocks(self) -> tuple[int, int]: + """Determine the number of available blocks for the GPU KV cache and + swappable CPU KV cache. + + The implementation may run profiling or other heuristics to determine + the size of caches. + + Returns a tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks + are blocks that are "active" on the device and can be appended to. + num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be + appended to. + """ + raise NotImplementedError + + @abstractmethod + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Initialize the KV cache with the given size in blocks. + """ + raise NotImplementedError + + @abstractmethod + def execute_model(self, + seq_group_metadata_list: List[SequenceGroupMetadata], + blocks_to_swap_in: Dict[int, int], + blocks_to_swap_out: Dict[int, int], + blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput: + """Executes one model step on the given sequences.""" + raise NotImplementedError + + @abstractmethod + def get_cache_block_size_bytes() -> int: + """Return the size of a single cache block, in bytes. Used in + speculative decoding. + """ + raise NotImplementedError + + @abstractmethod + def add_lora(self, lora_request: LoRARequest) -> bool: + raise NotImplementedError + + @abstractmethod + def remove_lora(self, lora_id: int) -> bool: + raise NotImplementedError + + @abstractmethod + def list_loras(self) -> List[int]: + raise NotImplementedError + + +class LoraNotSupportedWorkerBase(WorkerBase): + """Partial implementation of WorkerBase that raises exceptions when LoRA + methods are invoked. + """ + + def add_lora(self, lora_request: LoRARequest) -> bool: + raise ValueError(f"{type(self)} does not support LoRA") + + def remove_lora(self, lora_id: int) -> bool: + raise ValueError(f"{type(self)} does not support LoRA") + + def list_loras(self) -> List[int]: + raise ValueError(f"{type(self)} does not support LoRA")