From 95c869ec4e884899371c3cc98387a74f43616d3d Mon Sep 17 00:00:00 2001 From: Murali Andoorveedu <37849411+andoorve@users.noreply.github.com> Date: Tue, 2 Jul 2024 10:58:08 -0700 Subject: [PATCH] [Core] Pipeline Parallel Support (#4412) Signed-off-by: Muralidhar Andoorveedu --- .buildkite/test-pipeline.yaml | 10 + tests/async_engine/test_async_llm_engine.py | 14 +- tests/async_engine/test_openapi_server_ray.py | 4 +- tests/basic_correctness/test_preemption.py | 24 +- tests/distributed/test_comm_ops.py | 20 +- tests/distributed/test_pipeline_parallel.py | 149 ++++++++ .../output_processor/test_multi_step.py | 8 +- tests/entrypoints/openai/test_chat.py | 4 +- tests/entrypoints/openai/test_completion.py | 4 +- tests/entrypoints/openai/test_embedding.py | 4 +- tests/entrypoints/openai/test_models.py | 4 +- tests/entrypoints/openai/test_vision.py | 4 +- tests/spec_decode/utils.py | 6 +- tests/tensorizer_loader/test_tensorizer.py | 4 +- tests/utils.py | 16 +- tests/worker/test_swap.py | 4 +- vllm/config.py | 25 +- vllm/core/block_manager_v1.py | 3 + vllm/core/block_manager_v2.py | 3 + vllm/core/scheduler.py | 13 +- vllm/distributed/parallel_state.py | 50 +-- vllm/distributed/utils.py | 11 +- vllm/engine/async_llm_engine.py | 79 ++++- vllm/engine/llm_engine.py | 65 +++- vllm/engine/output_processor/interfaces.py | 2 +- vllm/engine/output_processor/multi_step.py | 5 +- vllm/engine/output_processor/single_step.py | 20 +- vllm/executor/distributed_gpu_executor.py | 12 +- vllm/executor/executor_base.py | 25 ++ vllm/executor/gpu_executor.py | 3 +- vllm/executor/multiproc_gpu_executor.py | 12 +- vllm/executor/ray_gpu_executor.py | 71 +++- vllm/model_executor/models/arctic.py | 3 +- vllm/model_executor/models/baichuan.py | 3 +- vllm/model_executor/models/bloom.py | 3 +- vllm/model_executor/models/chatglm.py | 3 +- vllm/model_executor/models/commandr.py | 3 +- vllm/model_executor/models/dbrx.py | 3 +- vllm/model_executor/models/deepseek.py | 3 +- vllm/model_executor/models/deepseek_v2.py | 3 +- vllm/model_executor/models/falcon.py | 3 +- vllm/model_executor/models/gemma.py | 3 +- vllm/model_executor/models/gemma2.py | 3 +- vllm/model_executor/models/gpt2.py | 88 +++-- vllm/model_executor/models/gpt_bigcode.py | 3 +- vllm/model_executor/models/gpt_j.py | 3 +- vllm/model_executor/models/gpt_neox.py | 3 +- vllm/model_executor/models/internlm2.py | 3 +- vllm/model_executor/models/jais.py | 3 +- vllm/model_executor/models/llama.py | 101 ++++-- vllm/model_executor/models/llava.py | 4 +- vllm/model_executor/models/llava_next.py | 4 +- vllm/model_executor/models/minicpm.py | 3 +- vllm/model_executor/models/mixtral.py | 3 +- vllm/model_executor/models/mixtral_quant.py | 3 +- vllm/model_executor/models/mpt.py | 3 +- vllm/model_executor/models/olmo.py | 3 +- vllm/model_executor/models/opt.py | 3 +- vllm/model_executor/models/orion.py | 3 +- vllm/model_executor/models/phi.py | 3 +- vllm/model_executor/models/phi3_small.py | 3 +- vllm/model_executor/models/phi3v.py | 11 +- vllm/model_executor/models/qwen.py | 3 +- vllm/model_executor/models/qwen2.py | 3 +- vllm/model_executor/models/qwen2_moe.py | 3 +- vllm/model_executor/models/stablelm.py | 3 +- vllm/model_executor/models/starcoder2.py | 3 +- vllm/model_executor/models/xverse.py | 3 +- vllm/sequence.py | 31 ++ vllm/spec_decode/draft_model_runner.py | 15 +- vllm/worker/cache_engine.py | 4 + vllm/worker/cpu_model_runner.py | 5 +- vllm/worker/cpu_worker.py | 38 +- vllm/worker/embedding_model_runner.py | 9 +- vllm/worker/model_runner.py | 326 +++++++++++------- vllm/worker/model_runner_base.py | 5 +- vllm/worker/neuron_model_runner.py | 5 +- vllm/worker/neuron_worker.py | 2 +- vllm/worker/worker.py | 36 +- vllm/worker/worker_base.py | 40 ++- vllm/worker/xpu_model_runner.py | 5 +- vllm/worker/xpu_worker.py | 4 +- 82 files changed, 1100 insertions(+), 404 deletions(-) create mode 100644 tests/distributed/test_pipeline_parallel.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index d96e3c6d192e2..d127278aaae2d 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -74,6 +74,16 @@ steps: - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py - pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py +- label: Pipeline Parallelism Test + working_dir: "/vllm-workspace/tests" + num_gpus: 4 + commands: + - TP_SIZE=2 PP_SIZE=2 EAGER_MODE=1 CHUNKED_PREFILL=1 pytest -v -s distributed/test_pipeline_parallel.py + - TP_SIZE=2 PP_SIZE=2 EAGER_MODE=1 CHUNKED_PREFILL=0 pytest -v -s distributed/test_pipeline_parallel.py + - PP_SIZE=4 EAGER_MODE=1 CHUNKED_PREFILL=1 pytest -v -s distributed/test_pipeline_parallel.py + - PP_SIZE=4 EAGER_MODE=1 CHUNKED_PREFILL=0 pytest -v -s distributed/test_pipeline_parallel.py + + - label: Engine Test mirror_hardwares: [amd] command: pytest -v -s engine tokenization test_sequence.py test_config.py test_logger.py diff --git a/tests/async_engine/test_async_llm_engine.py b/tests/async_engine/test_async_llm_engine.py index 52d3394a96a13..aa2b6e22208f3 100644 --- a/tests/async_engine/test_async_llm_engine.py +++ b/tests/async_engine/test_async_llm_engine.py @@ -5,6 +5,7 @@ import torch from vllm import SamplingParams +from vllm.config import ParallelConfig from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine from ..utils import wait_for_gpu_memory_to_clear @@ -23,8 +24,11 @@ def __init__(self): self.add_request_calls = 0 self.abort_request_calls = 0 self.request_id = None + # Ugly, remove dependency when possible + self.parallel_config = ParallelConfig(1, 1, False) - async def step_async(self): + async def step_async(self, virtual_engine): + # PP size is 1, ignore virtual engine self.step_calls += 1 return [RequestOutput( request_id=self.request_id)] if self.request_id else [] @@ -32,6 +36,9 @@ async def step_async(self): async def process_model_inputs_async(self, *args, **kwargs): pass + async def stop_remote_worker_execution_loop_async(self): + pass + def generate(self, request_id): self.request_id = request_id @@ -41,6 +48,7 @@ def stop_generating(self): def add_request(self, **kwargs): del kwargs # Unused self.add_request_calls += 1 + print(f'Request calls: {self.add_request_calls}') async def add_request_async(self, **kwargs): self.add_request_calls += 1 @@ -53,6 +61,9 @@ def abort_request(self, request_id): def has_unfinished_requests(self): return self.request_id is not None + def has_unfinished_requests_for_virtual_engine(self, virtual_engine): + return self.request_id is not None + class MockAsyncLLMEngine(AsyncLLMEngine): @@ -76,6 +87,7 @@ async def test_new_requests_event(): engine.engine.generate("2") await asyncio.sleep(0) await asyncio.sleep(0) + await asyncio.sleep(0) assert engine.engine.add_request_calls == 2 assert engine.engine.step_calls >= 2 await asyncio.sleep(0.001) diff --git a/tests/async_engine/test_openapi_server_ray.py b/tests/async_engine/test_openapi_server_ray.py index 332937b874e93..cc05d79e56874 100644 --- a/tests/async_engine/test_openapi_server_ray.py +++ b/tests/async_engine/test_openapi_server_ray.py @@ -4,7 +4,7 @@ # and debugging. import ray -from ..utils import RemoteOpenAIServer +from ..utils import VLLM_PATH, RemoteOpenAIServer # any model with a chat template should work here MODEL_NAME = "facebook/opt-125m" @@ -12,7 +12,7 @@ @pytest.fixture(scope="module") def ray_ctx(): - ray.init() + ray.init(runtime_env={"working_dir": VLLM_PATH}) yield ray.shutdown() diff --git a/tests/basic_correctness/test_preemption.py b/tests/basic_correctness/test_preemption.py index d60cc95d75433..7aed0d5e1fa69 100644 --- a/tests/basic_correctness/test_preemption.py +++ b/tests/basic_correctness/test_preemption.py @@ -56,8 +56,8 @@ def test_chunked_prefill_recompute( max_num_seqs=max_num_seqs, ) as vllm_model: vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt < - ARTIFICIAL_PREEMPTION_MAX_CNT) + assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt + < ARTIFICIAL_PREEMPTION_MAX_CNT) for i in range(len(example_prompts)): hf_output_ids, hf_output_str = hf_outputs[i] @@ -91,10 +91,10 @@ def test_preemption( disable_log_stats=False, ) as vllm_model: vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt < - ARTIFICIAL_PREEMPTION_MAX_CNT) + assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt + < ARTIFICIAL_PREEMPTION_MAX_CNT) total_preemption = ( - vllm_model.model.llm_engine.scheduler.num_cumulative_preemption) + vllm_model.model.llm_engine.scheduler[0].num_cumulative_preemption) check_outputs_equal( outputs_0_lst=hf_outputs, @@ -147,10 +147,10 @@ def test_swap( ) as vllm_model: vllm_outputs = vllm_model.generate_beam_search(example_prompts, beam_width, max_tokens) - assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt < - ARTIFICIAL_PREEMPTION_MAX_CNT) + assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt + < ARTIFICIAL_PREEMPTION_MAX_CNT) total_preemption = ( - vllm_model.model.llm_engine.scheduler.num_cumulative_preemption) + vllm_model.model.llm_engine.scheduler[0].num_cumulative_preemption) for i in range(len(example_prompts)): hf_output_ids, _ = hf_outputs[i] @@ -214,8 +214,8 @@ def test_swap_infeasible( example_prompts, sampling_params=sampling_params, ) - assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt < - ARTIFICIAL_PREEMPTION_MAX_CNT) + assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt + < ARTIFICIAL_PREEMPTION_MAX_CNT) # Verify the request is ignored and not hang. assert req_outputs[0].outputs[0].finish_reason == "length" @@ -252,8 +252,8 @@ def test_preemption_infeasible( sampling_params=sampling_params, ) - assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt < - ARTIFICIAL_PREEMPTION_MAX_CNT) + assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt + < ARTIFICIAL_PREEMPTION_MAX_CNT) # Verify the request is ignored and not hang. for req_output in req_outputs: diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index bf0f31df02fa5..7302d484954f7 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -32,7 +32,7 @@ def all_reduce_test_worker(tp_size: int, pp_size: int, rank: int, (r + 1) for r in range(tp_size) ] expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0) - t = all_tensors[rank] + t = all_tensors[rank % tp_size] t = tensor_model_parallel_all_reduce(t) assert torch.allclose(t, expected) @@ -60,7 +60,7 @@ def all_gather_test_worker(tp_size: int, pp_size: int, rank: int, for r in range(tp_size) ] expected = torch.cat(all_tensors, dim=all_gather_dimension) - t = all_tensors[rank] + t = all_tensors[rank % tp_size] t = tensor_model_parallel_all_gather(t, all_gather_dimension) assert torch.allclose(t, expected) @@ -91,7 +91,7 @@ def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int, "f": torch.tensor([], dtype=torch.float32, device="cuda"), } - if rank == 0: + if (rank % tp_size) == 0: broadcast_tensor_dict(test_dict, src=0) else: recv_dict = broadcast_tensor_dict(src=0) @@ -184,3 +184,17 @@ def test_multi_process_tensor_parallel(tp_size, test_target): "test_target", [send_recv_test_worker, send_recv_tensor_dict_test_worker]) def test_multi_process_pipeline_parallel(pp_size, test_target): multi_process_parallel(1, pp_size, test_target) + + +@pytest.mark.skipif(torch.cuda.device_count() < 4, + reason="Need at least 4 GPUs to run the test.") +@pytest.mark.parametrize("tp_size", [2]) +@pytest.mark.parametrize("pp_size", [2]) +@pytest.mark.parametrize("test_target", [ + send_recv_test_worker, send_recv_tensor_dict_test_worker, + all_reduce_test_worker, all_gather_test_worker, + broadcast_tensor_dict_test_worker +]) +def test_multi_process_tensor_parallel_pipeline_parallel( + tp_size, pp_size, test_target): + multi_process_parallel(tp_size, pp_size, test_target) diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py new file mode 100644 index 0000000000000..6072a2dd71800 --- /dev/null +++ b/tests/distributed/test_pipeline_parallel.py @@ -0,0 +1,149 @@ +import os + +import openai # use the official client for correctness check +import pytest +# using Ray for overall ease of process management, parallel requests, +# and debugging. +import ray + +from ..utils import VLLM_PATH, RemoteOpenAIServer + +# downloading lora to test lora requests + +# any model with a chat template should work here +MODEL_NAME = "meta-llama/Meta-Llama-3-8B" +EAGER_MODE = bool(int(os.getenv("EAGER_MODE", 0))) +CHUNKED_PREFILL = bool(int(os.getenv("CHUNKED_PREFILL", 0))) +TP_SIZE = int(os.getenv("TP_SIZE", 1)) +PP_SIZE = int(os.getenv("PP_SIZE", 1)) + +pytestmark = pytest.mark.asyncio + + +@pytest.fixture(scope="module") +def ray_ctx(): + ray.init(runtime_env={"working_dir": VLLM_PATH}) + yield + ray.shutdown() + + +@pytest.fixture(scope="module") +def server(ray_ctx): + args = [ + "--model", + MODEL_NAME, + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16", + "--pipeline-parallel-size", + str(PP_SIZE), + "--tensor-parallel-size", + str(TP_SIZE), + "--distributed-executor-backend", + "ray", + ] + if CHUNKED_PREFILL: + args += [ + "--enable-chunked-prefill", + ] + if EAGER_MODE: + args += [ + "--enforce-eager", + ] + return RemoteOpenAIServer(args, num_gpus=PP_SIZE * TP_SIZE) + + +@pytest.fixture(scope="module") +def client(server): + return server.get_async_client() + + +async def test_check_models(server, client: openai.AsyncOpenAI): + models = await client.models.list() + models = models.data + served_model = models[0] + assert served_model.id == MODEL_NAME + assert all(model.root == MODEL_NAME for model in models) + + +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_single_completion(server, client: openai.AsyncOpenAI, + model_name: str): + completion = await client.completions.create(model=model_name, + prompt="Hello, my name is", + max_tokens=5, + temperature=0.0) + + assert completion.id is not None + assert completion.choices is not None and len(completion.choices) == 1 + assert completion.choices[0].text is not None and len( + completion.choices[0].text) >= 5 + assert completion.choices[0].finish_reason == "length" + assert completion.usage == openai.types.CompletionUsage( + completion_tokens=5, prompt_tokens=6, total_tokens=11) + + # test using token IDs + completion = await client.completions.create( + model=MODEL_NAME, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + ) + assert completion.choices[0].text is not None and len( + completion.choices[0].text) >= 5 + + +@pytest.mark.parametrize( + # just test 1 lora hereafter + "model_name", + [MODEL_NAME], +) +async def test_batch_completions(server, client: openai.AsyncOpenAI, + model_name: str): + # test simple list + batch = await client.completions.create( + model=model_name, + prompt=["Hello, my name is", "Hello, my name is"], + max_tokens=5, + temperature=0.0, + ) + assert len(batch.choices) == 2 + assert batch.choices[0].text == batch.choices[1].text + + # test n = 2 + batch = await client.completions.create( + model=model_name, + prompt=["Hello, my name is", "Hello, my name is"], + n=2, + max_tokens=5, + temperature=0.0, + extra_body=dict( + # NOTE: this has to be true for n > 1 in vLLM, but not necessary + # for official client. + use_beam_search=True), + ) + assert len(batch.choices) == 4 + assert batch.choices[0].text != batch.choices[ + 1].text, "beam search should be different" + assert batch.choices[0].text == batch.choices[ + 2].text, "two copies of the same prompt should be the same" + assert batch.choices[1].text == batch.choices[ + 3].text, "two copies of the same prompt should be the same" + + # test streaming + batch = await client.completions.create( + model=model_name, + prompt=["Hello, my name is", "Hello, my name is"], + max_tokens=5, + temperature=0.0, + stream=True, + ) + texts = [""] * 2 + async for chunk in batch: + assert len(chunk.choices) == 1 + choice = chunk.choices[0] + texts[choice.index] += choice.text + assert texts[0] == texts[1] diff --git a/tests/engine/output_processor/test_multi_step.py b/tests/engine/output_processor/test_multi_step.py index 4f32a622546f0..88f3fad4c79f8 100644 --- a/tests/engine/output_processor/test_multi_step.py +++ b/tests/engine/output_processor/test_multi_step.py @@ -32,7 +32,7 @@ def test_appends_token_ids(num_new_tokens: int, seq_output_len: int): output_processor = MultiStepOutputProcessor( detokenizer=detokenizer, - scheduler=scheduler, + scheduler=[scheduler], seq_counter=seq_counter, get_tokenizer_for_seq=lambda _: mock_tokenizer(), stop_checker=stop_checker, @@ -86,7 +86,7 @@ def test_respects_max_tokens(num_new_tokens: int, seq_prompt_len: int, output_processor = MultiStepOutputProcessor( detokenizer=detokenizer, - scheduler=scheduler, + scheduler=[scheduler], seq_counter=seq_counter, get_tokenizer_for_seq=lambda _: mock_tokenizer(), stop_checker=stop_checker, @@ -148,7 +148,7 @@ def test_respects_eos_token_id(num_new_tokens: int, seq_prompt_len: int, output_processor = MultiStepOutputProcessor( detokenizer=detokenizer, - scheduler=scheduler, + scheduler=[scheduler], seq_counter=seq_counter, get_tokenizer_for_seq=lambda _: mock_tokenizer(eos_token_id), stop_checker=stop_checker, @@ -215,7 +215,7 @@ def test_ignores_eos_token_id(num_new_tokens: int, seq_prompt_len: int, output_processor = MultiStepOutputProcessor( detokenizer=detokenizer, - scheduler=scheduler, + scheduler=[scheduler], seq_counter=seq_counter, get_tokenizer_for_seq=lambda _: mock_tokenizer(eos_token_id), stop_checker=stop_checker, diff --git a/tests/entrypoints/openai/test_chat.py b/tests/entrypoints/openai/test_chat.py index f4c0af1adfdf9..3e80214f24dc5 100644 --- a/tests/entrypoints/openai/test_chat.py +++ b/tests/entrypoints/openai/test_chat.py @@ -14,7 +14,7 @@ from huggingface_hub import snapshot_download from openai import BadRequestError -from ...utils import RemoteOpenAIServer +from ...utils import VLLM_PATH, RemoteOpenAIServer # any model with a chat template should work here MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" @@ -77,7 +77,7 @@ def zephyr_lora_files(): @pytest.fixture(scope="module") def ray_ctx(): - ray.init() + ray.init(runtime_env={"working_dir": VLLM_PATH}) yield ray.shutdown() diff --git a/tests/entrypoints/openai/test_completion.py b/tests/entrypoints/openai/test_completion.py index b05035713d7be..4fe925495eec8 100644 --- a/tests/entrypoints/openai/test_completion.py +++ b/tests/entrypoints/openai/test_completion.py @@ -16,7 +16,7 @@ from vllm.transformers_utils.tokenizer import get_tokenizer -from ...utils import RemoteOpenAIServer +from ...utils import VLLM_PATH, RemoteOpenAIServer # any model with a chat template should work here MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" @@ -79,7 +79,7 @@ def zephyr_lora_files(): @pytest.fixture(scope="module") def ray_ctx(): - ray.init() + ray.init(runtime_env={"working_dir": VLLM_PATH}) yield ray.shutdown() diff --git a/tests/entrypoints/openai/test_embedding.py b/tests/entrypoints/openai/test_embedding.py index 7c7232dbccaa7..f8aa1c9143a3b 100644 --- a/tests/entrypoints/openai/test_embedding.py +++ b/tests/entrypoints/openai/test_embedding.py @@ -5,14 +5,14 @@ import pytest import ray -from ...utils import RemoteOpenAIServer +from ...utils import VLLM_PATH, RemoteOpenAIServer EMBEDDING_MODEL_NAME = "intfloat/e5-mistral-7b-instruct" @pytest.fixture(scope="module") def ray_ctx(): - ray.init() + ray.init(runtime_env={"working_dir": VLLM_PATH}) yield ray.shutdown() diff --git a/tests/entrypoints/openai/test_models.py b/tests/entrypoints/openai/test_models.py index fddfd7550483a..914ef6e19e109 100644 --- a/tests/entrypoints/openai/test_models.py +++ b/tests/entrypoints/openai/test_models.py @@ -6,7 +6,7 @@ # downloading lora to test lora requests from huggingface_hub import snapshot_download -from ...utils import RemoteOpenAIServer +from ...utils import VLLM_PATH, RemoteOpenAIServer # any model with a chat template should work here MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" @@ -22,7 +22,7 @@ def zephyr_lora_files(): @pytest.fixture(scope="module") def ray_ctx(): - ray.init() + ray.init(runtime_env={"working_dir": VLLM_PATH}) yield ray.shutdown() diff --git a/tests/entrypoints/openai/test_vision.py b/tests/entrypoints/openai/test_vision.py index a7f7fdae8d16c..7200b94f841a3 100644 --- a/tests/entrypoints/openai/test_vision.py +++ b/tests/entrypoints/openai/test_vision.py @@ -24,13 +24,13 @@ @pytest.fixture(scope="module") def ray_ctx(): - ray.init() + ray.init(runtime_env={"working_dir": VLLM_PATH}) yield ray.shutdown() @pytest.fixture(scope="module") -def server(): +def server(ray_ctx): return RemoteOpenAIServer([ "--model", MODEL_NAME, diff --git a/tests/spec_decode/utils.py b/tests/spec_decode/utils.py index 68802f0b8468d..86148291ae6ff 100644 --- a/tests/spec_decode/utils.py +++ b/tests/spec_decode/utils.py @@ -54,9 +54,9 @@ def new_execute_model(*args, **kwargs): return new_execute_model -def zero_kv_cache(cache_engine: CacheEngine): - assert cache_engine.gpu_cache - for key_blocks, value_blocks in cache_engine.gpu_cache: +def zero_kv_cache(cache_engine: List[CacheEngine]): + assert cache_engine[0].gpu_cache + for key_blocks, value_blocks in cache_engine[0].gpu_cache: key_blocks.zero_() value_blocks.zero_() diff --git a/tests/tensorizer_loader/test_tensorizer.py b/tests/tensorizer_loader/test_tensorizer.py index c8f86133f41ac..b2ebcc15cd0fc 100644 --- a/tests/tensorizer_loader/test_tensorizer.py +++ b/tests/tensorizer_loader/test_tensorizer.py @@ -22,7 +22,7 @@ tensorize_vllm_model) from ..conftest import VllmRunner, cleanup -from ..utils import RemoteOpenAIServer +from ..utils import VLLM_PATH, RemoteOpenAIServer # yapf conflicts with isort for this docstring @@ -220,6 +220,8 @@ def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path): json.dumps(model_loader_extra_config), ] + ray.init(runtime_env={"working_dir": VLLM_PATH}) + server = RemoteOpenAIServer(openai_args) print("Server ready.") diff --git a/tests/utils.py b/tests/utils.py index 09107b5e7e2b7..ad4d097b0e8ed 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -49,7 +49,6 @@ class RemoteOpenAIServer: DUMMY_API_KEY = "token-abc123" # vLLM's OpenAI server does not need API key MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds - @ray.remote(num_gpus=1) class _RemoteRunner: def __init__(self, cli_args: List[str], *, wait_url: str, @@ -92,7 +91,11 @@ def __del__(self): if hasattr(self, "proc"): self.proc.terminate() - def __init__(self, cli_args: List[str], *, auto_port: bool = True) -> None: + def __init__(self, + cli_args: List[str], + *, + auto_port: bool = True, + num_gpus: int = 1) -> None: if auto_port: if "-p" in cli_args or "--port" in cli_args: raise ValueError("You have manually specified the port" @@ -105,10 +108,11 @@ def __init__(self, cli_args: List[str], *, auto_port: bool = True) -> None: self.host = str(args.host or 'localhost') self.port = int(args.port) - self._runner = self._RemoteRunner.remote( # type: ignore - cli_args, - wait_url=self.url_for("health"), - wait_timeout=self.MAX_SERVER_START_WAIT_S) + self._runner = ray.remote(num_gpus=num_gpus)( + self._RemoteRunner).remote( + cli_args, + wait_url=self.url_for("health"), + wait_timeout=self.MAX_SERVER_START_WAIT_S) self._wait_until_ready() diff --git a/tests/worker/test_swap.py b/tests/worker/test_swap.py index d941ffdb5588a..7aa439ba0a154 100644 --- a/tests/worker/test_swap.py +++ b/tests/worker/test_swap.py @@ -39,8 +39,8 @@ def test_swap() -> None: num_cpu_blocks=engine_config.cache_config.num_cpu_blocks) # Randomly initialize the cache. - gpu_cache = worker.cache_engine.gpu_cache - cpu_cache = worker.cache_engine.cpu_cache + gpu_cache = worker.cache_engine[0].gpu_cache + cpu_cache = worker.cache_engine[0].cpu_cache num_layers = len(gpu_cache) for i in range(num_layers): gpu_key_cache, gpu_value_cache = gpu_cache[i] diff --git a/vllm/config.py b/vllm/config.py index 66338cb0d8825..9a7e0ea7a3a10 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -27,6 +27,17 @@ _GB = 1 << 30 _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768 +_PP_SUPPORTED_MODELS = [ + "AquilaModel", + "AquilaForCausalLM", + "InternLMForCausalLM", + "LlamaForCausalLM", + "LLaMAForCausalLM", + "MistralForCausalLM", + "Phi3ForCausalLM", + "GPT2LMHeadModel", +] + class ModelConfig: """Configuration for the model. @@ -258,6 +269,13 @@ def verify_with_parallel_config( total_num_hidden_layers = getattr(self.hf_text_config, "num_hidden_layers", 0) pipeline_parallel_size = parallel_config.pipeline_parallel_size + architectures = getattr(self.hf_config, "architectures", []) + if not all(arch in _PP_SUPPORTED_MODELS + for arch in architectures) and pipeline_parallel_size > 1: + raise NotImplementedError( + "Pipeline parallelism is only supported for the following " + f" architectures: {_PP_SUPPORTED_MODELS}.") + if total_num_hidden_layers % pipeline_parallel_size != 0: raise ValueError( f"Total number of hidden layers ({total_num_hidden_layers}) " @@ -665,9 +683,10 @@ def __init__( self._verify_args() def _verify_args(self) -> None: - if self.pipeline_parallel_size > 1: - raise NotImplementedError( - "Pipeline parallelism is not supported yet.") + if (self.pipeline_parallel_size > 1 + and self.distributed_executor_backend == "mp"): + raise NotImplementedError("Pipeline parallelism is not supported " + "yet with multiprocessing.") if self.distributed_executor_backend not in ("ray", "mp", None): raise ValueError( "Unrecognized distributed executor backend. Supported values " diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 995ea04a5b3d6..e29eba375f4dd 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -471,6 +471,9 @@ def append_slots( def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: # NOTE: fork does not allocate a new physical block. # Thus, it is always safe from OOM. + if parent_seq.seq_id not in self.block_tables: + # Parent sequence has either been freed or never existed. + return src_block_table = self.block_tables[parent_seq.seq_id] self.block_tables[child_seq.seq_id] = src_block_table.copy() # When using a sliding window, blocks will be eventually reused. diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index 6a6eebc39c58e..b48ea1b19b82a 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -317,6 +317,9 @@ def get_common_computed_block_ids( computed_seq_block_ids) # type: ignore def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: + if parent_seq.seq_id not in self.block_tables: + # Parent sequence has either been freed or never existed. + return src_block_table = self.block_tables[parent_seq.seq_id] self.block_tables[child_seq.seq_id] = src_block_table.fork() diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 48c34625c08ae..5fb3b78141b12 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -256,6 +256,7 @@ def __init__( scheduler_config: SchedulerConfig, cache_config: CacheConfig, lora_config: Optional[LoRAConfig], + pipeline_parallel_size: int = 1, ) -> None: self.scheduler_config = scheduler_config self.cache_config = cache_config @@ -273,11 +274,19 @@ def __init__( BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class( version) + num_gpu_blocks = cache_config.num_gpu_blocks + if num_gpu_blocks: + num_gpu_blocks //= pipeline_parallel_size + + num_cpu_blocks = cache_config.num_cpu_blocks + if num_cpu_blocks: + num_cpu_blocks //= pipeline_parallel_size + # Create the block space manager. self.block_manager = BlockSpaceManagerImpl( block_size=self.cache_config.block_size, - num_gpu_blocks=self.cache_config.num_gpu_blocks, - num_cpu_blocks=self.cache_config.num_cpu_blocks, + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks, sliding_window=self.cache_config.sliding_window, enable_caching=self.cache_config.enable_prefix_caching) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 4ebb8703e0f44..faf9177adc8d3 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -416,7 +416,7 @@ def send_object(self, obj: Any, dst: int) -> None: assert dst < self.world_size, f"Invalid dst rank ({dst})" - assert dst != self.rank, ( + assert dst != self.rank_in_group, ( "Invalid destination rank. Destination rank is the same " "as the current rank.") @@ -446,7 +446,7 @@ def recv_object(self, src: int) -> Any: assert src < self.world_size, f"Invalid src rank ({src})" - assert src != self.rank, ( + assert src != self.rank_in_group, ( "Invalid source rank. Source rank is the same as the current rank." ) @@ -454,7 +454,7 @@ def recv_object(self, src: int) -> Any: # Receive object size rank_size = torch.distributed.recv(size_tensor, - src=src, + src=self.ranks[src], group=self.cpu_group) # Tensor to receive serialized objects into. @@ -464,7 +464,7 @@ def recv_object(self, src: int) -> Any: device="cpu") rank_object = torch.distributed.recv(object_tensor, - src=src, + src=self.ranks[src], group=self.cpu_group) assert rank_object == rank_size, ( @@ -491,10 +491,9 @@ def broadcast_tensor_dict( group = self.device_group metadata_group = self.cpu_group assert src < self.world_size, f"Invalid src rank ({src})" - src = self.ranks[src] - rank = self.rank - if rank == src: + rank_in_group = self.rank_in_group + if rank_in_group == src: metadata_list: List[Tuple[Any, Any]] = [] assert isinstance( tensor_dict, @@ -512,13 +511,13 @@ def broadcast_tensor_dict( if tensor.is_cpu: # use metadata_group for CPU tensors handle = torch.distributed.broadcast(tensor, - src=src, + src=self.ranks[src], group=metadata_group, async_op=True) else: # use group for GPU tensors handle = torch.distributed.broadcast(tensor, - src=src, + src=self.ranks[src], group=group, async_op=True) async_handles.append(handle) @@ -542,15 +541,16 @@ def broadcast_tensor_dict( # use metadata_group for CPU tensors handle = torch.distributed.broadcast( tensor, - src=src, + src=self.ranks[src], group=metadata_group, async_op=True) else: # use group for GPU tensors - handle = torch.distributed.broadcast(tensor, - src=src, - group=group, - async_op=True) + handle = torch.distributed.broadcast( + tensor, + src=self.ranks[src], + group=group, + async_op=True) async_handles.append(handle) _update_nested_dict(tensor_dict, key, tensor) else: @@ -575,7 +575,7 @@ def send_tensor_dict( metadata_group = self.cpu_group if dst is None: - dst = self.next_rank + dst = (self.rank_in_group + 1) % self.world_size assert dst < self.world_size, f"Invalid dst rank ({dst})" metadata_list: List[Tuple[Any, Any]] = [] @@ -593,10 +593,14 @@ def send_tensor_dict( continue if tensor.is_cpu: # use metadata_group for CPU tensors - torch.distributed.send(tensor, dst=dst, group=metadata_group) + torch.distributed.send(tensor, + dst=self.ranks[dst], + group=metadata_group) else: # use group for GPU tensors - torch.distributed.send(tensor, dst=dst, group=group) + torch.distributed.send(tensor, + dst=self.ranks[dst], + group=group) return None def recv_tensor_dict( @@ -614,7 +618,7 @@ def recv_tensor_dict( metadata_group = self.cpu_group if src is None: - src = self.prev_rank + src = (self.rank_in_group - 1) % self.world_size assert src < self.world_size, f"Invalid src rank ({src})" recv_metadata_list = self.recv_object(src=src) @@ -631,11 +635,13 @@ def recv_tensor_dict( if tensor.is_cpu: # use metadata_group for CPU tensors torch.distributed.recv(tensor, - src=src, + src=self.ranks[src], group=metadata_group) else: # use group for GPU tensors - torch.distributed.recv(tensor, src=src, group=group) + torch.distributed.recv(tensor, + src=self.ranks[src], + group=group) _update_nested_dict(tensor_dict, key, tensor) else: _update_nested_dict(tensor_dict, key, value) @@ -654,7 +660,7 @@ def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: """Sends a tensor to the destination rank in a non-blocking way""" """NOTE: `dst` is the local rank of the destination rank.""" if dst is None: - dst = self.next_rank + dst = (self.rank_in_group + 1) % self.world_size pynccl_comm = self.pynccl_comm if pynccl_comm is not None and not pynccl_comm.disabled: @@ -669,7 +675,7 @@ def recv(self, """Receives a tensor from the src rank.""" """NOTE: `src` is the local rank of the destination rank.""" if src is None: - src = self.prev_rank + src = (self.rank_in_group - 1) % self.world_size tensor = torch.empty(size, dtype=dtype, device=self.device) pynccl_comm = self.pynccl_comm diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index 0cd420c8e11b5..4e4206e5893aa 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -2,7 +2,7 @@ # Adapted from # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -from typing import Sequence +from typing import Sequence, Tuple import torch @@ -46,3 +46,12 @@ def split_tensor_along_last_dim( return tuple(chunk.contiguous() for chunk in tensor_list) return tensor_list + + +def get_pp_indices(num_hidden_layers: int, pp_rank: int, + pp_size: int) -> Tuple[int, int]: + layers_per_partition = divide(num_hidden_layers, pp_size) + start_layer = pp_rank * layers_per_partition + end_layer = start_layer + layers_per_partition + + return (start_layer, end_layer) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index f45d72cb7168e..d29bc1b1f6c42 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -211,7 +211,8 @@ class _AsyncLLMEngine(LLMEngine): """Extension of LLMEngine to add async methods.""" async def step_async( - self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: + self, virtual_engine: int + ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: """Performs one decoding iteration and returns newly generated results. The workers are ran asynchronously if possible. @@ -221,7 +222,8 @@ async def step_async( and updates the scheduler with the model outputs. Finally, it decodes the sequences and returns the newly generated results. """ - seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() + seq_group_metadata_list, scheduler_outputs = self.scheduler[ + virtual_engine].schedule() if not scheduler_outputs.is_empty(): # Execute the model. @@ -230,6 +232,7 @@ async def step_async( blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, blocks_to_copy=scheduler_outputs.blocks_to_copy, + virtual_engine=virtual_engine, num_lookahead_slots=scheduler_outputs.num_lookahead_slots, running_queue_size=scheduler_outputs.running_queue_size, ) @@ -248,16 +251,12 @@ async def step_async( # Tracing self.do_tracing(scheduler_outputs) - if not request_outputs: - # Stop the execute model loop in parallel workers until there are - # more requests to process. This avoids waiting indefinitely in - # torch.distributed ops which may otherwise timeout, and unblocks - # the RPC thread in the workers so that they can process any other - # queued control plane messages, such as add/remove lora adapters. - await self.model_executor.stop_remote_worker_execution_loop_async() - return request_outputs + async def stop_remote_worker_execution_loop_async(self) -> None: + """Stop the remote worker execution loop.""" + await self.model_executor.stop_remote_worker_execution_loop_async() + async def process_model_inputs_async( self, request_id: str, @@ -494,7 +493,8 @@ def _init_engine(self, *args, # order of the arguments. cache_config = kwargs["cache_config"] parallel_config = kwargs["parallel_config"] - if parallel_config.tensor_parallel_size == 1: + if (parallel_config.tensor_parallel_size == 1 + and parallel_config.pipeline_parallel_size == 1): num_gpus = cache_config.gpu_memory_utilization else: num_gpus = 1 @@ -502,7 +502,7 @@ def _init_engine(self, *args, self._engine_class).remote return engine_class(*args, **kwargs) - async def engine_step(self) -> bool: + async def engine_step(self, virtual_engine: int) -> bool: """Kick the engine to process the waiting requests. Returns True if there are in-progress requests.""" @@ -533,7 +533,7 @@ async def engine_step(self) -> bool: if self.engine_use_ray: request_outputs = await self.engine.step.remote() # type: ignore else: - request_outputs = await self.engine.step_async() + request_outputs = await self.engine.step_async(virtual_engine) # Put the outputs into the corresponding streams. for request_output in request_outputs: @@ -549,18 +549,65 @@ async def _engine_abort(self, request_ids: Iterable[str]): self.engine.abort_request(request_ids) async def run_engine_loop(self): - has_requests_in_progress = False + if self.engine_use_ray: + pipeline_parallel_size = 1 # type: ignore + else: + pipeline_parallel_size = \ + self.engine.parallel_config.pipeline_parallel_size + has_requests_in_progress = [False] * pipeline_parallel_size while True: - if not has_requests_in_progress: + if not any(has_requests_in_progress): logger.debug("Waiting for new requests...") + # Stop the execute model loop in parallel workers until there + # are more requests to process. This avoids waiting + # indefinitely in torch.distributed ops which may otherwise + # timeout, and unblocks the RPC thread in the workers so that + # they can process any other queued control plane messages, + # such as add/remove lora adapters. + if self.engine_use_ray: + await (self.engine.stop_remote_worker_execution_loop. + remote() # type: ignore + ) + else: + await self.engine.stop_remote_worker_execution_loop_async() await self._request_tracker.wait_for_new_requests() logger.debug("Got new requests!") + requests_in_progress = [ + asyncio.create_task(self.engine_step(ve)) + for ve in range(pipeline_parallel_size) + ] + has_requests_in_progress = [True] * pipeline_parallel_size # Abort if iteration takes too long due to unrecoverable errors # (eg. NCCL timeouts). try: async with asyncio_timeout(ENGINE_ITERATION_TIMEOUT_S): - has_requests_in_progress = await self.engine_step() + done, _ = await asyncio.wait( + requests_in_progress, + return_when=asyncio.FIRST_COMPLETED) + for _ in range(pipeline_parallel_size): + await asyncio.sleep(0) + for task in done: + result = task.result() + virtual_engine = requests_in_progress.index(task) + if self.engine_use_ray: + has_unfinished_requests = ( + await (self.engine. + has_unfinished_requests_for_virtual_engine. + remote( # type: ignore + virtual_engine))) + else: + has_unfinished_requests = ( + self.engine. + has_unfinished_requests_for_virtual_engine( + virtual_engine)) + if result or has_unfinished_requests: + requests_in_progress[virtual_engine] = ( + asyncio.create_task( + self.engine_step(virtual_engine))) + has_requests_in_progress[virtual_engine] = True + else: + has_requests_in_progress[virtual_engine] = False except asyncio.TimeoutError as exc: logger.error( "Engine iteration timed out. This should never happen!") diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index c13b174713423..a790570051491 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -173,6 +173,7 @@ def __init__( "rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, " "trust_remote_code=%s, dtype=%s, max_seq_len=%d, " "download_dir=%r, load_format=%s, tensor_parallel_size=%d, " + "pipeline_parallel_size=%d, " "disable_custom_all_reduce=%s, quantization=%s, " "enforce_eager=%s, kv_cache_dtype=%s, " "quantization_param_path=%s, device_config=%s, " @@ -195,6 +196,7 @@ def __init__( load_config.download_dir, load_config.load_format, parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size, parallel_config.disable_custom_all_reduce, model_config.quantization, model_config.enforce_eager, @@ -296,7 +298,11 @@ def __init__( # Create the scheduler. # NOTE: the cache_config here have been updated with the numbers of # GPU and CPU blocks, which are profiled in the distributed executor. - self.scheduler = Scheduler(scheduler_config, cache_config, lora_config) + self.scheduler = [ + Scheduler(scheduler_config, cache_config, lora_config, + parallel_config.pipeline_parallel_size) + for _ in range(parallel_config.pipeline_parallel_size) + ] # Metric Logging. if self.log_stats: @@ -513,8 +519,16 @@ def _add_processed_request( raise ValueError( "Either SamplingParams or PoolingParams must be provided.") - # Add the sequence group to the scheduler. - self.scheduler.add_seq_group(seq_group) + # Add the sequence group to the scheduler with least unfinished seqs. + costs = [ + scheduler.get_num_unfinished_seq_groups() + for scheduler in self.scheduler + ] + min_cost_scheduler = self.scheduler[costs.index(min(costs))] + min_cost_scheduler.add_seq_group(seq_group) + + def stop_remote_worker_execution_loop(self) -> None: + self.model_executor.stop_remote_worker_execution_loop() def process_model_inputs( self, @@ -684,7 +698,8 @@ def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: >>> # abort the request >>> engine.abort_request(request_id) """ - self.scheduler.abort_seq_group(request_id) + for scheduler in self.scheduler: + scheduler.abort_seq_group(request_id) def get_model_config(self) -> ModelConfig: """Gets the model configuration.""" @@ -696,11 +711,20 @@ def get_decoding_config(self) -> DecodingConfig: def get_num_unfinished_requests(self) -> int: """Gets the number of unfinished requests.""" - return self.scheduler.get_num_unfinished_seq_groups() + return sum(scheduler.get_num_unfinished_seq_groups() + for scheduler in self.scheduler) def has_unfinished_requests(self) -> bool: """Returns True if there are unfinished requests.""" - return self.scheduler.has_unfinished_seqs() + return any(scheduler.has_unfinished_seqs() + for scheduler in self.scheduler) + + def has_unfinished_requests_for_virtual_engine( + self, virtual_engine: int) -> bool: + """ + Returns True if there are unfinished requests for the virtual engine. + """ + return self.scheduler[virtual_engine].has_unfinished_seqs() def _process_sequence_group_outputs( self, @@ -749,7 +773,8 @@ def _process_model_outputs( self.output_processor.process_outputs(seq_group, outputs) # Free the finished sequence groups. - self.scheduler.free_finished_seq_groups() + for scheduler in self.scheduler: + scheduler.free_finished_seq_groups() # Create the outputs. request_outputs: List[Union[RequestOutput, @@ -815,7 +840,12 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: >>> if not (engine.has_unfinished_requests() or example_inputs): >>> break """ - seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() + if self.parallel_config.pipeline_parallel_size > 1: + raise NotImplementedError( + "Pipeline parallelism is only supported through AsyncLLMEngine " + "as performance will be severely degraded otherwise.") + seq_group_metadata_list, scheduler_outputs = self.scheduler[ + 0].schedule() if not scheduler_outputs.is_empty(): execute_model_req = ExecuteModelRequest( @@ -886,23 +916,28 @@ def _get_stats( # System State # Scheduler State - num_running_sys = len(self.scheduler.running) - num_swapped_sys = len(self.scheduler.swapped) - num_waiting_sys = len(self.scheduler.waiting) + num_running_sys = sum( + len(scheduler.running) for scheduler in self.scheduler) + num_swapped_sys = sum( + len(scheduler.swapped) for scheduler in self.scheduler) + num_waiting_sys = sum( + len(scheduler.waiting) for scheduler in self.scheduler) # KV Cache Usage in % num_total_gpu = self.cache_config.num_gpu_blocks gpu_cache_usage_sys = 0. if num_total_gpu is not None: - num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks( - ) + num_free_gpu = sum( + scheduler.block_manager.get_num_free_gpu_blocks() + for scheduler in self.scheduler) gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu) num_total_cpu = self.cache_config.num_cpu_blocks cpu_cache_usage_sys = 0. if num_total_cpu is not None and num_total_cpu > 0: - num_free_cpu = self.scheduler.block_manager.get_num_free_cpu_blocks( - ) + num_free_cpu = sum( + scheduler.block_manager.get_num_free_cpu_blocks() + for scheduler in self.scheduler) cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu) # Iteration stats diff --git a/vllm/engine/output_processor/interfaces.py b/vllm/engine/output_processor/interfaces.py index 9ddb6a3648b8c..92aecebe6ec38 100644 --- a/vllm/engine/output_processor/interfaces.py +++ b/vllm/engine/output_processor/interfaces.py @@ -27,7 +27,7 @@ class SequenceGroupOutputProcessor(ABC): def create_output_processor( scheduler_config: SchedulerConfig, detokenizer: Detokenizer, - scheduler: Scheduler, + scheduler: List[Scheduler], seq_counter: Counter, get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer], stop_checker: "StopChecker", diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index 8512ff83e41cc..25d15df9f915d 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -34,7 +34,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor): def __init__( self, detokenizer: Detokenizer, - scheduler: Scheduler, + scheduler: List[Scheduler], seq_counter: Counter, get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer], stop_checker: StopChecker, @@ -141,4 +141,5 @@ def _process_seq_outputs(self, seq: Sequence, break if seq.is_finished(): - self.scheduler.free_seq(seq) + for scheduler in self.scheduler: + scheduler.free_seq(seq) diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py index 07a68c65a6dd8..fa672e1feda92 100644 --- a/vllm/engine/output_processor/single_step.py +++ b/vllm/engine/output_processor/single_step.py @@ -33,7 +33,7 @@ def __init__( self, scheduler_config: SchedulerConfig, detokenizer: Detokenizer, - scheduler: Scheduler, + scheduler: List[Scheduler], seq_counter: Counter, stop_checker: StopChecker, ): @@ -95,7 +95,8 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # not be used in the future iterations. parent.status = SequenceStatus.FINISHED_ABORTED seq_group.remove(parent.seq_id) - self.scheduler.free_seq(parent) + for scheduler in self.scheduler: + scheduler.free_seq(parent) continue # Fork the parent sequence if there are multiple child samples. for child_sample in child_samples[:-1]: @@ -133,7 +134,8 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, if seq is not parent: seq_group.add(seq) if not seq.is_finished(): - self.scheduler.fork_seq(parent, seq) + for scheduler in self.scheduler: + scheduler.fork_seq(parent, seq) # Free the finished and selected parent sequences' memory in block # manager. Keep them in the sequence group as candidate output. @@ -141,7 +143,8 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # old sequences. for seq, parent in child_seqs: if seq is parent and seq.is_finished(): - self.scheduler.free_seq(seq) + for scheduler in self.scheduler: + scheduler.free_seq(seq) return # Beam search case @@ -226,13 +229,15 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, if seq is not parent: seq_group.add(seq) if not seq.is_finished(): - self.scheduler.fork_seq(parent, seq) + for scheduler in self.scheduler: + scheduler.fork_seq(parent, seq) # Free the finished and selected parent sequences' memory in block # manager. Keep them in the sequence group as candidate output. for seq, parent in selected_child_seqs: if seq is parent and seq.is_finished(): - self.scheduler.free_seq(seq) + for scheduler in self.scheduler: + scheduler.free_seq(seq) # Remove the unselected parent sequences from the sequence group and # free their memory in block manager. @@ -241,7 +246,8 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # Remove the parent sequence if it is not selected for next # iteration seq_group.remove(seq.seq_id) - self.scheduler.free_seq(seq) + for scheduler in self.scheduler: + scheduler.free_seq(seq) def _check_beam_search_early_stopping( self, diff --git a/vllm/executor/distributed_gpu_executor.py b/vllm/executor/distributed_gpu_executor.py index d8693e636ac85..3db82eb1fe790 100644 --- a/vllm/executor/distributed_gpu_executor.py +++ b/vllm/executor/distributed_gpu_executor.py @@ -69,7 +69,7 @@ def execute_model( if self.parallel_worker_tasks is None: self.parallel_worker_tasks = self._run_workers( "start_worker_execution_loop", - async_run_remote_workers_only=True, + async_run_tensor_parallel_workers_only=True, **self.extra_execute_model_run_workers_kwargs) # Only the driver worker returns the sampling results. @@ -138,17 +138,17 @@ def _run_workers( self, method: str, *args, - async_run_remote_workers_only: bool = False, + async_run_tensor_parallel_workers_only: bool = False, max_concurrent_workers: Optional[int] = None, **kwargs, ) -> Any: """Runs the given method on all workers. Args: - async_run_remote_workers_only: If True the method will be run only - in the remote workers, not the driver worker. It will also be - run asynchronously and return a list of futures rather than - blocking on the results. + async_run_tensor_parallel_workers_only: If True the method will be + run only in the remote TP workers, not the driver worker. + It will also be run asynchronously and return a list of futures + rather than blocking on the results. """ raise NotImplementedError diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index d7c19622e270a..9018c329510c9 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -1,3 +1,4 @@ +import asyncio from abc import ABC, abstractmethod from typing import List, Optional, Set, Tuple @@ -110,6 +111,30 @@ def __del__(self): class ExecutorAsyncBase(ExecutorBase): + def __init__( + self, + model_config: ModelConfig, + cache_config: CacheConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + speculative_config: Optional[SpeculativeConfig], + ) -> None: + # This locks each pipeline parallel stage so multiple virtual engines + # can't execute on the same stage at the same time + self.pp_locks = [ + asyncio.Lock() + for _ in range(parallel_config.pipeline_parallel_size) + ] + + super().__init__(model_config, cache_config, parallel_config, + scheduler_config, device_config, load_config, + lora_config, vision_language_config, + speculative_config) + @abstractmethod async def execute_model_async( self, diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 5522b5322e66c..c2910ccdcdb7b 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -45,7 +45,8 @@ def _get_worker_kwargs( lora_config=self.lora_config, vision_language_config=self.vision_language_config, speculative_config=self.speculative_config, - is_driver_worker=rank == 0, + is_driver_worker=(not self.parallel_config) + or (rank % self.parallel_config.tensor_parallel_size == 0), ) def _create_worker(self, diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py index 6aebb4702889a..5bfeac0cf027e 100644 --- a/vllm/executor/multiproc_gpu_executor.py +++ b/vllm/executor/multiproc_gpu_executor.py @@ -91,17 +91,17 @@ def _run_workers( self, method: str, *args, - async_run_remote_workers_only: bool = False, + async_run_tensor_parallel_workers_only: bool = False, max_concurrent_workers: Optional[int] = None, **kwargs, ) -> Any: """Runs the given method on all workers. Args: - async_run_remote_workers_only: If True the method will be run only - in the remote workers, not the driver worker. It will also be - run asynchronously and return a list of futures rather than - blocking on the results. + async_run_tensor_parallel_workers_only: If True the method will be + run only in the remote TP workers, not the driver worker. + It will also be run asynchronously and return a list of futures + rather than blocking on the results. """ if max_concurrent_workers: @@ -114,7 +114,7 @@ def _run_workers( for worker in self.workers ] - if async_run_remote_workers_only: + if async_run_tensor_parallel_workers_only: # Just return futures return worker_outputs diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index faa500c2d79ca..e742d11bb3e62 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -62,7 +62,8 @@ def _configure_ray_workers_use_nsight(self, def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwargs): - if self.parallel_config.tensor_parallel_size == 1: + if (self.parallel_config.tensor_parallel_size == 1 + and self.parallel_config.pipeline_parallel_size == 1): # For single GPU case, we use a ray worker with constrained memory. num_gpus = self.cache_config.gpu_memory_utilization else: @@ -189,6 +190,26 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", max_concurrent_workers=self.parallel_config. max_parallel_loading_workers) + # This is the list of workers that are rank 0 of each TP group EXCEPT + # global rank 0. These are the workers that will broadcast to the + # rest of the workers. + self.tp_driver_workers: List[RayWorkerWrapper] = [] + # This is the list of workers that are not drivers and not the first + # worker in a TP group. These are the workers that will be + # broadcasted to. + self.non_driver_workers: List[RayWorkerWrapper] = [] + + for pp_rank in range(self.parallel_config.pipeline_parallel_size): + for tp_rank in range(self.parallel_config.tensor_parallel_size): + rank = (pp_rank * + self.parallel_config.tensor_parallel_size) + tp_rank + if rank == 0: + pass + elif rank % self.parallel_config.tensor_parallel_size == 0: + self.tp_driver_workers.append(self.workers[rank - 1]) + else: + self.non_driver_workers.append(self.workers[rank - 1]) + def _driver_execute_model( self, execute_model_req: Optional[ExecuteModelRequest] ) -> Optional[List[SamplerOutput]]: @@ -204,7 +225,7 @@ def _run_workers( self, method: str, *args, - async_run_remote_workers_only: bool = False, + async_run_tensor_parallel_workers_only: bool = False, all_args: Optional[List[Tuple[Any, ...]]] = None, all_kwargs: Optional[List[Dict[str, Any]]] = None, use_dummy_driver: bool = False, @@ -215,10 +236,11 @@ def _run_workers( """Runs the given method on all workers. Can be used in the following ways: - - async_run_remote_workers_only: If True the method will be run only - in the remote workers, not the driver worker. It will also be - run asynchronously and return a list of futures rather than blocking - on the results. + Args: + - async_run_tensor_parallel_workers_only: If True the method will be + run only in the remote TP workers, not the driver worker. + It will also be run asynchronously and return a list of futures + rather than blocking on the results. - args/kwargs: All workers share the same args/kwargs - all_args/all_kwargs: args/kwargs for each worker are specified individually @@ -228,7 +250,9 @@ def _run_workers( raise NotImplementedError( "max_concurrent_workers is not supported yet.") - count = len(self.workers) + count = len(self.workers) if not \ + async_run_tensor_parallel_workers_only \ + else len(self.non_driver_workers) all_worker_args = repeat(args, count) if all_args is None \ else islice(all_args, 1, None) all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \ @@ -242,14 +266,17 @@ def _run_workers( ray_worker_outputs = [] else: # Start the ray workers first. + ray_workers = self.workers + if async_run_tensor_parallel_workers_only: + ray_workers = self.non_driver_workers ray_worker_outputs = [ worker.execute_method.remote(method, *worker_args, **worker_kwargs) for (worker, worker_args, worker_kwargs - ) in zip(self.workers, all_worker_args, all_worker_kwargs) + ) in zip(ray_workers, all_worker_args, all_worker_kwargs) ] - if async_run_remote_workers_only: + if async_run_tensor_parallel_workers_only: # Just return futures return ray_worker_outputs @@ -319,12 +346,32 @@ async def _driver_execute_model_async( self, execute_model_req: Optional[ExecuteModelRequest] = None ) -> List[SamplerOutput]: - return await self.driver_exec_method("execute_model", - execute_model_req) + + async def _run_task_with_lock(task, lock, *args, **kwargs): + async with lock: + return await task(*args, **kwargs) + + tasks = [] + tasks.append( + asyncio.create_task( + _run_task_with_lock(self.driver_exec_method, self.pp_locks[0], + "execute_model", execute_model_req))) + for pp_rank, driver_worker in enumerate(self.tp_driver_workers, + start=1): + tasks.append( + asyncio.create_task( + _run_task_with_lock(driver_worker.execute_method.remote, + self.pp_locks[pp_rank], + "execute_model", execute_model_req))) + + results = await asyncio.gather(*tasks) + + # Only the last PP stage has the final results. + return results[-1] async def _start_worker_execution_loop(self): coros = [ worker.execute_method.remote("start_worker_execution_loop") - for worker in self.workers + for worker in self.non_driver_workers ] return await asyncio.gather(*coros) diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index 5777611079c66..fec52e0168851 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -29,7 +29,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.transformers_utils.configs.arctic import ArcticConfig logger = init_logger(__name__) @@ -426,6 +426,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata) diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 5cf5a199b7690..ddc4e908451af 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -43,7 +43,7 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput from .interfaces import SupportsLoRA @@ -338,6 +338,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata) diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index a29aee4cffb7d..8387c8e37bdd3 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -39,7 +39,7 @@ VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: @@ -286,6 +286,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, attn_metadata) diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 5b5a69447e0b8..e6012a6d4e784 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -25,7 +25,7 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.transformers_utils.configs import ChatGLMConfig from .interfaces import SupportsLoRA @@ -365,6 +365,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, attn_metadata) diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index 600c2990b3691..2961f421eb6fc 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -46,7 +46,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput @torch.compile @@ -353,6 +353,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata) diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index 59af42445f323..210cf61652661 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -23,7 +23,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.transformers_utils.configs.dbrx import DbrxConfig @@ -381,6 +381,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, attn_metadata) diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index 8fbda2638aaa3..e9ceca9b18c35 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -48,7 +48,7 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput class DeepseekMLP(nn.Module): @@ -387,6 +387,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 3d4f78c664776..3cf62afd9b4ac 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -48,7 +48,7 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput class DeepseekV2MLP(nn.Module): @@ -475,6 +475,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata) diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 9618652f70d23..89b0bbf014dea 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -44,7 +44,7 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.transformers_utils.configs import RWConfig FalconConfig = Union[HF_FalconConfig, RWConfig] @@ -410,6 +410,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.transformer( input_ids, diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index efefb34814c90..0a5a7ed3d04e4 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -39,7 +39,7 @@ VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput from .interfaces import SupportsLoRA @@ -339,6 +339,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata) diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index 4e35a9ec34069..1f921c8bd0953 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -37,7 +37,7 @@ VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.utils import print_warning_once from .interfaces import SupportsLoRA @@ -338,6 +338,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata) diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index cc83f6eb6d94d..55f2e27410dd7 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -17,7 +17,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GPT-2 model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -25,7 +25,9 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed.parallel_state import ( + get_pp_group, get_tensor_model_parallel_world_size) +from vllm.distributed.utils import get_pp_indices from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, @@ -38,7 +40,7 @@ VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput class GPT2Attention(nn.Module): @@ -181,10 +183,18 @@ def __init__( self.embed_dim = config.hidden_size self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) - self.h = nn.ModuleList([ - GPT2Block(config, cache_config, quant_config) - for _ in range(config.num_hidden_layers) - ]) + self.start_layer, self.end_layer = get_pp_indices( + config.num_hidden_layers, + get_pp_group().rank_in_group, + get_pp_group().world_size) + self.h = nn.ModuleList( + [nn.Identity() for _ in range(self.start_layer)] + [ + GPT2Block(config, cache_config, quant_config) + for _ in range(self.start_layer, self.end_layer) + ] + [ + nn.Identity() + for _ in range(self.end_layer, config.num_hidden_layers) + ]) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) def forward( @@ -193,14 +203,24 @@ def forward( position_ids: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - inputs_embeds = self.wte(input_ids) - position_embeds = self.wpe(position_ids) - hidden_states = inputs_embeds + position_embeds + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] - for i in range(len(self.h)): + for i in range(self.start_layer, self.end_layer): layer = self.h[i] - hidden_states = layer(hidden_states, kv_caches[i], attn_metadata) + hidden_states = layer(hidden_states, + kv_caches[i - self.start_layer], + attn_metadata) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.ln_f(hidden_states) return hidden_states @@ -228,9 +248,10 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, @@ -247,6 +268,16 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + }) + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters(remove_duplicate=False)) for name, loaded_weight in weights: @@ -260,16 +291,19 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue if not name.startswith("transformer."): name = "transformer." + name - param = params_dict[name] - # The HF's GPT-2 implementation uses Conv1D instead of Linear. - # Because of this, we need to transpose the weights. - # Note(zhuohan): the logic below might break quantized models. - for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]: - if conv1d_weight_name not in name: - continue - if not name.endswith(".weight"): - continue - loaded_weight = loaded_weight.t() - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + try: + param = params_dict[name] + # The HF's GPT-2 implementation uses Conv1D instead of Linear. + # Because of this, we need to transpose the weights. + # Note(zhuohan): the logic below might break quantized models. + for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]: + if conv1d_weight_name not in name: + continue + if not name.endswith(".weight"): + continue + loaded_weight = loaded_weight.t() + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + except KeyError: + continue diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 17bbe4e312fc3..7d0bf39c58f42 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -39,7 +39,7 @@ VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput from .interfaces import SupportsLoRA @@ -273,6 +273,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, attn_metadata) diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index 47fd5788a4c35..de7f86af709e8 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -38,7 +38,7 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput class GPTJAttention(nn.Module): @@ -239,6 +239,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, attn_metadata) diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index eb0fcc8f26a58..3658b8fbf057e 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -38,7 +38,7 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput class GPTNeoXAttention(nn.Module): @@ -251,6 +251,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.gpt_neox(input_ids, positions, kv_caches, attn_metadata) diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index e75c567f589c8..283bc064b596c 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -22,7 +22,7 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput class InternLM2MLP(nn.Module): @@ -263,6 +263,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: IntermediateTensors, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata) diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index 869b8fc91fd64..2758e2d0b59af 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -40,7 +40,7 @@ VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.transformers_utils.configs import JAISConfig @@ -289,6 +289,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, attn_metadata) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 54d01701f04fb..af75b6bee1041 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -21,7 +21,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only LLaMA model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -29,7 +29,8 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig -from vllm.distributed import (get_tensor_model_parallel_rank, +from vllm.distributed import (get_pp_group, get_pp_indices, + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm @@ -46,7 +47,7 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, kv_cache_scales_loader) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.utils import is_hip, print_warning_once from .interfaces import SupportsLoRA @@ -261,12 +262,20 @@ def __init__( config.hidden_size, org_num_embeddings=config.vocab_size, ) - self.layers = nn.ModuleList([ - LlamaDecoderLayer(config=config, - cache_config=cache_config, - quant_config=quant_config) - for idx in range(config.num_hidden_layers) - ]) + self.start_layer, self.end_layer = get_pp_indices( + config.num_hidden_layers, + get_pp_group().rank_in_group, + get_pp_group().world_size) + self.layers = nn.ModuleList( + [nn.Identity() for _ in range(self.start_layer)] + [ + LlamaDecoderLayer(config=config, + cache_config=cache_config, + quant_config=quant_config) + for _ in range(self.start_layer, self.end_layer) + ] + [ + nn.Identity() + for _ in range(self.end_layer, config.num_hidden_layers) + ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: @@ -278,22 +287,36 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if inputs_embeds is not None: - hidden_states = inputs_embeds + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None else: - hidden_states = self.get_input_embeddings(input_ids) - residual = None - for i in range(len(self.layers)): + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, - kv_caches[i], + kv_caches[i - self.start_layer], attn_metadata, residual, ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -372,10 +395,11 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) - return hidden_states + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + model_output = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + return model_output def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: @@ -391,6 +415,20 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + "residual": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + }) + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) @@ -416,9 +454,12 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) + try: + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + except KeyError: + pass break else: # Skip loading extra bias for GPTQ models. @@ -437,10 +478,13 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue else: name = remapped_kv_scale_name - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + try: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + except KeyError: + pass # If this function is called, it should always initialize KV cache scale # factors (or else raise an exception). Thus, handled exceptions should @@ -452,7 +496,8 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None: quantization_param_path, tp_rank, tp_size, self.config.num_hidden_layers, self.config.__class__.model_type): - layer_self_attn = self.model.layers[layer_idx].self_attn + if not isinstance(self.model.layers[layer_idx], nn.Identity): + layer_self_attn = self.model.layers[layer_idx].self_attn if is_hip(): # The scaling factor convention we are assuming is diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index e0134c5c452de..39c47dddf5070 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -18,7 +18,7 @@ from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput from .clip import dummy_image_for_clip, dummy_seq_data_for_clip from .interfaces import SupportsVision @@ -202,6 +202,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs: object, ) -> SamplerOutput: """Run forward pass for LLaVA-1.5. @@ -247,6 +248,7 @@ def forward( positions, kv_caches, attn_metadata, + None, inputs_embeds=inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 3c0988137f7cf..8b078391b3497 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -22,7 +22,7 @@ from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, get_clip_patch_grid_length) @@ -376,6 +376,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs: object, ) -> SamplerOutput: """Run forward pass for LlaVA-NeXT. @@ -430,6 +431,7 @@ def forward( positions, kv_caches, attn_metadata, + None, inputs_embeds=inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index a76ed049828e7..33020432713fb 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -50,7 +50,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput from .interfaces import SupportsLoRA @@ -462,6 +462,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index a662db6d28d00..05c36b9c03710 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -51,7 +51,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.utils import print_warning_once from .interfaces import SupportsLoRA @@ -536,6 +536,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata) diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index 1894c05e167d6..dde2da20b3b98 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -47,7 +47,7 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput class MixtralMLP(nn.Module): @@ -354,6 +354,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata) diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 5f9e4d86f3cd8..28dc5922cfe9c 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -22,7 +22,7 @@ VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.transformers_utils.configs.mpt import MPTConfig @@ -273,6 +273,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, attn_metadata) diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 39270f71ec46f..53215f32b92a3 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -43,7 +43,7 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput class OlmoAttention(nn.Module): @@ -301,6 +301,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.model( input_ids=input_ids, diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 4bf59105dbabb..d12a51af5a781 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -39,7 +39,7 @@ VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput class OPTLearnedPositionalEmbedding(nn.Embedding): @@ -304,6 +304,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata) diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index 133a10e6bb3e8..a298f0307f3a0 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -26,7 +26,7 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput class OrionMLP(nn.Module): @@ -269,6 +269,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata) diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index 008fceb624f75..cc8e31fe1adb9 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -57,7 +57,7 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput from .interfaces import SupportsLoRA @@ -278,6 +278,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata) diff --git a/vllm/model_executor/models/phi3_small.py b/vllm/model_executor/models/phi3_small.py index 0c5298eb6f100..706ae65201d9f 100644 --- a/vllm/model_executor/models/phi3_small.py +++ b/vllm/model_executor/models/phi3_small.py @@ -21,7 +21,7 @@ DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput def load_column_parallel_weight(param: torch.nn.Parameter, @@ -412,6 +412,7 @@ def forward( positions: Optional[torch.LongTensor], kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: output_hidden_states = self.model( input_ids=input_ids, diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index a16f7f0ea5706..eff4e50294b3a 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -35,7 +35,7 @@ from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput from .clip import dummy_image_for_clip, dummy_seq_data_for_clip from .interfaces import SupportsVision @@ -381,9 +381,13 @@ def _parse_and_validate_image_input( return None - def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, **kwargs: object): + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + **kwargs: object): image_input = self._parse_and_validate_image_input(**kwargs) if image_input is not None: @@ -398,6 +402,7 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, positions, kv_caches, attn_metadata, + intermediate_tensors, inputs_embeds=inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index b6ea6ab396642..408c206c5e1ec 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -27,7 +27,7 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.utils import print_warning_once @@ -245,6 +245,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, attn_metadata) diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index e2d725af63593..3691a3d2e3614 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -45,7 +45,7 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.utils import print_warning_once from .interfaces import SupportsLoRA @@ -331,6 +331,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata) diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 564536f2dd248..b3e7dfef93ece 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -50,7 +50,7 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput class Qwen2MoeMLP(nn.Module): @@ -397,6 +397,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata) diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index a6ed3800bed0f..1098b3031b1e8 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -41,7 +41,7 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput class StablelmMLP(nn.Module): @@ -250,6 +250,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata) diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index 4324bf50d4ad1..6f3d5d51d0315 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -40,7 +40,7 @@ DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput class Starcoder2Attention(nn.Module): @@ -262,6 +262,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata) diff --git a/vllm/model_executor/models/xverse.py b/vllm/model_executor/models/xverse.py index b61721999ca9b..08d3efd3312b9 100644 --- a/vllm/model_executor/models/xverse.py +++ b/vllm/model_executor/models/xverse.py @@ -43,7 +43,7 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput from .interfaces import SupportsLoRA @@ -320,6 +320,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata) diff --git a/vllm/sequence.py b/vllm/sequence.py index 3e7c31b8c1a8d..b036e76d7ccec 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -770,6 +770,34 @@ def __eq__(self, other: object) -> bool: return self.embeddings == other.embeddings +@dataclass +class IntermediateTensors: + """For all pipeline stages except the last, we need to return the hidden + states and residuals to be sent to the next stage. This data structure + contains the hidden states and residuals for a request. + """ + + tensors: Dict[str, torch.Tensor] + + def __getitem__(self, key: Union[str, slice]): + if isinstance(key, str): + return self.tensors[key] + elif isinstance(key, slice): + return self.__class__({k: v[key] for k, v in self.tensors.items()}) + + def __setitem__(self, key: str, value): + self.tensors[key] = value + + def __len__(self): + return len(self.tensors) + + def __eq__(self, other: object): + return isinstance(other, self.__class__) and self + + def __repr__(self) -> str: + return f"IntermediateTensors(tensors={self.tensors})" + + @dataclass class SamplerOutput: """For each sequence group, we generate a list of SequenceOutput object, @@ -896,6 +924,8 @@ class ExecuteModelRequest: blocks_to_swap_out: List[Tuple[int, int]] = field(default_factory=list) # Blocks to copy. Source to dest block. blocks_to_copy: List[Tuple[int, int]] = field(default_factory=list) + # Virtual engine ID for pipeline parallel. + virtual_engine: int = 0 # The number of slots for lookahead decoding. num_lookahead_slots: int = 0 # The number of requests in the running queue. @@ -914,6 +944,7 @@ def clone( blocks_to_swap_in=self.blocks_to_swap_in.copy(), blocks_to_swap_out=self.blocks_to_swap_out.copy(), blocks_to_copy=self.blocks_to_copy.copy(), + virtual_engine=self.virtual_engine, num_lookahead_slots=self.num_lookahead_slots, running_queue_size=self.running_queue_size, previous_hidden_states=self.previous_hidden_states, diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index f30d29376121a..b4c953162e2b4 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -6,7 +6,8 @@ ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) from vllm.logger import init_logger -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import (IntermediateTensors, SamplerOutput, + SequenceGroupMetadata) from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata, ModelRunner) @@ -74,9 +75,9 @@ def __init__( List[SequenceGroupMetadata]] = None def prepare_model_input( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> ModelInputForGPUWithSamplingMetadata: + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0) -> ModelInputForGPUWithSamplingMetadata: """A temporary solution that caches the seq_group_metadata_list for multi-step execution. TODO: In-place update model_input and remove this function. @@ -115,6 +116,7 @@ def execute_model( self, model_input: ModelInputForGPUWithSamplingMetadata, kv_caches: List[torch.Tensor], + intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, ) -> Optional[List[SamplerOutput]]: # Since we do not broadcast data inside execute_model anymore, @@ -130,6 +132,7 @@ def execute_model( self.set_active_loras(model_input.lora_requests, model_input.lora_mapping) + virtual_engine = model_input.virtual_engine outputs: List[SamplerOutput] = [] for step in range(num_steps): # Currently cuda graph is only supported by the decode phase. @@ -139,7 +142,8 @@ def execute_model( if prefill_meta is None and decode_meta.use_cuda_graph: assert model_input.input_tokens is not None graph_batch_size = model_input.input_tokens.shape[0] - model_executable = self.graph_runners[graph_batch_size] + model_executable = ( + self.graph_runners[virtual_engine][graph_batch_size]) else: model_executable = self.model @@ -149,6 +153,7 @@ def execute_model( positions=model_input.input_positions, kv_caches=kv_caches, attn_metadata=model_input.attn_metadata, + intermediate_tensors=intermediate_tensors, **multi_modal_kwargs, ) diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index fbd1343fea19c..891e74f8ab940 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -38,7 +38,11 @@ def __init__( self.block_size = cache_config.block_size self.num_gpu_blocks = cache_config.num_gpu_blocks + if self.num_gpu_blocks: + self.num_gpu_blocks //= parallel_config.pipeline_parallel_size self.num_cpu_blocks = cache_config.num_cpu_blocks + if self.num_cpu_blocks: + self.num_cpu_blocks //= parallel_config.pipeline_parallel_size if cache_config.cache_dtype == "auto": self.dtype = model_config.dtype diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index b83cc6f095bf7..f46e9e8aba9db 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -13,7 +13,8 @@ from vllm.model_executor import SamplingMetadata from vllm.model_executor.model_loader import get_model from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import (IntermediateTensors, SamplerOutput, + SequenceGroupMetadata) from vllm.utils import make_tensor_with_pad from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, @@ -315,6 +316,7 @@ def make_model_input_from_broadcasted_tensor_dict( def prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, ) -> CPUModelInput: multi_modal_kwargs = None # NOTE: We assume that all sequences in the group are all prompts or @@ -351,6 +353,7 @@ def execute_model( self, model_input: CPUModelInput, kv_caches: List[torch.Tensor], + intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, ) -> Optional[List[SamplerOutput]]: if num_steps > 1: diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 30ee262c7a8b3..8089abd690680 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -167,8 +167,8 @@ def __init__( is_driver_worker=is_driver_worker) # Uninitialized cache engine. Will be initialized by # initialize_cache. - self.cache_engine: CPUCacheEngine - self.cpu_cache: List[torch.Tensor] + self.cache_engine: List[CPUCacheEngine] + self.cpu_cache: List[List[torch.Tensor]] def init_device(self) -> None: self.init_distributed_environment() @@ -242,25 +242,32 @@ def _validate_num_cpu_blocks(self, num_cpu_blocks: int) -> None: "initializing the engine.") def _init_cache_engine(self) -> None: - self.cache_engine = CPUCacheEngine(self.cache_config, - self.model_config, - self.parallel_config, - self.device_config) - self.cpu_cache = self.cache_engine.cpu_cache - self.model_runner.block_size = self.cache_engine.block_size - - assert self.cpu_cache is not None + self.cache_engine = [ + CPUCacheEngine(self.cache_config, self.model_config, + self.parallel_config, self.device_config) + for _ in range(self.parallel_config.pipeline_parallel_size) + ] + self.cpu_cache = [ + self.cache_engine[ve].cpu_cache + for ve in range(self.parallel_config.pipeline_parallel_size) + ] + self.model_runner.block_size = self.cache_engine[0].block_size + + assert all( + self.cpu_cache[ve] is not None + for ve in range(self.parallel_config.pipeline_parallel_size)) # Populate the cache to warmup the memory - for layer_cache in self.cpu_cache: - layer_cache.fill_(0) + for ve in range(self.parallel_config.pipeline_parallel_size): + for layer_cache in self.cpu_cache[ve]: + layer_cache.fill_(0) @property def do_metadata_broadcast(self) -> bool: return self.parallel_config.tensor_parallel_size > 1 @property - def kv_cache(self) -> Optional[List[torch.Tensor]]: + def kv_cache(self) -> Optional[List[List[torch.Tensor]]]: return self.cpu_cache def execute_worker( @@ -269,12 +276,14 @@ def execute_worker( ) -> None: if (worker_input.blocks_to_copy is not None and worker_input.blocks_to_copy.numel() > 0): - self.cache_engine.copy(worker_input.blocks_to_copy) + self.cache_engine[worker_input.virtual_engine].copy( + worker_input.blocks_to_copy) @torch.inference_mode() def prepare_worker_input( self, execute_model_req: ExecuteModelRequest) -> WorkerInput: assert execute_model_req is not None + virtual_engine = execute_model_req.virtual_engine num_seq_groups: int = len(execute_model_req.seq_group_metadata_list) blocks_to_copy = execute_model_req.blocks_to_copy blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, @@ -285,6 +294,7 @@ def prepare_worker_input( return WorkerInput( num_seq_groups=num_seq_groups, blocks_to_copy=blocks_to_copy, + virtual_engine=virtual_engine, ) def init_distributed_environment(self) -> None: diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index 272917c7272df..faf6e99ab646f 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -9,7 +9,8 @@ from vllm.logger import init_logger from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.pooling_params import PoolingParams -from vllm.sequence import PoolerOutput, SequenceData, SequenceGroupMetadata +from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData, + SequenceGroupMetadata) from vllm.worker.model_runner import GPUModelRunnerBase, ModelInputForGPU logger = init_logger(__name__) @@ -57,6 +58,7 @@ def execute_model( self, model_input: ModelInputForGPUWithPoolingMetadata, kv_caches: List[torch.Tensor], + intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, ) -> Optional[List[PoolerOutput]]: if num_steps > 1: @@ -73,10 +75,12 @@ def execute_model( assert model_input.attn_metadata is not None prefill_meta = model_input.attn_metadata.prefill_metadata decode_meta = model_input.attn_metadata.decode_metadata + virtual_engine = model_input.virtual_engine if prefill_meta is None and decode_meta.use_cuda_graph: assert model_input.input_tokens is not None graph_batch_size = model_input.input_tokens.shape[0] - model_executable = self.graph_runners[graph_batch_size] + model_executable = self.graph_runners[virtual_engine][ + graph_batch_size] else: model_executable = self.model @@ -115,6 +119,7 @@ def make_model_input_from_broadcasted_tensor_dict( def prepare_model_input( self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + virtual_engine: int = 0, ) -> ModelInputForGPUWithPoolingMetadata: assert seq_group_metadata_list is not None model_input = self._prepare_model_input_tensors( diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 0b20d5010d5ef..28b447c0dc8a9 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -8,6 +8,7 @@ import numpy as np import torch +import torch.distributed import torch.nn as nn try: @@ -25,6 +26,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) +from vllm.distributed import get_pp_group from vllm.distributed.parallel_state import graph_capture from vllm.inputs import INPUT_REGISTRY from vllm.logger import init_logger @@ -37,7 +39,8 @@ from vllm.model_executor.models.interfaces import supports_lora from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sampling_params import SamplingParams -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import (IntermediateTensors, SamplerOutput, + SequenceGroupMetadata) from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip, is_pin_memory_available, make_tensor_with_pad) from vllm.worker.model_runner_base import ( @@ -81,6 +84,7 @@ class ModelInputForGPU(ModelRunnerInputBase): lora_requests: Optional[Set[LoRARequest]] = None attn_metadata: Optional["AttentionMetadata"] = None multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None + virtual_engine: int = 0 def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { @@ -89,6 +93,7 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: "lora_requests": self.lora_requests, "lora_mapping": self.lora_mapping, "multi_modal_kwargs": self.multi_modal_kwargs, + "virtual_engine": self.virtual_engine, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) return tensor_dict @@ -122,6 +127,7 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: "lora_requests": self.lora_requests, "lora_mapping": self.lora_mapping, "multi_modal_kwargs": self.multi_modal_kwargs, + "virtual_engine": self.virtual_engine, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) _add_sampling_metadata_broadcastable_dict(tensor_dict, @@ -179,7 +185,10 @@ def __init__( self.sliding_window = model_config.get_sliding_window() self.block_size = cache_config.block_size self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture - self.graph_runners: Dict[int, CUDAGraphRunner] = {} + + self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [ + {} for _ in range(self.parallel_config.pipeline_parallel_size) + ] self.graph_memory_pool: Optional[Tuple[ int, int]] = None # Set during graph capture. # When using CUDA graph, the input block tables must be padded to @@ -787,9 +796,11 @@ def profile_run(self) -> None: max_num_seqs = min( max_num_seqs, int(max_num_batched_tokens / vlm_config.image_feature_size)) + batch_size = 0 for group_id in range(max_num_seqs): seq_len = (max_num_batched_tokens // max_num_seqs + (group_id < max_num_batched_tokens % max_num_seqs)) + batch_size += seq_len seq_data, dummy_multi_modal_data = INPUT_REGISTRY \ .dummy_data_for_profiling(model_config, seq_len) @@ -811,7 +822,13 @@ def profile_run(self) -> None: num_layers = self.model_config.get_num_layers(self.parallel_config) kv_caches = [None] * num_layers model_input = self.prepare_model_input(seqs) - self.execute_model(model_input, kv_caches) + intermediate_tensors = None + if not get_pp_group().is_first_rank: + intermediate_tensors = self.model.make_empty_intermediate_tensors( + batch_size=batch_size, + dtype=self.model_config.dtype, + device=self.device) + self.execute_model(model_input, kv_caches, intermediate_tensors) torch.cuda.synchronize() return @@ -847,7 +864,7 @@ def list_loras(self) -> Set[int]: return self.lora_manager.list_loras() @torch.inference_mode() - def capture_model(self, kv_caches: List[torch.Tensor]) -> None: + def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: """Cuda graph capture a model. Note that CUDA graph's performance gain is negligible if number @@ -880,10 +897,18 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: slot_mapping.fill_(_PAD_SLOT_ID) seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() block_tables = torch.from_numpy(self.graph_block_tables).cuda() + intermediate_inputs = None + if not get_pp_group().is_first_rank: + intermediate_inputs = self.model.make_empty_intermediate_tensors( + batch_size=max_batch_size, + dtype=self.model_config.dtype, + device=self.device) # Prepare buffer for outputs. These will be reused for all batch sizes. # It will be filled after the first graph capture. - hidden_states: Optional[torch.Tensor] = None + hidden_or_intermediate_states: List[Optional[torch.Tensor]] = [ + None + ] * self.parallel_config.pipeline_parallel_size graph_batch_size = _get_graph_batch_size( self.scheduler_config.max_num_seqs) @@ -912,109 +937,120 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: with graph_capture() as graph_capture_context: # NOTE: Capturing the largest batch size first may help reduce the # memory usage of CUDA graph. - for batch_size in reversed(batch_size_capture_list): - if self.attn_backend.get_name() == "flashinfer": - indptr_buffer = indptr_buffer[:batch_size + 1] - last_page_len_buffer = last_page_len_buffer[:batch_size] - - num_qo_heads = self.model_config.get_num_attention_heads( - self.parallel_config) - num_kv_heads = self.model_config.get_num_kv_heads( - self.parallel_config) - if num_qo_heads // num_kv_heads >= 4: - use_tensor_cores = True + for virtual_engine in range( + self.parallel_config.pipeline_parallel_size): + for batch_size in reversed(batch_size_capture_list): + if self.attn_backend.get_name() == "flashinfer": + indptr_buffer = indptr_buffer[:batch_size + 1] + last_page_len_buffer = last_page_len_buffer[: + batch_size] + + num_qo_heads = ( + self.model_config.get_num_attention_heads( + self.parallel_config)) + num_kv_heads = self.model_config.get_num_kv_heads( + self.parallel_config) + if num_qo_heads // num_kv_heads >= 4: + use_tensor_cores = True + else: + use_tensor_cores = False + decode_wrapper = \ + CUDAGraphBatchDecodeWithPagedKVCacheWrapper( + decode_workspace_buffer, indptr_buffer, + indices_buffer, last_page_len_buffer, "NHD", + use_tensor_cores) + kv_cache_dtype = get_kv_cache_torch_dtype( + self.kv_cache_dtype, self.model_config.dtype) + + paged_kv_indptr_tensor_host = torch.arange( + 0, batch_size + 1, dtype=torch.int32) + paged_kv_indices_tensor_host = torch.arange( + 0, batch_size, dtype=torch.int32) + paged_kv_last_page_len_tensor_host = torch.full( + (batch_size, ), self.block_size, dtype=torch.int32) + query_start_loc_host = torch.arange(0, + batch_size + 1, + dtype=torch.int32) + + attn_metadata = self.attn_backend.make_metadata( + num_prefills=0, + slot_mapping=slot_mapping[:batch_size], + num_prefill_tokens=0, + num_decode_tokens=batch_size, + max_prefill_seq_len=0, + block_tables=block_tables, + paged_kv_indptr=paged_kv_indptr_tensor_host, + paged_kv_indices=paged_kv_indices_tensor_host, + paged_kv_last_page_len= + paged_kv_last_page_len_tensor_host, + num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + head_dim=self.model_config.get_head_size(), + page_size=self.block_size, + seq_start_loc=None, + query_start_loc=query_start_loc_host, + device=self.device, + data_type=kv_cache_dtype, + use_cuda_graph=True, + decode_wrapper=decode_wrapper, + prefill_wrapper=None) + attn_metadata.begin_forward() else: - use_tensor_cores = False - decode_wrapper = \ - CUDAGraphBatchDecodeWithPagedKVCacheWrapper( - decode_workspace_buffer, indptr_buffer, indices_buffer, - last_page_len_buffer, "NHD", use_tensor_cores) - kv_cache_dtype = get_kv_cache_torch_dtype( - self.kv_cache_dtype, self.model_config.dtype) - - paged_kv_indptr_tensor_host = torch.arange( - 0, batch_size + 1, dtype=torch.int32) - paged_kv_indices_tensor_host = torch.arange( - 0, batch_size, dtype=torch.int32) - paged_kv_last_page_len_tensor_host = torch.full( - (batch_size, ), self.block_size, dtype=torch.int32) - query_start_loc_host = torch.arange(0, - batch_size + 1, - dtype=torch.int32) - - attn_metadata = self.attn_backend.make_metadata( - num_prefills=0, - slot_mapping=slot_mapping[:batch_size], - num_prefill_tokens=0, - num_decode_tokens=batch_size, - max_prefill_seq_len=0, - block_tables=block_tables, - paged_kv_indptr=paged_kv_indptr_tensor_host, - paged_kv_indices=paged_kv_indices_tensor_host, - paged_kv_last_page_len= - paged_kv_last_page_len_tensor_host, - num_qo_heads=num_qo_heads, - num_kv_heads=num_kv_heads, - head_dim=self.model_config.get_head_size(), - page_size=self.block_size, - seq_start_loc=None, - query_start_loc=query_start_loc_host, - device=self.device, - data_type=kv_cache_dtype, - use_cuda_graph=True, - decode_wrapper=decode_wrapper, - prefill_wrapper=None) - attn_metadata.begin_forward() - else: - attn_metadata = self.attn_backend.make_metadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=batch_size, - slot_mapping=slot_mapping[:batch_size], - seq_lens=None, - seq_lens_tensor=seq_lens[:batch_size], - max_query_len=None, - max_prefill_seq_len=0, - max_decode_seq_len=self.max_seq_len_to_capture, - query_start_loc=None, - seq_start_loc=None, - context_lens_tensor=None, - block_tables=block_tables[:batch_size], - use_cuda_graph=True, + attn_metadata = self.attn_backend.make_metadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=batch_size, + slot_mapping=slot_mapping[:batch_size], + seq_lens=None, + seq_lens_tensor=seq_lens[:batch_size], + max_query_len=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_seq_len_to_capture, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=block_tables[:batch_size], + use_cuda_graph=True, + ) + + if self.lora_config: + lora_mapping = LoRAMapping( + [0] * batch_size, + [0] * batch_size, + ) + self.set_active_loras(set(), lora_mapping) + + graph_runner = CUDAGraphRunner( + self.model, self.attn_backend.get_name()) + + if self.attn_backend.get_name() == "flashinfer": + graph_runner.flashinfer_indptr_buffer = indptr_buffer + graph_runner.flashinfer_indices_buffer = indices_buffer + graph_runner.flashinfer_last_page_len_buffer = \ + last_page_len_buffer + graph_runner.flashinfer_decode_workspace_buffer = \ + decode_workspace_buffer + graph_runner.flashinfer_decode_wrapper = \ + decode_wrapper + + graph_runner.capture( + input_tokens[:batch_size], + input_positions[:batch_size], + hidden_or_intermediate_states[ + virtual_engine] # type: ignore + [:batch_size] + if hidden_or_intermediate_states[virtual_engine] + is not None else None, + intermediate_inputs[:batch_size] + if intermediate_inputs is not None else None, + kv_caches[virtual_engine], + attn_metadata, + memory_pool=self.graph_memory_pool, + stream=graph_capture_context.stream, ) - - if self.lora_config: - lora_mapping = LoRAMapping( - [0] * batch_size, - [0] * batch_size, - ) - self.set_active_loras(set(), lora_mapping) - - graph_runner = CUDAGraphRunner(self.model, - self.attn_backend.get_name()) - - if self.attn_backend.get_name() == "flashinfer": - graph_runner.flashinfer_indptr_buffer = indptr_buffer - graph_runner.flashinfer_indices_buffer = indices_buffer - graph_runner.flashinfer_last_page_len_buffer = \ - last_page_len_buffer - graph_runner.flashinfer_decode_workspace_buffer = \ - decode_workspace_buffer - graph_runner.flashinfer_decode_wrapper = \ - decode_wrapper - - graph_runner.capture( - input_tokens[:batch_size], - input_positions[:batch_size], - hidden_states[:batch_size] - if hidden_states is not None else None, - kv_caches, - attn_metadata, - memory_pool=self.graph_memory_pool, - stream=graph_capture_context.stream, - ) - self.graph_memory_pool = graph_runner.graph.pool() - self.graph_runners[batch_size] = graph_runner + self.graph_memory_pool = graph_runner.graph.pool() + self.graph_runners[virtual_engine][batch_size] = ( + graph_runner) end_time = time.perf_counter() elapsed_time = end_time - start_time @@ -1047,6 +1083,7 @@ def make_model_input_from_broadcasted_tensor_dict( def prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, ) -> ModelInputForGPUWithSamplingMetadata: """Prepare the model input based on a given sequence group, including metadata for the sampling step. @@ -1072,15 +1109,17 @@ def prepare_model_input( if seq_group_metadata_list else None) return dataclasses.replace(model_input, sampling_metadata=sampling_metadata, - is_prompt=is_prompt) + is_prompt=is_prompt, + virtual_engine=virtual_engine) @torch.inference_mode() def execute_model( self, model_input: ModelInputForGPUWithSamplingMetadata, kv_caches: List[torch.Tensor], + intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, - ) -> Optional[List[SamplerOutput]]: + ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: if num_steps > 1: raise ValueError("num_steps > 1 is not supported in ModelRunner") @@ -1124,27 +1163,34 @@ def execute_model( assert model_input.attn_metadata is not None prefill_meta = model_input.attn_metadata.prefill_metadata decode_meta = model_input.attn_metadata.decode_metadata + # TODO(andoorve): We can remove this once all + # virtual engines share the same kv cache. + virtual_engine = model_input.virtual_engine if prefill_meta is None and decode_meta.use_cuda_graph: assert model_input.input_tokens is not None graph_batch_size = model_input.input_tokens.shape[0] - model_executable = self.graph_runners[graph_batch_size] + model_executable = self.graph_runners[virtual_engine][ + graph_batch_size] else: model_executable = self.model multi_modal_kwargs = model_input.multi_modal_kwargs or {} - hidden_states = model_executable( + hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, kv_caches=kv_caches, attn_metadata=model_input.attn_metadata, + intermediate_tensors=intermediate_tensors, **multi_modal_kwargs, ) - # Compute the logits. - logits = self.model.compute_logits(hidden_states, + # Compute the logits in the last pipeline stage. + if not get_pp_group().is_last_rank: + return hidden_or_intermediate_states + + logits = self.model.compute_logits(hidden_or_intermediate_states, model_input.sampling_metadata) - # Only perform sampling in the driver worker. if not self.is_driver_worker: return [] @@ -1159,9 +1205,12 @@ def execute_model( assert model_input.sampling_metadata is not None indices = model_input.sampling_metadata.selected_token_indices if model_input.is_prompt: - hidden_states = hidden_states.index_select(0, indices) + hidden_states = hidden_or_intermediate_states.index_select( + 0, indices) elif decode_meta.use_cuda_graph: - hidden_states = hidden_states[:len(indices)] + hidden_states = hidden_or_intermediate_states[:len(indices)] + else: + hidden_states = hidden_or_intermediate_states output.hidden_states = hidden_states @@ -1195,13 +1244,15 @@ def capture( self, input_ids: torch.Tensor, positions: torch.Tensor, - hidden_states: Optional[torch.Tensor], + hidden_or_intermediate_states: Optional[Union[IntermediateTensors, + torch.Tensor]], + intermediate_inputs: Optional[IntermediateTensors], kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, memory_pool: Optional[Tuple[int, int]], stream: torch.cuda.Stream, **kwargs, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: assert self._graph is None # Run the model a few times without capturing the graph. # This is to make sure that the captured graph does not include the @@ -1213,6 +1264,7 @@ def capture( positions, kv_caches, attn_metadata, + intermediate_inputs, **kwargs, ) torch.cuda.synchronize() @@ -1220,18 +1272,27 @@ def capture( # Capture the graph. self._graph = torch.cuda.CUDAGraph() with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream): - output_hidden_states = self.model( + output_hidden_or_intermediate_states = self.model( input_ids, positions, kv_caches, attn_metadata, + intermediate_inputs, **kwargs, ) - if hidden_states is not None: - hidden_states.copy_(output_hidden_states) + if hidden_or_intermediate_states is not None: + if get_pp_group().is_last_rank: + hidden_or_intermediate_states.copy_( + output_hidden_or_intermediate_states) + else: + for key in hidden_or_intermediate_states.tensors: + hidden_or_intermediate_states[key].copy_( + output_hidden_or_intermediate_states[key]) else: - hidden_states = output_hidden_states - del output_hidden_states + hidden_or_intermediate_states = ( + output_hidden_or_intermediate_states) + + del output_hidden_or_intermediate_states # make sure `output_hidden_states` is deleted # in the graph's memory pool gc.collect() @@ -1255,8 +1316,15 @@ def capture( attn_metadata.decode_metadata.seq_lens_tensor, "block_tables": attn_metadata.decode_metadata.block_tables, } - self.output_buffers = {"hidden_states": hidden_states} - return hidden_states + if intermediate_inputs is not None: + self.input_buffers.update(intermediate_inputs.tensors) + if get_pp_group().is_last_rank: + self.output_buffers = { + "hidden_states": hidden_or_intermediate_states + } + else: + self.output_buffers = hidden_or_intermediate_states + return hidden_or_intermediate_states def forward( self, @@ -1264,6 +1332,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], **kwargs, ) -> torch.Tensor: # KV caches are fixed tensors, so we don't need to copy them. @@ -1280,11 +1349,18 @@ def forward( non_blocking=True) self.input_buffers["block_tables"].copy_( attn_metadata.decode_metadata.block_tables, non_blocking=True) + if intermediate_tensors is not None: + for key in intermediate_tensors.tensors: + self.input_buffers[key].copy_(intermediate_tensors[key], + non_blocking=True) # Run the graph. self.graph.replay() # Return the output tensor. - return self.output_buffers["hidden_states"] + if get_pp_group().is_last_rank: + return self.output_buffers["hidden_states"] + + return self.output_buffers def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index 959cfc0b9cac5..f66bb466228be 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -5,7 +5,8 @@ import torch -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import (IntermediateTensors, SamplerOutput, + SequenceGroupMetadata) if TYPE_CHECKING: from vllm.attention import AttentionMetadata @@ -137,6 +138,7 @@ def make_model_input_from_broadcasted_tensor_dict( def prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, ) -> T: """ Prepare the inputs to ModelRunnerBase.execute_model from an execution @@ -150,6 +152,7 @@ def execute_model( self, model_input: T, kv_caches: Optional[List[torch.Tensor]], + intermediate_tensors: Optional[IntermediateTensors], num_steps: int = 1, ) -> Optional[List[SamplerOutput]]: """ diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index 2ccf4a50a87bd..ab8e485281293 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -9,7 +9,8 @@ from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.model_loader.neuron import get_neuron_model -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import (IntermediateTensors, SamplerOutput, + SequenceGroupMetadata) from vllm.utils import is_pin_memory_available, make_tensor_with_pad from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase @@ -175,6 +176,7 @@ def make_model_input_from_broadcasted_tensor_dict( def prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, ) -> ModelInputForNeuron: # NOTE: We assume that all sequences in the group are all prompts or # all decodes. @@ -207,6 +209,7 @@ def execute_model( self, model_input: ModelInputForNeuron, kv_caches: Optional[List[torch.Tensor]] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, ) -> Optional[List[SamplerOutput]]: if num_steps > 1: diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py index 307c107ddef71..f7525e049ee30 100644 --- a/vllm/worker/neuron_worker.py +++ b/vllm/worker/neuron_worker.py @@ -80,7 +80,7 @@ def do_metadata_broadcast(self) -> bool: return False @property - def kv_cache(self) -> Optional[List[torch.Tensor]]: + def kv_cache(self) -> Optional[List[List[torch.Tensor]]]: return None @torch.inference_mode() diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index cc27d06b511f5..5b57282909914 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -59,9 +59,9 @@ def __init__( self.lora_config = lora_config self.load_config = load_config self.is_driver_worker = is_driver_worker - if self.is_driver_worker: - assert self.rank == 0, "The driver worker must have rank 0." - + if parallel_config and is_driver_worker: + assert rank % parallel_config.tensor_parallel_size == 0, \ + "Driver worker should be rank 0 of tensor parallel group." if self.model_config.trust_remote_code: # note: lazy import to avoid importing torch before initializing from vllm.utils import init_cached_hf_modules @@ -99,9 +99,9 @@ def __init__( ) # Uninitialized cache engine. Will be initialized by # initialize_cache. - self.cache_engine: CacheEngine + self.cache_engine: List[CacheEngine] # Initialize gpu_cache as embedding models don't initialize kv_caches - self.gpu_cache: Optional[List[torch.tensor]] = None + self.gpu_cache: Optional[List[List[torch.tensor]]] = None def init_device(self) -> None: if self.device_config.device.type == "cuda": @@ -217,10 +217,15 @@ def initialize_cache(self, num_gpu_blocks: int, def _init_cache_engine(self): assert self.cache_config.num_gpu_blocks is not None - self.cache_engine = CacheEngine(self.cache_config, self.model_config, - self.parallel_config, - self.device_config) - self.gpu_cache = self.cache_engine.gpu_cache + self.cache_engine = [ + CacheEngine(self.cache_config, self.model_config, + self.parallel_config, self.device_config) + for _ in range(self.parallel_config.pipeline_parallel_size) + ] + self.gpu_cache = [ + self.cache_engine[ve].gpu_cache + for ve in range(self.parallel_config.pipeline_parallel_size) + ] def _warm_up_model(self) -> None: if not self.model_config.enforce_eager: @@ -234,12 +239,13 @@ def do_metadata_broadcast(self) -> bool: return self.parallel_config.tensor_parallel_size > 1 @property - def kv_cache(self) -> Optional[List[torch.Tensor]]: + def kv_cache(self) -> Optional[List[List[torch.Tensor]]]: return self.gpu_cache @torch.inference_mode() def prepare_worker_input( self, execute_model_req: ExecuteModelRequest) -> WorkerInput: + virtual_engine = execute_model_req.virtual_engine num_seq_groups = len(execute_model_req.seq_group_metadata_list) # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors. # they contain parameters to launch cudamemcpyasync. @@ -261,20 +267,24 @@ def prepare_worker_input( blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, + virtual_engine=virtual_engine, ) @torch.inference_mode() def execute_worker(self, worker_input: WorkerInput) -> None: + virtual_engine = worker_input.virtual_engine # Issue cache operations. if (worker_input.blocks_to_swap_in is not None and worker_input.blocks_to_swap_in.numel() > 0): - self.cache_engine.swap_in(worker_input.blocks_to_swap_in) + self.cache_engine[virtual_engine].swap_in( + worker_input.blocks_to_swap_in) if (worker_input.blocks_to_swap_out is not None and worker_input.blocks_to_swap_out.numel() > 0): - self.cache_engine.swap_out(worker_input.blocks_to_swap_out) + self.cache_engine[virtual_engine].swap_out( + worker_input.blocks_to_swap_out) if (worker_input.blocks_to_copy is not None and worker_input.blocks_to_copy.numel() > 0): - self.cache_engine.copy(worker_input.blocks_to_copy) + self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy) def add_lora(self, lora_request: LoRARequest) -> bool: return self.model_runner.add_lora(lora_request) diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index d867e15bdf82d..118173a4ca94b 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -6,10 +6,11 @@ import torch -from vllm.distributed import broadcast_tensor_dict +from vllm.distributed import broadcast_tensor_dict, get_pp_group from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, + SamplerOutput) from vllm.utils import (enable_trace_function_call_for_thread, is_hip, update_environment_variables) from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase @@ -124,6 +125,7 @@ class WorkerInput: blocks_to_swap_in: Optional[torch.Tensor] = None blocks_to_swap_out: Optional[torch.Tensor] = None blocks_to_copy: Optional[torch.Tensor] = None + virtual_engine: int = 0 @classmethod def from_broadcasted_tensor_dict( @@ -139,6 +141,7 @@ def from_broadcasted_tensor_dict( blocks_to_swap_in=tensor_dict.pop("blocks_to_swap_in"), blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"), blocks_to_copy=tensor_dict.pop("blocks_to_copy"), + virtual_engine=tensor_dict["virtual_engine"], ) def as_broadcastable_tensor_dict( @@ -151,6 +154,7 @@ def as_broadcastable_tensor_dict( "blocks_to_swap_in": self.blocks_to_swap_in, "blocks_to_swap_out": self.blocks_to_swap_out, "blocks_to_copy": self.blocks_to_copy, + "virtual_engine": self.virtual_engine, } return tensor_dict @@ -181,11 +185,13 @@ def do_metadata_broadcast(self) -> bool: @property @abstractmethod - def kv_cache(self) -> Optional[List[torch.Tensor]]: + def kv_cache(self) -> Optional[List[List[torch.Tensor]]]: """ - Get the kv cache to pass to the worker's model runner. Used by the - default `execute_model`. If the worker's model runner does not follow - the ModelRunnerBase interface, then inherit from WorkerBase instead. + Gets the list of kv caches to pass to the worker's model runner. Each + element in the list is a kv cache corresponding to a particular virtual + engine (PP stream). Used by the default `execute_model`. If the worker's + model runner does not follow the ModelRunnerBase interface, then inherit + from WorkerBase instead. """ raise NotImplementedError @@ -227,7 +233,8 @@ def execute_model( execute_model_req=execute_model_req) model_input: ModelRunnerInputBase = ( self.model_runner.prepare_model_input( - execute_model_req.seq_group_metadata_list)) + execute_model_req.seq_group_metadata_list, + execute_model_req.virtual_engine)) num_steps = execute_model_req.num_steps if self.do_metadata_broadcast: @@ -255,8 +262,23 @@ def execute_model( if worker_input.num_seq_groups == 0: return [] - return self.model_runner.execute_model(model_input, self.kv_cache, - num_steps) + intermediate_tensors = None + if not get_pp_group().is_first_rank: + intermediate_tensors = IntermediateTensors( + get_pp_group().recv_tensor_dict()) + + output = self.model_runner.execute_model( + model_input, self.kv_cache[worker_input.virtual_engine] + if self.kv_cache is not None else None, intermediate_tensors, + num_steps) + + if not get_pp_group().is_last_rank: + get_pp_group().send_tensor_dict(output.tensors) + return [None] + + # Worker only supports single-step execution. Wrap the output in a + # list to conform to interface. + return output class WorkerWrapperBase: diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 99fd7da5edda5..73b771c4395f8 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -12,7 +12,8 @@ from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.sampling_params import SamplingParams -from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata +from vllm.sequence import (IntermediateTensors, SamplerOutput, SequenceData, + SequenceGroupMetadata) from vllm.utils import CudaMemoryProfiler, make_tensor_with_pad from vllm.worker.model_runner import AttentionMetadata, SamplingMetadata from vllm.worker.model_runner_base import ( @@ -190,6 +191,7 @@ def make_model_input_from_broadcasted_tensor_dict( def prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, ) -> ModelInputForXPU: multi_modal_input = None if self.is_driver_worker: @@ -334,6 +336,7 @@ def execute_model( self, model_input: ModelInputForXPU, kv_caches: List[torch.Tensor], + intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, ) -> Optional[List[SamplerOutput]]: if num_steps > 1: diff --git a/vllm/worker/xpu_worker.py b/vllm/worker/xpu_worker.py index 773ee9f8159e1..7a51f2b2c729b 100644 --- a/vllm/worker/xpu_worker.py +++ b/vllm/worker/xpu_worker.py @@ -85,8 +85,8 @@ def __init__( ) # Uninitialized cache engine. Will be initialized by # initialize_cache. - self.cache_engine: CacheEngine - self.gpu_cache: List[torch.Tensor] + self.cache_engine: List[CacheEngine] + self.gpu_cache: Optional[List[List[torch.Tensor]]] def init_device(self) -> None: if self.device_config.device.type == "xpu" and is_xpu():