Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
[Misc] [Core] Implement RFC "Augment BaseExecutor interfaces to enabl…
Browse files Browse the repository at this point in the history
…e hardware-agnostic speculative decoding" (vllm-project#3837)
  • Loading branch information
cadedaniel authored and andy-neuma committed Apr 12, 2024
1 parent f7db9ea commit 7e06ab2
Show file tree
Hide file tree
Showing 20 changed files with 453 additions and 277 deletions.
6 changes: 3 additions & 3 deletions tests/core/block/e2e/test_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# Allow only 5 sequences of ~1024 tokens in worst case.
"block_size": 16,
"forced_num_gpu_blocks": 5 * (64 + 1),
"num_gpu_blocks_override": 5 * (64 + 1),
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{
Expand Down Expand Up @@ -162,14 +162,14 @@ def test_v1_v2_greedy_equality_with_cow(baseline_llm_generator,
# Allow only 2 sequences of ~128 tokens in worst case.
# Note 8 = 128/block_size
"forced_num_gpu_blocks": 2 * (8 + 1),
"num_gpu_blocks_override": 2 * (8 + 1),
},
{
"block_size": 8,
# Allow only 2 sequences of ~128 tokens in worst case.
# Note 16 = 128/block_size
"forced_num_gpu_blocks": 2 * (16 + 1),
"num_gpu_blocks_override": 2 * (16 + 1),
}
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{
Expand Down
8 changes: 6 additions & 2 deletions tests/lora/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import tempfile
from unittest.mock import patch

from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig,
SchedulerConfig)
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
from vllm.lora.models import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.worker.worker import Worker
Expand All @@ -27,6 +27,10 @@ def test_worker_apply_lora(sql_lora_files):
parallel_config=ParallelConfig(1, 1, False),
scheduler_config=SchedulerConfig(32, 32, 32),
device_config=DeviceConfig("cuda"),
cache_config=CacheConfig(block_size=16,
gpu_memory_utilization=1.,
swap_space=0,
cache_dtype="auto"),
local_rank=0,
rank=0,
lora_config=LoRAConfig(max_lora_rank=8, max_cpu_loras=32,
Expand Down
35 changes: 13 additions & 22 deletions tests/spec_decode/test_spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,8 +512,8 @@ def test_init_device():


@torch.inference_mode()
def test_init_cache_engine():
"""Verify SpecDecodeWorker invokes init_cache_engine on proposer/scorer
def test_initialize_cache():
"""Verify SpecDecodeWorker invokes initialize_cache on proposer/scorer
workers.
"""
draft_worker = mock_worker(cls=MultiStepWorker)
Expand All @@ -525,23 +525,22 @@ def test_init_cache_engine():
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
metrics_collector)

cache_config = MagicMock()
kwargs = {"num_gpu_blocks": 1024, "num_cpu_blocks": 1023}
worker.initialize_cache(**kwargs)

worker.init_cache_engine(cache_config)

draft_worker.init_cache_engine.assert_called_once_with(cache_config)
target_worker.init_cache_engine.assert_called_once_with(cache_config)
draft_worker.initialize_cache.assert_called_once_with(**kwargs)
target_worker.initialize_cache.assert_called_once_with(**kwargs)


@pytest.mark.parametrize('available_gpu_blocks', [1, 1024])
@pytest.mark.parametrize('available_cpu_blocks', [500])
@pytest.mark.parametrize('target_cache_block_size_bytes', [2 * 2 * 4096])
@pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096])
@pytest.mark.skip_global_cleanup
def test_profile_num_available_blocks(available_gpu_blocks: int,
available_cpu_blocks: int,
target_cache_block_size_bytes: int,
draft_kv_size_bytes: int):
def test_determine_num_available_blocks(available_gpu_blocks: int,
available_cpu_blocks: int,
target_cache_block_size_bytes: int,
draft_kv_size_bytes: int):
"""Verify SpecDecodeWorker correctly profiles num available GPU blocks.
Specifically, it should run profiling in the scorer worker, and then evenly
split the blocks between proposer and scorer worker.
Expand All @@ -552,7 +551,7 @@ def test_profile_num_available_blocks(available_gpu_blocks: int,
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector)

target_worker.profile_num_available_blocks.return_value = (
target_worker.determine_num_available_blocks.return_value = (
available_gpu_blocks, available_cpu_blocks)
target_worker.get_cache_block_size_bytes.return_value = (
target_cache_block_size_bytes)
Expand All @@ -561,17 +560,9 @@ def test_profile_num_available_blocks(available_gpu_blocks: int,
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
metrics_collector)

# These values do not directly impact the adjusted block size calculation,
# so they can be fixed.
gpu_memory_utilization = 0.9
cpu_swap_space = 100
block_size = 16

num_gpu_blocks, num_cpu_blocks = worker.profile_num_available_blocks(
block_size, gpu_memory_utilization, cpu_swap_space, cache_dtype="auto")
num_gpu_blocks, num_cpu_blocks = worker.determine_num_available_blocks()

target_worker.profile_num_available_blocks.assert_called_once_with(
block_size, gpu_memory_utilization, cpu_swap_space, "auto")
target_worker.determine_num_available_blocks.assert_called_once()
assert num_cpu_blocks == available_cpu_blocks

assert num_gpu_blocks == split_num_cache_blocks_evenly(
Expand Down
6 changes: 4 additions & 2 deletions tests/spec_decode/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def create_worker(cls: type,
parallel_config=engine_config.parallel_config,
scheduler_config=engine_config.scheduler_config,
device_config=engine_config.device_config,
cache_config=engine_config.cache_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
Expand All @@ -128,8 +129,9 @@ def create_worker(cls: type,

engine_config.cache_config.num_gpu_blocks = num_gpu_blocks
engine_config.cache_config.num_cpu_blocks = 0
worker.init_cache_engine(engine_config.cache_config)
worker.warm_up_model()
worker.initialize_cache(
num_gpu_blocks=engine_config.cache_config.num_gpu_blocks,
num_cpu_blocks=engine_config.cache_config.num_cpu_blocks)

return worker

Expand Down
10 changes: 6 additions & 4 deletions tests/worker/test_swap.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ def test_swap() -> None:
dtype="half",
load_format="dummy")
engine_config = engine_args.create_engine_config()
engine_config.cache_config.num_gpu_blocks = 100
engine_config.cache_config.num_cpu_blocks = 100
engine_config.cache_config.num_gpu_blocks = 1000
engine_config.cache_config.num_cpu_blocks = 1000

# Create the worker.
distributed_init_method = get_distributed_init_method(
Expand All @@ -22,6 +22,7 @@ def test_swap() -> None:
parallel_config=engine_config.parallel_config,
scheduler_config=engine_config.scheduler_config,
device_config=engine_config.device_config,
cache_config=engine_config.cache_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
Expand All @@ -31,8 +32,9 @@ def test_swap() -> None:
# Initialize the worker.
worker.init_device()
worker.load_model()
worker.init_cache_engine(engine_config.cache_config)
worker.warm_up_model()
worker.initialize_cache(
num_gpu_blocks=engine_config.cache_config.num_gpu_blocks,
num_cpu_blocks=engine_config.cache_config.num_cpu_blocks)

# Randomly initialize the cache.
gpu_cache = worker.cache_engine.gpu_cache
Expand Down
6 changes: 3 additions & 3 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ class CacheConfig:
vLLM execution.
swap_space: Size of the CPU swap space per GPU (in GiB).
cache_dtype: Data type for kv cache storage.
forced_num_gpu_blocks: Number of GPU blocks to use. This overrides the
num_gpu_blocks_override: Number of GPU blocks to use. This overrides the
profiled num_gpu_blocks if specified. Does nothing if None.
"""

Expand All @@ -376,14 +376,14 @@ def __init__(
gpu_memory_utilization: float,
swap_space: int,
cache_dtype: str,
forced_num_gpu_blocks: Optional[int] = None,
num_gpu_blocks_override: Optional[int] = None,
sliding_window: Optional[int] = None,
enable_prefix_caching: bool = False,
) -> None:
self.block_size = block_size
self.gpu_memory_utilization = gpu_memory_utilization
self.swap_space_bytes = swap_space * _GB
self.forced_num_gpu_blocks = forced_num_gpu_blocks
self.num_gpu_blocks_override = num_gpu_blocks_override
self.cache_dtype = cache_dtype
self.sliding_window = sliding_window
self.enable_prefix_caching = enable_prefix_caching
Expand Down
6 changes: 3 additions & 3 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class EngineArgs:
max_cpu_loras: Optional[int] = None
device: str = 'auto'
ray_workers_use_nsight: bool = False
forced_num_gpu_blocks: Optional[int] = None
num_gpu_blocks_override: Optional[int] = None
num_lookahead_slots: int = 0

# Related to Vision-language models such as llava
Expand Down Expand Up @@ -250,7 +250,7 @@ def add_cli_args(
'the model executor, which can range from 0 to 1.'
'If unspecified, will use the default value of 0.9.')
parser.add_argument(
'--forced-num-gpu-blocks',
'--num-gpu-blocks-override',
type=int,
default=None,
help='If specified, ignore GPU profiling result and use this number'
Expand Down Expand Up @@ -454,7 +454,7 @@ def create_engine_config(self, ) -> EngineConfig:
cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization,
self.swap_space, self.kv_cache_dtype,
self.forced_num_gpu_blocks,
self.num_gpu_blocks_override,
model_config.get_sliding_window(),
self.enable_prefix_caching)
parallel_config = ParallelConfig(
Expand Down
22 changes: 22 additions & 0 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ def __init__(
speculative_config=speculative_config,
)

self._initialize_kv_caches()

# If usage stat is enabled, collect relevant info.
if is_usage_stats_enabled():
from vllm.model_executor.model_loader import (
Expand Down Expand Up @@ -180,6 +182,26 @@ def __init__(
labels=dict(model_name=model_config.model))
self.stat_logger.info("cache_config", self.cache_config)

def _initialize_kv_caches(self) -> None:
"""Initialize the KV cache in the worker(s).
The workers will determine the number of blocks in both the GPU cache
and the swap CPU cache.
"""
num_gpu_blocks, num_cpu_blocks = (
self.model_executor.determine_num_available_blocks())

if self.cache_config.num_gpu_blocks_override is not None:
num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override
logger.info(f"Overriding {num_gpu_blocks=} with "
f"{num_gpu_blocks_override=}")
num_gpu_blocks = num_gpu_blocks_override

self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks

self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)

@classmethod
def from_engine_args(
cls,
Expand Down
58 changes: 22 additions & 36 deletions vllm/executor/cpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def __init__(self, model_config: ModelConfig, cache_config: CacheConfig,

# Instantiate the worker and load the model to CPU.
self._init_worker()
self._init_cache()

def _init_worker(self):
from vllm.worker.cpu_worker import CPUWorker
Expand All @@ -46,10 +45,11 @@ def _init_worker(self):
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
self.driver_worker = CPUWorker(
self.model_config,
self.parallel_config,
self.scheduler_config,
self.device_config,
model_config=self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
device_config=self.device_config,
cache_config=self.cache_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
Expand All @@ -60,35 +60,21 @@ def _init_worker(self):
self.driver_worker.init_device()
self.driver_worker.load_model()

def _init_cache(self) -> None:
num_cpu_blocks = self.driver_worker.get_cpu_cache_block_num(
block_size=self.cache_config.block_size,
cache_space=self.cache_config.cpu_kvcache_space_bytes,
cache_dtype=self.cache_config.cache_dtype,
)

def determine_num_available_blocks(self) -> tuple[int, int]:
"""Determine the number of available KV blocks by invoking the
underlying worker.
"""
return self.driver_worker.determine_num_available_blocks()

def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Initialize the KV cache by invoking the underlying worker.
"""
# NOTE: We log here to avoid multiple logs when number of workers is
# greater than one. We could log in the engine, but not all executors
# have GPUs.
logger.info(f"# CPU blocks: {num_cpu_blocks}")
if num_cpu_blocks <= 0:
raise ValueError("No available memory for the cache blocks. "
"Try increasing `VLLM_CPU_KVCACHE_SPACE` when "
"initializing the engine.")

max_seq_len = self.cache_config.block_size * num_cpu_blocks
if self.model_config.max_model_len > max_seq_len:
raise ValueError(
f"The model's max seq len ({self.model_config.max_model_len}) "
"is larger than the maximum number of tokens that can be "
f"stored in KV cache ({max_seq_len}). Try increasing "
"`VLLM_CPU_KVCACHE_SPACE` or decreasing `max_model_len` when "
"initializing the engine.")

# Note: To reuse the cache management procedure,
# use cpu cache as 'gpu cache'.
self.cache_config.num_gpu_blocks = num_cpu_blocks # type: ignore
self.cache_config.num_cpu_blocks = 0 # type: ignore

# Initialize the cache.
self.driver_worker.init_cache_engine(cache_config=self.cache_config)
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)

def execute_model(self,
seq_group_metadata_list: List[SequenceGroupMetadata],
Expand All @@ -104,13 +90,13 @@ def execute_model(self,
return output

def add_lora(self, lora_request: LoRARequest) -> bool:
raise NotImplementedError("LoRA is not implemented for cpu backend.")
return self.driver_worker.add_lora(lora_request)

def remove_lora(self, lora_id: int) -> bool:
raise NotImplementedError("LoRA is not implemented for cpu backend.")
return self.driver_worker.remove_lora(lora_id)

def list_loras(self) -> List[int]:
raise NotImplementedError("LoRA is not implemented for cpu backend.")
return self.driver_worker.list_loras()

def check_health(self) -> None:
# CPUExecutor will always be healthy as long as
Expand Down
23 changes: 23 additions & 0 deletions vllm/executor/executor_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,29 @@ def __init__(
) -> None:
raise NotImplementedError

@abstractmethod
def determine_num_available_blocks(self) -> tuple[int, int]:
"""Determine the number of available blocks for the GPU KV cache and
swappable CPU KV cache.
Normally, this should simply delegate to the underlying Worker. Some
ExecutorBase may require modification of the result, e.g. to ensure the
selected cache sizes are compatible with all workers.
Returns a tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
are blocks that are "active" on the device and can be appended to.
num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
appended to.
"""
raise NotImplementedError

@abstractmethod
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Initialize the KV cache with the given size in blocks.
"""
raise NotImplementedError

@abstractmethod
def execute_model(self,
seq_group_metadata_list: List[SequenceGroupMetadata],
Expand Down
Loading

0 comments on commit 7e06ab2

Please sign in to comment.