diff --git a/docs/source/models/adding_model.rst b/docs/source/models/adding_model.rst index 1184a23224568..bf243a044769f 100644 --- a/docs/source/models/adding_model.rst +++ b/docs/source/models/adding_model.rst @@ -58,11 +58,10 @@ Next, you need to rewrite the :code:`forward` methods of your model by following + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, - + cache_events: Optional[List[torch.cuda.Event]], - +) -> SamplerOutput: + +) -> Optional[SamplerOutput]: -3. Update the code by considering that :code:`input_ids` and :code:`positions` are now flattened tensors. -4. Replace the attention operation with either :code:`PagedAttention`, :code:`PagedAttentionWithRoPE`, or :code:`PagedAttentionWithALiBi` depending on the model's architecture. +1. Update the code by considering that :code:`input_ids` and :code:`positions` are now flattened tensors. +2. Replace the attention operation with either :code:`PagedAttention`, :code:`PagedAttentionWithRoPE`, or :code:`PagedAttentionWithALiBi` depending on the model's architecture. .. note:: Currently, vLLM supports the basic multi-head attention mechanism and its variant with rotary positional embeddings. diff --git a/requirements-rocm.txt b/requirements-rocm.txt index 81bc19580274c..fd537f9cd4611 100644 --- a/requirements-rocm.txt +++ b/requirements-rocm.txt @@ -3,8 +3,6 @@ typing-extensions>=4.8.0 starlette psutil ray >= 2.5.1 -pandas # Required for Ray data. -pyarrow # Required for Ray data. sentencepiece # Required for LLaMA tokenizer. numpy tokenizers>=0.15.0 diff --git a/requirements.txt b/requirements.txt index 92ba0a716c45c..cee7f190db317 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,6 @@ ninja # For faster builds. psutil ray >= 2.5.1 -pandas # Required for Ray data. -pyarrow # Required for Ray data. sentencepiece # Required for LLaMA tokenizer. numpy torch == 2.1.2 diff --git a/tests/async_engine/test_api_server.py b/tests/async_engine/test_api_server.py index f1891f78491c2..0b45e10dc5550 100644 --- a/tests/async_engine/test_api_server.py +++ b/tests/async_engine/test_api_server.py @@ -8,11 +8,11 @@ import requests -def _query_server(prompt: str) -> dict: +def _query_server(prompt: str, max_tokens: int = 5) -> dict: response = requests.post("http://localhost:8000/generate", json={ "prompt": prompt, - "max_tokens": 100, + "max_tokens": max_tokens, "temperature": 0, "ignore_eos": True }) @@ -20,6 +20,10 @@ def _query_server(prompt: str) -> dict: return response.json() +def _query_server_long(prompt: str) -> dict: + return _query_server(prompt, max_tokens=500) + + @pytest.fixture def api_server(): script_path = Path(__file__).parent.joinpath( @@ -68,10 +72,11 @@ def test_api_server(api_server): for result in pool.map(_query_server, prompts): assert result + with Pool(32) as pool: # Cancel requests prompts = ["canceled requests"] * 100 - pool.map_async(_query_server, prompts) - time.sleep(0.001) + pool.map_async(_query_server_long, prompts) + time.sleep(0.01) pool.terminate() pool.join() diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 1d8d41e013b03..3749592a0ec71 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -49,12 +49,13 @@ def test_copy_blocks( src_blocks = random.sample(range(num_blocks), num_mappings) remainig_blocks = list(set(range(num_blocks)) - set(src_blocks)) dst_blocks = random.sample(remainig_blocks, 2 * num_mappings) - block_mapping = {} + copy_src = [] + copy_dst = [] for i in range(num_mappings): - src = src_blocks[i] - dst1 = dst_blocks[2 * i] - dst2 = dst_blocks[2 * i + 1] - block_mapping[src] = [dst1, dst2] + copy_src.append(src_blocks[i]) + copy_dst.append(dst_blocks[2 * i]) + copy_src.append(src_blocks[i]) + copy_dst.append(dst_blocks[2 * i + 1]) # Create the KV caches. key_caches, value_caches = kv_cache_factory(num_blocks, block_size, @@ -66,15 +67,14 @@ def test_copy_blocks( cloned_value_caches = [value_cache.clone() for value_cache in value_caches] # Call the copy blocks kernel. - cache_ops.copy_blocks(key_caches, value_caches, block_mapping) + cache_ops.copy_blocks(key_caches, value_caches, copy_src, copy_dst) # Run the reference implementation. - for src, dsts in block_mapping.items(): - for dst in dsts: - for cloned_key_cache in cloned_key_caches: - cloned_key_cache[dst].copy_(cloned_key_cache[src]) - for cloned_value_cache in cloned_value_caches: - cloned_value_cache[dst].copy_(cloned_value_cache[src]) + for src, dst in zip(copy_src, copy_dst): + for cloned_key_cache in cloned_key_caches: + cloned_key_cache[dst].copy_(cloned_key_cache[src]) + for cloned_value_cache in cloned_value_caches: + cloned_value_cache[dst].copy_(cloned_value_cache[src]) # Compare the results. for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches): diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 949a7e2292a4c..250d84caf56d4 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -33,8 +33,9 @@ def test_prepare_prompt(): expected_selected_token_indices.append(selected_token_start_idx + prompt_len - 1) selected_token_start_idx += max_seq_len - input_tokens, input_positions, _ = model_runner._prepare_prompt( - seq_group_metadata_list) + input_tokens, input_positions, _, return_prompt_lens = ( + model_runner._prepare_prompt(seq_group_metadata_list)) + assert return_prompt_lens == prompt_lens sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, prompt_lens) assert input_tokens.shape == (batch_size, max_seq_len) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 611da51f61931..fbe4a4e5d4599 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -185,14 +185,21 @@ async def step_async(self) -> List[RequestOutput]: """ seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() - # Execute the model. - output = (await self._run_workers_async( - "execute_model", - seq_group_metadata_list=seq_group_metadata_list, - blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, - blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, - blocks_to_copy=scheduler_outputs.blocks_to_copy, - )) if not scheduler_outputs.is_empty() else [] + if not scheduler_outputs.is_empty(): + # Execute the model. + all_outputs = await self._run_workers_async( + "execute_model", + driver_kwargs={ + "seq_group_metadata_list": seq_group_metadata_list, + "blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in, + "blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out, + "blocks_to_copy": scheduler_outputs.blocks_to_copy, + }) + + # Only the driver worker returns the sampling results. + output = all_outputs[0] + else: + output = [] return self._process_model_outputs(output, scheduler_outputs) @@ -200,30 +207,29 @@ async def _run_workers_async( self, method: str, *args, - get_all_outputs: bool = False, + driver_args: Optional[List[Any]] = None, + driver_kwargs: Optional[Dict[str, Any]] = None, **kwargs, ) -> Any: """Runs the given method on all workers.""" coros = [] - for worker in self.workers: - if self.parallel_config.worker_use_ray: - coros.append( - worker.execute_method.remote(method, *args, **kwargs)) - else: - executor = getattr(worker, method) - coros.append(asyncio.get_event_loop().run_in_executor( - None, partial(executor, *args, **kwargs))) - all_outputs = await asyncio.gather(*coros) + if driver_args is None: + driver_args = args + if driver_kwargs is None: + driver_kwargs = kwargs - if get_all_outputs: - return all_outputs + # Run the driver worker asynchronously. + driver_executor = getattr(self.driver_worker, method) + coros.append(asyncio.get_event_loop().run_in_executor( + None, partial(driver_executor, *driver_args, **driver_kwargs))) - # Make sure all workers have the same results. - output = all_outputs[0] - for other_output in all_outputs[1:]: - assert output == other_output - return output + # Run the ray workers asynchronously. + for worker in self.workers: + coros.append(worker.execute_method.remote(method, *args, **kwargs)) + + all_outputs = await asyncio.gather(*coros) + return all_outputs class AsyncLLMEngine: @@ -488,13 +494,12 @@ def from_engine_args(cls, engine_configs = engine_args.create_engine_configs() parallel_config = engine_configs[2] # Initialize the cluster. - distributed_init_method, placement_group = initialize_cluster( - parallel_config, engine_args.engine_use_ray) + placement_group = initialize_cluster(parallel_config, + engine_args.engine_use_ray) # Create the async LLM engine. engine = cls(parallel_config.worker_use_ray, engine_args.engine_use_ray, *engine_configs, - distributed_init_method, placement_group, log_requests=not engine_args.disable_log_requests, log_stats=not engine_args.disable_log_stats, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 43bf9747ee184..0e36a50a57094 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1,8 +1,9 @@ import copy +from collections import defaultdict import os import time -from functools import partial -from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union +from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, + Union) from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig) @@ -17,10 +18,9 @@ SequenceGroupOutput, SequenceOutput, SequenceStatus) from vllm.transformers_utils.tokenizer import (detokenize_incrementally, get_tokenizer) -from vllm.utils import Counter +from vllm.utils import Counter, set_cuda_visible_devices, get_ip, get_open_port if ray: - from ray.air.util.torch_dist import init_torch_dist_process_group from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy if TYPE_CHECKING: @@ -53,8 +53,6 @@ class LLMEngine: management. parallel_config: The configuration related to distributed execution. scheduler_config: The configuration related to the request scheduler. - distributed_init_method: The initialization method for distributed - execution. See `torch.distributed.init_process_group` for details. placement_group: Ray placement group for distributed execution. Required for distributed execution. log_stats: Whether to log statistics. @@ -66,7 +64,6 @@ def __init__( cache_config: CacheConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, - distributed_init_method: str, placement_group: Optional["PlacementGroup"], log_stats: bool, ) -> None: @@ -111,7 +108,7 @@ def __init__( os.environ["RAY_USAGE_STATS_ENABLED"] = "0" self._init_workers_ray(placement_group) else: - self._init_workers(distributed_init_method) + self._init_workers() # Profile the memory usage and initialize the cache. self._init_cache() @@ -126,7 +123,7 @@ def __init__( # List of (timestamp, num_tokens) self.num_generation_tokens: List[Tuple[float, int]] = [] - def _init_workers(self, distributed_init_method: str): + def _init_workers(self): # Lazy import the Worker to avoid importing torch.cuda/xformers # before CUDA_VISIBLE_DEVICES is set in the Worker from vllm.worker.worker import Worker @@ -135,70 +132,122 @@ def _init_workers(self, distributed_init_method: str): "Ray is required if parallel_config.world_size > 1.") self.workers: List[Worker] = [] - worker = Worker( + distributed_init_method = f"tcp://{get_ip()}:{get_open_port()}" + self.driver_worker = Worker( self.model_config, self.parallel_config, self.scheduler_config, - 0, - distributed_init_method, - ) - self.workers.append(worker) - self._run_workers( - "init_model", - get_all_outputs=True, - ) - self._run_workers( - "load_model", - get_all_outputs=True, - max_concurrent_workers=self.parallel_config. - max_parallel_loading_workers, + local_rank=0, + rank=0, + distributed_init_method=distributed_init_method, + is_driver_worker=True, ) + self._run_workers("init_model") + self._run_workers("load_model") def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwargs): - # Lazy import the Worker to avoid importing torch.cuda/xformers - # before CUDA_VISIBLE_DEVICES is set in the Worker - from vllm.worker.worker import Worker + if self.parallel_config.tensor_parallel_size == 1: + num_gpus = self.cache_config.gpu_memory_utilization + else: + num_gpus = 1 - self.workers: List[Worker] = [] - for bundle in placement_group.bundle_specs: + self.driver_dummy_worker: RayWorkerVllm = None + self.workers: List[RayWorkerVllm] = [] + + driver_ip = get_ip() + for bundle_id, bundle in enumerate(placement_group.bundle_specs): if not bundle.get("GPU", 0): continue - if self.parallel_config.tensor_parallel_size == 1: - num_gpus = self.cache_config.gpu_memory_utilization - else: - num_gpus = 1 + scheduling_strategy = PlacementGroupSchedulingStrategy( + placement_group=placement_group, + placement_group_capture_child_tasks=True, + placement_group_bundle_index=bundle_id, + ) worker = ray.remote( num_cpus=0, num_gpus=num_gpus, - scheduling_strategy=PlacementGroupSchedulingStrategy( - placement_group=placement_group, - placement_group_capture_child_tasks=True), + scheduling_strategy=scheduling_strategy, **ray_remote_kwargs, )(RayWorkerVllm).remote(self.model_config.trust_remote_code) - self.workers.append(worker) + + worker_ip = ray.get(worker.get_node_ip.remote()) + if worker_ip == driver_ip and self.driver_dummy_worker is None: + # If the worker is on the same node as the driver, we use it + # as the resource holder for the driver process. + self.driver_dummy_worker = worker + else: + self.workers.append(worker) + + if self.driver_dummy_worker is None: + raise ValueError( + "Ray does not allocate any GPUs on the driver node. Consider " + "adjusting the Ray placement group or running the driver on a " + "GPU node.") + + driver_node_id, driver_gpu_ids = ray.get( + self.driver_dummy_worker.get_node_and_gpu_ids.remote()) + worker_node_and_gpu_ids = ray.get( + [worker.get_node_and_gpu_ids.remote() for worker in self.workers]) + + node_workers = defaultdict(list) + node_gpus = defaultdict(list) + + node_workers[driver_node_id].append(0) + node_gpus[driver_node_id].extend(driver_gpu_ids) + for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids, + start=1): + node_workers[node_id].append(i) + node_gpus[node_id].extend(gpu_ids) + for node_id, gpu_ids in node_gpus.items(): + node_gpus[node_id] = sorted(gpu_ids) + + # Set CUDA_VISIBLE_DEVICES for the driver. + set_cuda_visible_devices(node_gpus[driver_node_id]) + for worker, (node_id, _) in zip(self.workers, worker_node_and_gpu_ids): + worker.set_cuda_visible_devices.remote(node_gpus[node_id]) + + distributed_init_method = f"tcp://{driver_ip}:{get_open_port()}" + + # Lazy import the Worker to avoid importing torch.cuda/xformers + # before CUDA_VISIBLE_DEVICES is set in the Worker + from vllm.worker.worker import Worker # Initialize torch distributed process group for the workers. - init_torch_dist_process_group(self.workers, backend="nccl") model_config = copy.deepcopy(self.model_config) parallel_config = copy.deepcopy(self.parallel_config) scheduler_config = copy.deepcopy(self.scheduler_config) - self._run_workers("init_worker", - get_all_outputs=True, - worker_init_fn=lambda: Worker( - model_config, - parallel_config, - scheduler_config, - None, - None, - )) - self._run_workers( - "init_model", - get_all_outputs=True, + + for rank, (worker, (node_id, + _)) in enumerate(zip(self.workers, + worker_node_and_gpu_ids), + start=1): + local_rank = node_workers[node_id].index(rank) + worker.init_worker.remote( + lambda rank=rank, local_rank=local_rank: Worker( + model_config, + parallel_config, + scheduler_config, + local_rank, + rank, + distributed_init_method, + )) + + driver_rank = 0 + driver_local_rank = node_workers[driver_node_id].index(driver_rank) + self.driver_worker = Worker( + model_config, + parallel_config, + scheduler_config, + driver_local_rank, + driver_rank, + distributed_init_method, + is_driver_worker=True, ) + + self._run_workers("init_model") self._run_workers( "load_model", - get_all_outputs=True, max_concurrent_workers=self.parallel_config. max_parallel_loading_workers, ) @@ -212,7 +261,6 @@ def _init_cache(self) -> None: # Get the maximum number of blocks that can be allocated on GPU and CPU. num_blocks = self._run_workers( "profile_num_available_blocks", - get_all_outputs=True, block_size=self.cache_config.block_size, gpu_memory_utilization=self.cache_config.gpu_memory_utilization, cpu_swap_space=self.cache_config.swap_space_bytes, @@ -256,11 +304,9 @@ def from_engine_args(cls, engine_args: EngineArgs) -> "LLMEngine": engine_configs = engine_args.create_engine_configs() parallel_config = engine_configs[2] # Initialize the cluster. - distributed_init_method, placement_group = initialize_cluster( - parallel_config) + placement_group = initialize_cluster(parallel_config) # Create the LLM engine. engine = cls(*engine_configs, - distributed_init_method, placement_group, log_stats=not engine_args.disable_log_stats) return engine @@ -577,14 +623,21 @@ def step(self) -> List[RequestOutput]: """ seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() - # Execute the model. - output = self._run_workers( - "execute_model", - seq_group_metadata_list=seq_group_metadata_list, - blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, - blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, - blocks_to_copy=scheduler_outputs.blocks_to_copy, - ) if not scheduler_outputs.is_empty() else [] + if not scheduler_outputs.is_empty(): + # Execute the model. + all_outputs = self._run_workers( + "execute_model", + driver_kwargs={ + "seq_group_metadata_list": seq_group_metadata_list, + "blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in, + "blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out, + "blocks_to_copy": scheduler_outputs.blocks_to_copy, + }) + + # Only the driver worker returns the sampling results. + output = all_outputs[0] + else: + output = [] return self._process_model_outputs(output, scheduler_outputs) @@ -712,53 +765,38 @@ def _check_stop(self, seq: Sequence, seq.status = SequenceStatus.FINISHED_STOPPED return - def _run_workers_in_batch( - self, - workers, - method: str, - *args, - **kwargs, - ): - all_outputs = [] - for worker in workers: - if self.parallel_config.worker_use_ray: - executor = partial(worker.execute_method.remote, method) - else: - executor = getattr(worker, method) - - output = executor(*args, **kwargs) - all_outputs.append(output) - if self.parallel_config.worker_use_ray: - all_outputs = ray.get(all_outputs) - return all_outputs - def _run_workers( self, method: str, *args, - get_all_outputs: bool = False, + driver_args: Optional[List[Any]] = None, + driver_kwargs: Optional[Dict[str, Any]] = None, max_concurrent_workers: Optional[int] = None, **kwargs, ) -> Any: """Runs the given method on all workers.""" - all_outputs = [] + if max_concurrent_workers: - work_groups = [ - self.workers[i:i + max_concurrent_workers] - for i in range(0, len(self.workers), max_concurrent_workers) - ] - else: - work_groups = [self.workers] + raise NotImplementedError( + "max_concurrent_workers is not supported yet.") + + # Start the ray workers first. + ray_worker_outputs = [ + worker.execute_method.remote(method, *args, **kwargs) + for worker in self.workers + ] + + if driver_args is None: + driver_args = args + if driver_kwargs is None: + driver_kwargs = kwargs - for workers in work_groups: - all_outputs.extend( - self._run_workers_in_batch(workers, method, *args, **kwargs)) + # Start the driver worker after all the ray workers. + driver_worker_output = getattr(self.driver_worker, + method)(*driver_args, **driver_kwargs) - if get_all_outputs: - return all_outputs + # Get the results of the ray workers. + if self.workers: + ray_worker_outputs = ray.get(ray_worker_outputs) - # Make sure all workers have the same results. - output = all_outputs[0] - for other_output in all_outputs[1:]: - assert output == other_output - return output + return [driver_worker_output] + ray_worker_outputs diff --git a/vllm/engine/ray_utils.py b/vllm/engine/ray_utils.py index f402da4c621dd..52e5be0227cd9 100644 --- a/vllm/engine/ray_utils.py +++ b/vllm/engine/ray_utils.py @@ -1,16 +1,15 @@ -from typing import Optional, Tuple, TYPE_CHECKING +from typing import Optional, List, Tuple, TYPE_CHECKING from vllm.config import ParallelConfig from vllm.logger import init_logger -from vllm.utils import get_open_port, is_hip +from vllm.utils import is_hip, set_cuda_visible_devices, get_ip logger = init_logger(__name__) try: import ray - from ray.air.util.torch_dist import TorchDistributedWorker - class RayWorkerVllm(TorchDistributedWorker): + class RayWorkerVllm: """Ray wrapper for vllm.worker.Worker, allowing Worker to be lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES.""" @@ -30,12 +29,22 @@ def execute_method(self, method, *args, **kwargs): executor = getattr(self, method) return executor(*args, **kwargs) + def get_node_ip(self) -> str: + return get_ip() + + def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]: + node_id = ray.get_runtime_context().get_node_id() + gpu_ids = ray.get_gpu_ids() + return node_id, gpu_ids + + def set_cuda_visible_devices(self, device_ids) -> None: + set_cuda_visible_devices(device_ids) + except ImportError as e: logger.warning(f"Failed to import Ray with {e!r}. " "For distributed inference, please install Ray with " "`pip install ray pandas pyarrow`.") ray = None - TorchDistributedWorker = None RayWorkerVllm = None if TYPE_CHECKING: @@ -75,13 +84,11 @@ def initialize_cluster( ray.init(address=ray_address, ignore_reinit_error=True) if not parallel_config.worker_use_ray: - # Initialize cluster locally. - port = get_open_port() - # We need to setup the distributed init method to make sure - # the distributed megatron code (e.g., get world size) works correctly. - distributed_init_method = f"tcp://localhost:{port}" - return distributed_init_method, None + assert parallel_config.world_size == 1, ( + "Ray is required if parallel_config.world_size > 1.") + return None + # Create placement group for worker processes current_placement_group = ray.util.get_current_placement_group() if current_placement_group: # We are in a placement group @@ -106,12 +113,12 @@ def initialize_cluster( "The number of required GPUs exceeds the total number of " "available GPUs in the cluster.") # Create a new placement group - current_placement_group = ray.util.placement_group([{ - "GPU": 1 - }] * parallel_config.world_size) + placement_group_specs = ([{"GPU": 1}] * parallel_config.world_size) + current_placement_group = ray.util.placement_group( + placement_group_specs) # Wait until PG is ready - this will block until all # requested resources are available, and will timeout # if they cannot be provisioned. ray.get(current_placement_group.ready(), timeout=1800) - return None, current_placement_group + return current_placement_group diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py index af6f4921856e1..da615ecccf993 100644 --- a/vllm/model_executor/input_metadata.py +++ b/vllm/model_executor/input_metadata.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Optional import torch @@ -16,28 +16,27 @@ class InputMetadata: def __init__( self, - prompt_lens: List[int], + is_prompt: bool, slot_mapping: torch.Tensor, max_context_len: Optional[int], context_lens: Optional[torch.Tensor], block_tables: Optional[torch.Tensor], use_cuda_graph: bool, ) -> None: - self.prompt_lens = prompt_lens + self.is_prompt = is_prompt self.max_context_len = max_context_len self.slot_mapping = slot_mapping self.context_lens = context_lens self.block_tables = block_tables self.use_cuda_graph = use_cuda_graph - self.is_prompt = len(prompt_lens) > 0 # Set during the execution of the first attention op. # FIXME(woosuk): This is a hack. self.attn_bias = None def __repr__(self) -> str: return ("InputMetadata(" - f"prompt_lens={self.prompt_lens}, " + f"is_prompt={self.is_prompt}, " f"max_context_len={self.max_context_len}, " f"slot_mapping={self.slot_mapping}, " f"context_lens={self.context_lens}, " diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 25ba48c2aa6d6..ebc9afc1be672 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -5,7 +5,7 @@ import torch.nn as nn from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_all_gather) + tensor_model_parallel_gather) from vllm.model_executor.sampling_metadata import SamplingMetadata, SamplingTensors from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import (PromptLogprobs, SampleLogprobs, SamplerOutput, @@ -37,7 +37,7 @@ def forward( hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, embedding_bias: Optional[torch.Tensor] = None, - ) -> SamplerOutput: + ) -> Optional[SamplerOutput]: # Get the hidden states that we use for sampling. hidden_states = _prune_hidden_states(hidden_states, sampling_metadata) @@ -45,6 +45,14 @@ def forward( logits = _get_logits(hidden_states, embedding, embedding_bias, self.vocab_size) + # Only perform sampling in the driver worker. + # Note: `_get_logits` is still distributed across TP workers because + # the `embedding` weight is distributed across TP workers. + # TODO(zhuohan): Change the get_logits part to a separate stage. + if not sampling_metadata.perform_sampling: + return None + + assert logits is not None _, vocab_size = logits.shape # Apply logits processors (if any). @@ -92,14 +100,15 @@ def forward( def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor, embedding_bias: Optional[torch.Tensor], - vocab_size: int) -> torch.Tensor: + vocab_size: int) -> Optional[torch.Tensor]: # Get the logits for the next tokens. logits = torch.matmul(hidden_states, embedding.t()) if embedding_bias is not None: logits += embedding_bias - logits = tensor_model_parallel_all_gather(logits) + logits = tensor_model_parallel_gather(logits) # Remove paddings in vocab (if any). - logits = logits[:, :vocab_size] + if logits is not None: + logits = logits[:, :vocab_size] return logits diff --git a/vllm/model_executor/models/aquila.py b/vllm/model_executor/models/aquila.py index 2a1a0d76801c8..2f2bd5ffb4a63 100644 --- a/vllm/model_executor/models/aquila.py +++ b/vllm/model_executor/models/aquila.py @@ -298,7 +298,7 @@ def sample( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, - ) -> SamplerOutput: + ) -> Optional[SamplerOutput]: next_tokens = self.sampler(self.lm_head.weight, hidden_states, sampling_metadata) return next_tokens diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index cd8ab444677ea..f08c3c8d257ff 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -313,7 +313,7 @@ def sample( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, - ) -> SamplerOutput: + ) -> Optional[SamplerOutput]: next_tokens = self.sampler(self.lm_head.weight, hidden_states, sampling_metadata) return next_tokens diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 6d1aeeed78e93..4adfb6b78102f 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -290,7 +290,7 @@ def sample( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, - ) -> SamplerOutput: + ) -> Optional[SamplerOutput]: next_tokens = self.sampler(self.lm_head_weight, hidden_states, sampling_metadata) return next_tokens diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index aa957b36b36d7..dca8d724f976b 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -349,7 +349,7 @@ def sample( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, - ) -> SamplerOutput: + ) -> Optional[SamplerOutput]: next_tokens = self.sampler(self.lm_head_weight, hidden_states, sampling_metadata) return next_tokens diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 7055d08521c47..2b5e022312e3b 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -394,7 +394,7 @@ def sample( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, - ) -> SamplerOutput: + ) -> Optional[SamplerOutput]: next_tokens = self.sampler(self.lm_head.weight, hidden_states, sampling_metadata) return next_tokens diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index d9b561cd8b225..661da0fe0434e 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -235,7 +235,7 @@ def sample( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, - ) -> SamplerOutput: + ) -> Optional[SamplerOutput]: next_tokens = self.sampler(self.lm_head_weight, hidden_states, sampling_metadata) return next_tokens diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 4d8144bad351f..ef4c1d4143c88 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -254,7 +254,7 @@ def sample( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, - ) -> SamplerOutput: + ) -> Optional[SamplerOutput]: next_tokens = self.sampler(self.lm_head_weight, hidden_states, sampling_metadata) return next_tokens diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index ab3480a77a43d..5bab30d9d442e 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -240,7 +240,7 @@ def sample( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, - ) -> SamplerOutput: + ) -> Optional[SamplerOutput]: next_tokens = self.sampler(self.lm_head.weight, hidden_states, sampling_metadata, self.lm_head.bias) return next_tokens diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index 773fed36a9fb0..8f7e1063e0c1d 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -255,7 +255,7 @@ def sample( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, - ) -> SamplerOutput: + ) -> Optional[SamplerOutput]: next_tokens = self.sampler(self.embed_out.weight, hidden_states, sampling_metadata) return next_tokens diff --git a/vllm/model_executor/models/internlm.py b/vllm/model_executor/models/internlm.py index 00bb70fc3f87f..5d0b93793c89d 100644 --- a/vllm/model_executor/models/internlm.py +++ b/vllm/model_executor/models/internlm.py @@ -255,7 +255,7 @@ def sample( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, - ) -> SamplerOutput: + ) -> Optional[SamplerOutput]: next_tokens = self.sampler(self.lm_head.weight, hidden_states, sampling_metadata) return next_tokens diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index b3b24ea6fea44..3791aa893893a 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -291,7 +291,7 @@ def sample( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, - ) -> SamplerOutput: + ) -> Optional[SamplerOutput]: next_tokens = self.sampler(self.lm_head.weight, hidden_states, sampling_metadata) return next_tokens diff --git a/vllm/model_executor/models/mistral.py b/vllm/model_executor/models/mistral.py index 57230fcced9ff..70d033fec69fc 100644 --- a/vllm/model_executor/models/mistral.py +++ b/vllm/model_executor/models/mistral.py @@ -287,7 +287,7 @@ def sample( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, - ) -> SamplerOutput: + ) -> Optional[SamplerOutput]: next_tokens = self.sampler(self.lm_head.weight, hidden_states, sampling_metadata) return next_tokens diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index e61b401a78a2b..a8dadce24aa1d 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -320,7 +320,7 @@ def forward( positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, - ) -> SamplerOutput: + ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) residual = None for i in range(len(self.layers)): @@ -361,7 +361,7 @@ def sample( self, hidden_states: Optional[torch.Tensor], sampling_metadata: SamplingMetadata, - ) -> SamplerOutput: + ) -> Optional[SamplerOutput]: next_tokens = self.sampler(self.lm_head.weight, hidden_states, sampling_metadata) return next_tokens diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index d6e9a76d2ba42..22a876e2ef691 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -276,7 +276,7 @@ def sample( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, - ) -> SamplerOutput: + ) -> Optional[SamplerOutput]: next_tokens = self.sampler(self.lm_head_weight, hidden_states, sampling_metadata) return next_tokens diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 22d3b5ccadfde..393b2dcabcd5a 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -309,7 +309,7 @@ def sample( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, - ) -> SamplerOutput: + ) -> Optional[SamplerOutput]: next_tokens = self.sampler(self.lm_head_weight, hidden_states, sampling_metadata) return next_tokens diff --git a/vllm/model_executor/models/phi_1_5.py b/vllm/model_executor/models/phi_1_5.py index 9f3c6f68d24e5..9d4424dd08903 100644 --- a/vllm/model_executor/models/phi_1_5.py +++ b/vllm/model_executor/models/phi_1_5.py @@ -280,7 +280,7 @@ def sample( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, - ) -> SamplerOutput: + ) -> Optional[SamplerOutput]: head = self.lm_head.linear next_tokens = self.sampler(head.weight, hidden_states, sampling_metadata, head.bias) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 2d394a6b914c5..fbc7320fb45a4 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -247,7 +247,7 @@ def sample( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, - ) -> SamplerOutput: + ) -> Optional[SamplerOutput]: next_tokens = self.sampler(self.lm_head.weight, hidden_states, sampling_metadata) return next_tokens diff --git a/vllm/model_executor/models/yi.py b/vllm/model_executor/models/yi.py index 78715a8873fce..53daa6c4cd939 100644 --- a/vllm/model_executor/models/yi.py +++ b/vllm/model_executor/models/yi.py @@ -286,7 +286,7 @@ def sample( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, - ) -> SamplerOutput: + ) -> Optional[SamplerOutput]: next_tokens = self.sampler(self.lm_head.weight, hidden_states, sampling_metadata) return next_tokens diff --git a/vllm/model_executor/parallel_utils/communication_op.py b/vllm/model_executor/parallel_utils/communication_op.py index b1d5f5b9fb88e..8bf04f3d1f056 100644 --- a/vllm/model_executor/parallel_utils/communication_op.py +++ b/vllm/model_executor/parallel_utils/communication_op.py @@ -1,6 +1,7 @@ import torch from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, get_tensor_model_parallel_group, ) @@ -45,3 +46,61 @@ def tensor_model_parallel_all_gather(input_, dim=-1): (world_size * input_size[dim], ) + input_size[dim + 1:]) return output_tensor + + +def tensor_model_parallel_gather(input_, dst=0, dim=-1): + """Gather the input tensor across model parallel group. + + NOTE: We assume that the input tensor is on the same device across + all the ranks. + """ + world_size = get_tensor_model_parallel_world_size() + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + assert -input_.dim() <= dim < input_.dim(), ( + f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + # Allocate output tensor. + if get_tensor_model_parallel_rank() == dst: + gather_list = [torch.empty_like(input_) for _ in range(world_size)] + else: + gather_list = None + # Gather. + torch.distributed.gather(input_, + gather_list, + dst=dst, + group=get_tensor_model_parallel_group()) + if get_tensor_model_parallel_rank() == dst: + output_tensor = torch.cat(gather_list, dim=dim) + else: + output_tensor = None + return output_tensor + + +def broadcast(input_, src=0): + """Broadcast the input tensor.""" + world_size = torch.distributed.get_world_size() + assert 0 <= src < world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + # Broadcast. + torch.distributed.broadcast(input_, src=src) + return input_ + + +def broadcast_object_list(obj_list, src=0): + """Broadcast the input object list.""" + world_size = torch.distributed.get_world_size() + assert 0 <= src < world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return obj_list + # Broadcast. + torch.distributed.broadcast_object_list(obj_list, src=src) + return obj_list diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 49013ec273787..2d41d40e04678 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple import torch @@ -18,24 +18,29 @@ class SamplingMetadata: seq_data: Seq_id -> SequenceData. prompt_lens: Lengths of prompts. selected_token_indices: Token indices selected for sampling. - categorized_sample_indices: SamplingType -> token indicies to sample. + categorized_sample_indices: SamplingType -> token indices to sample. + perform_sampling: Whether to perform sampling. This option is used to + make the sampling only happens in the driver worker, and disable + sampling in other worker processes. """ def __init__( self, - seq_groups: List[Tuple[List[int], SamplingParams]], - seq_data: Dict[int, SequenceData], - prompt_lens: List[int], + seq_groups: Optional[List[Tuple[List[int], SamplingParams]]], + seq_data: Optional[Dict[int, SequenceData]], + prompt_lens: Optional[List[int]], selected_token_indices: torch.Tensor, - categorized_sample_indices: Dict[SamplingType, torch.Tensor], + categorized_sample_indices: Optional[Dict[SamplingType, torch.Tensor]], + perform_sampling: bool = True, ) -> None: self.seq_groups = seq_groups self.seq_data = seq_data self.prompt_lens = prompt_lens self.selected_token_indices = selected_token_indices self.categorized_sample_indices = categorized_sample_indices + self.perform_sampling = perform_sampling - self.num_prompts = len(prompt_lens) + self.num_prompts = len(prompt_lens) if prompt_lens is not None else 0 def __repr__(self) -> str: return ( @@ -44,7 +49,8 @@ def __repr__(self) -> str: f"seq_data={self.seq_data}, " f"prompt_lens={self.prompt_lens}, " f"selected_token_indices={self.selected_token_indices}, " - f"categorized_sample_indices={self.categorized_sample_indices})") + f"categorized_sample_indices={self.categorized_sample_indices}), " + f"perform_sampling={self.perform_sampling})") @dataclass diff --git a/vllm/utils.py b/vllm/utils.py index eff5d10fd4ee0..c32047ac27dc6 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1,7 +1,9 @@ import enum +import os import socket import uuid from platform import uname +from typing import List import psutil import torch @@ -55,7 +57,15 @@ def in_wsl() -> bool: return "microsoft" in " ".join(uname()).lower() -def get_open_port(): +def get_ip() -> str: + return socket.gethostbyname(socket.gethostname()) + + +def get_open_port() -> int: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(("", 0)) return s.getsockname()[1] + + +def set_cuda_visible_devices(device_ids: List[int]) -> None: + os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_ids)) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index fb7a0c17d6f9f..be2803089f51b 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1,5 +1,5 @@ import time -from typing import Dict, List, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -8,6 +8,8 @@ from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig from vllm.logger import init_logger from vllm.model_executor import get_model, InputMetadata, SamplingMetadata +from vllm.model_executor.parallel_utils.communication_op import ( + broadcast, broadcast_object_list) from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.utils import in_wsl @@ -28,10 +30,12 @@ def __init__( model_config: ModelConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, + is_driver_worker: bool = False, ): self.model_config = model_config self.parallel_config = parallel_config self.scheduler_config = scheduler_config + self.is_driver_worker = is_driver_worker # model_config can be None in tests/samplers/test_sampler.py. # FIXME(woosuk): This is a hack to make the tests work. Refactor this. @@ -70,7 +74,7 @@ def set_block_size(self, block_size: int) -> None: def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]: + ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int]]: assert len(seq_group_metadata_list) > 0 input_tokens: List[List[int]] = [] input_positions: List[List[int]] = [] @@ -135,14 +139,14 @@ def _prepare_prompt( dtype=torch.long) input_metadata = InputMetadata( - prompt_lens=prompt_lens, + is_prompt=True, slot_mapping=slot_mapping, max_context_len=None, context_lens=None, block_tables=None, use_cuda_graph=False, ) - return input_tokens, input_positions, input_metadata + return input_tokens, input_positions, input_metadata, prompt_lens def _prepare_decode( self, @@ -203,32 +207,24 @@ def _prepare_decode( block_tables.append([]) batch_size = graph_batch_size - # When using CUDA graph, we don't need to make the tensors on the GPU - # because they will be eventually copied to the designated GPU buffer. - device = "cpu" if use_captured_graph else "cuda" - pin_memory = use_captured_graph and not self.in_wsl input_tokens = _make_tensor_with_pad(input_tokens, max_len=1, pad=0, dtype=torch.long, - device=device, - pin_memory=pin_memory) + device="cuda") input_positions = _make_tensor_with_pad(input_positions, max_len=1, pad=0, dtype=torch.long, - device=device, - pin_memory=pin_memory) + device="cuda") slot_mapping = _make_tensor_with_pad(slot_mapping, max_len=1, pad=_PAD_SLOT_ID, dtype=torch.long, - device=device, - pin_memory=pin_memory) + device="cuda") context_lens = torch.tensor(context_lens, dtype=torch.int, - device=device, - pin_memory=pin_memory) + device="cuda") if use_captured_graph: # The shape of graph_block_tables is @@ -237,17 +233,18 @@ def _prepare_decode( for i, block_table in enumerate(block_tables): if block_table: input_block_tables[i, :len(block_table)] = block_table - block_tables = torch.tensor(input_block_tables, device=device) + block_tables = torch.tensor(input_block_tables, device="cuda") else: block_tables = _make_tensor_with_pad( block_tables, max_len=max_context_len, pad=0, dtype=torch.int, + device="cuda", ) input_metadata = InputMetadata( - prompt_lens=[], + is_prompt=False, slot_mapping=slot_mapping, max_context_len=max_context_len, context_lens=context_lens, @@ -326,23 +323,127 @@ def _prepare_sample( ) return sampling_metadata + def prepare_input_tensors( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, SamplingMetadata]: + if self.is_driver_worker: + # NOTE: We assume that all sequences in the group are all prompts or + # all decodes. + is_prompt = seq_group_metadata_list[0].is_prompt + # Prepare input tensors. + if is_prompt: + (input_tokens, input_positions, input_metadata, + prompt_lens) = self._prepare_prompt(seq_group_metadata_list) + else: + (input_tokens, input_positions, input_metadata + ) = self._prepare_decode(seq_group_metadata_list) + prompt_lens = [] + sampling_metadata = self._prepare_sample(seq_group_metadata_list, + prompt_lens) + + def get_size_or_none(x: Optional[torch.Tensor]): + return x.size() if x is not None else None + + # Broadcast the input data. For input tensors, we first broadcast + # its shape and then broadcast the tensor to avoid high + # serialization cost. + py_data = { + "input_tokens_size": + input_tokens.size(), + "input_positions_size": + input_positions.size(), + "is_prompt": + input_metadata.is_prompt, + "slot_mapping_size": + get_size_or_none(input_metadata.slot_mapping), + "max_context_len": + input_metadata.max_context_len, + "context_lens_size": + get_size_or_none(input_metadata.context_lens), + "block_tables_size": + get_size_or_none(input_metadata.block_tables), + "use_cuda_graph": + input_metadata.use_cuda_graph, + "selected_token_indices_size": + sampling_metadata.selected_token_indices.size(), + } + broadcast_object_list([py_data], src=0) + # TODO(zhuohan): Combine the broadcasts or set async_op=True. + broadcast(input_tokens, src=0) + broadcast(input_positions, src=0) + if input_metadata.slot_mapping is not None: + broadcast(input_metadata.slot_mapping, src=0) + if input_metadata.context_lens is not None: + broadcast(input_metadata.context_lens, src=0) + if input_metadata.block_tables is not None: + broadcast(input_metadata.block_tables, src=0) + broadcast(sampling_metadata.selected_token_indices, src=0) + else: + receving_list = [None] + broadcast_object_list(receving_list, src=0) + py_data = receving_list[0] + input_tokens = torch.empty(*py_data["input_tokens_size"], + dtype=torch.long, + device="cuda") + broadcast(input_tokens, src=0) + input_positions = torch.empty(*py_data["input_positions_size"], + dtype=torch.long, + device="cuda") + broadcast(input_positions, src=0) + if py_data["slot_mapping_size"] is not None: + slot_mapping = torch.empty(*py_data["slot_mapping_size"], + dtype=torch.long, + device="cuda") + broadcast(slot_mapping, src=0) + else: + slot_mapping = None + if py_data["context_lens_size"] is not None: + context_lens = torch.empty(*py_data["context_lens_size"], + dtype=torch.int, + device="cuda") + broadcast(context_lens, src=0) + else: + context_lens = None + if py_data["block_tables_size"] is not None: + block_tables = torch.empty(*py_data["block_tables_size"], + dtype=torch.int, + device="cuda") + broadcast(block_tables, src=0) + else: + block_tables = None + selected_token_indices = torch.empty( + *py_data["selected_token_indices_size"], + dtype=torch.long, + device="cuda") + broadcast(selected_token_indices, src=0) + input_metadata = InputMetadata( + is_prompt=py_data["is_prompt"], + slot_mapping=slot_mapping, + max_context_len=py_data["max_context_len"], + context_lens=context_lens, + block_tables=block_tables, + use_cuda_graph=py_data["use_cuda_graph"], + ) + sampling_metadata = SamplingMetadata( + seq_groups=None, + seq_data=None, + prompt_lens=None, + selected_token_indices=selected_token_indices, + categorized_sample_indices=None, + perform_sampling=False, + ) + + return input_tokens, input_positions, input_metadata, sampling_metadata + @torch.inference_mode() def execute_model( self, - seq_group_metadata_list: List[SequenceGroupMetadata], + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], - ) -> SamplerOutput: - # NOTE: We assume that all sequences in the group are all prompts or - # all decodes. - is_prompt = seq_group_metadata_list[0].is_prompt - # Prepare input tensors. - if is_prompt: - inputs = self._prepare_prompt(seq_group_metadata_list) - input_tokens, input_positions, input_metadata = inputs - else: - inputs = self._prepare_decode(seq_group_metadata_list) - input_tokens, input_positions, input_metadata = inputs - + ) -> Optional[SamplerOutput]: + input_tokens, input_positions, input_metadata, sampling_metadata = ( + self.prepare_input_tensors(seq_group_metadata_list)) # Execute the model. if input_metadata.use_cuda_graph: graph_batch_size = input_tokens.shape[0] @@ -356,9 +457,6 @@ def execute_model( input_metadata=input_metadata, ) - sampling_metadata = self._prepare_sample(seq_group_metadata_list, - input_metadata.prompt_lens) - # Sample the next token. output = self.model.sample( hidden_states=hidden_states, @@ -424,7 +522,7 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: for batch_size in reversed(_BATCH_SIZES_TO_CAPTURE): # Create dummy input_metadata. input_metadata = InputMetadata( - prompt_lens=[], + is_prompt=False, slot_mapping=slot_mapping[:batch_size], max_context_len=self.max_context_len_to_capture, context_lens=context_lens[:batch_size], diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 8698b15721507..6c83f708bd9c6 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -8,6 +8,8 @@ from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig) from vllm.model_executor import set_random_seed +from vllm.model_executor.parallel_utils.communication_op import ( + broadcast_object_list) from vllm.model_executor.parallel_utils.parallel_state import ( initialize_model_parallel) from vllm.sequence import SamplerOutput, SequenceGroupMetadata @@ -28,17 +30,23 @@ def __init__( model_config: ModelConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, - rank: Optional[int] = None, - distributed_init_method: Optional[str] = None, + local_rank: int, + rank: int, + distributed_init_method: str, + is_driver_worker: bool = False, ) -> None: self.model_config = model_config self.parallel_config = parallel_config self.scheduler_config = scheduler_config + self.local_rank = local_rank self.rank = rank self.distributed_init_method = distributed_init_method + self.is_driver_worker = is_driver_worker + if self.is_driver_worker: + assert self.rank == 0, "The driver worker must have rank 0." self.model_runner = ModelRunner(model_config, parallel_config, - scheduler_config) + scheduler_config, is_driver_worker) # Uninitialized cache engine. Will be initialized by # self.init_cache_engine(). self.cache_config = None @@ -57,13 +65,7 @@ def init_model(self) -> None: # This env var set by Ray causes exceptions with graph building. os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) - # Env vars will be set by Ray. - self.rank = self.rank if self.rank is not None else int( - os.getenv("RANK", "-1")) - local_rank = int(os.getenv("LOCAL_RANK", "0")) - self.device = torch.device(f"cuda:{local_rank}") - if self.rank < 0: - raise ValueError("Invalid or unspecified rank.") + self.device = torch.device(f"cuda:{self.local_rank}") torch.cuda.set_device(self.device) _check_if_gpu_supports_dtype(self.model_config.dtype) @@ -125,14 +127,12 @@ def warm_up_model(self) -> None: # the model initialization and profiling. set_random_seed(self.model_config.seed) - @torch.inference_mode() - def execute_model( + def cache_swap( self, - seq_group_metadata_list: List[SequenceGroupMetadata], blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int, int], blocks_to_copy: Dict[int, List[int]], - ) -> SamplerOutput: + ) -> None: # Issue cache operations. issued_cache_op = False if blocks_to_swap_in: @@ -152,8 +152,38 @@ def execute_model( if cache_events is not None: for event in cache_events: event.wait() + + @torch.inference_mode() + def execute_model( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None, + blocks_to_swap_in: Optional[Dict[int, int]] = None, + blocks_to_swap_out: Optional[Dict[int, int]] = None, + blocks_to_copy: Optional[Dict[int, List[int]]] = None, + ) -> Optional[SamplerOutput]: + if self.is_driver_worker: + assert seq_group_metadata_list is not None + num_seq_groups = len(seq_group_metadata_list) + assert blocks_to_swap_in is not None + assert blocks_to_swap_out is not None + assert blocks_to_copy is not None + block_swapping_info = [ + blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy + ] + broadcast_object_list([num_seq_groups] + block_swapping_info, + src=0) + else: + # num_seq_groups, blocks_to_swap_in, blocks_to_swap_out, + # blocks_to_copy (4 elements) + recv_data = [None] * 4 + broadcast_object_list(recv_data, src=0) + num_seq_groups = recv_data[0] + block_swapping_info = recv_data[1:] + + self.cache_swap(*block_swapping_info) + # If there is no input, we don't need to execute the model. - if not seq_group_metadata_list: + if num_seq_groups == 0: return {} output = self.model_runner.execute_model(seq_group_metadata_list,