diff --git a/tests/spec_decode/test_multi_step_worker.py b/tests/spec_decode/test_multi_step_worker.py index a33fd71459455..cb2de97a4af94 100644 --- a/tests/spec_decode/test_multi_step_worker.py +++ b/tests/spec_decode/test_multi_step_worker.py @@ -5,13 +5,12 @@ import torch from vllm.model_executor.utils import set_random_seed -from vllm.sequence import SamplerOutput +from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.top1_proposer import Top1Proposer from vllm.worker.worker import Worker from .utils import (assert_logprobs_dict_allclose, create_batch, - create_execute_model_data, create_seq_group_metadata_from_prompts, create_worker, patch_execute_model_with_seeds, zero_kv_cache) @@ -105,31 +104,32 @@ def test_same_output_for_single_step(): final_prompt_lens = [len(prompt) + num_steps for prompt in prompts] - multi_step_execute_model_data = create_execute_model_data( - seq_group_metadata_list=create_seq_group_metadata_from_prompts( - prompts, - num_gpu_blocks, - block_size, - final_prompt_lens=final_prompt_lens)) - - single_step_execute_model_data = create_execute_model_data( - seq_group_metadata_list=create_seq_group_metadata_from_prompts( - prompts, - num_gpu_blocks, - block_size, - final_prompt_lens=final_prompt_lens)) + multi_step_seq_group = create_seq_group_metadata_from_prompts( + prompts, + num_gpu_blocks, + block_size, + final_prompt_lens=final_prompt_lens) zero_kv_cache(multi_step_worker.cache_engine) set_random_seed(seed) actual_output, _ = multi_step_worker.sampler_output( - **multi_step_execute_model_data.to_dict(), sample_len=num_steps) + execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=multi_step_seq_group), + sample_len=num_steps) assert len(actual_output) == num_steps actual_output = actual_output[0] + single_step_seq_group = create_seq_group_metadata_from_prompts( + prompts, + num_gpu_blocks, + block_size, + final_prompt_lens=final_prompt_lens) + zero_kv_cache(worker.cache_engine) set_random_seed(seed) expected_output = worker.execute_model( - **single_step_execute_model_data.to_dict(), )[0] + execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=single_step_seq_group))[0] actual_token_ids = [ output.samples[0].output_token for output in actual_output @@ -193,19 +193,20 @@ def test_same_output_for_multi_step(): worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds) continuations = [[1] for _ in prompts] - execute_model_data = create_execute_model_data( - create_seq_group_metadata_from_prompts( - prompts, - num_gpu_blocks, - block_size, - continuations=continuations, - final_prompt_lens=final_prompt_lens), ) + seq_group_metadata_list = create_seq_group_metadata_from_prompts( + prompts, + num_gpu_blocks, + block_size, + continuations=continuations, + final_prompt_lens=final_prompt_lens) # Run multi-step. zero_kv_cache(multi_step_worker.cache_engine) set_random_seed(seed) multi_step_output, _ = multi_step_worker.sampler_output( - **execute_model_data.to_dict(), sample_len=num_steps) + execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list), + sample_len=num_steps) # Run single-step repeatedly. zero_kv_cache(worker.cache_engine) @@ -215,16 +216,16 @@ def test_same_output_for_multi_step(): for _ in multi_step_output: - execute_model_data = create_execute_model_data( - create_seq_group_metadata_from_prompts( - prompts, - num_gpu_blocks, - block_size, - continuations=continuations, - final_prompt_lens=final_prompt_lens)) + seq_group_metadata_list = create_seq_group_metadata_from_prompts( + prompts, + num_gpu_blocks, + block_size, + continuations=continuations, + final_prompt_lens=final_prompt_lens) single_step_output.extend( - worker.execute_model(**execute_model_data.to_dict(), )) + worker.execute_model(execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list))) # Append output tokens to new sequence data. for i, seq_group_output in enumerate(single_step_output[-1]): @@ -304,12 +305,11 @@ def test_draft_proposals_full_speculation_len(): ) for _ in range(k) ], True - execute_model_data, _, _ = create_batch(batch_size, k) + seq_group_metadata_list, _, _ = create_batch(batch_size, k) - proposals = proposer.get_proposals( - **execute_model_data.to_dict(), - proposal_len=k, - ) + proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=k), ) assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_probs) @@ -340,14 +340,13 @@ def test_draft_proposals_no_speculations(): max_proposal_len=prompt_len + k - 1, ) - execute_model_data, _, _ = create_batch(batch_size, - k, - prompt_len=prompt_len) + seq_group_metadata_list, _, _ = create_batch(batch_size, + k, + prompt_len=prompt_len) - proposals = proposer.get_proposals( - **execute_model_data.to_dict(), - proposal_len=k, - ) + proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=k), ) assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_probs) @@ -409,17 +408,16 @@ def test_draft_proposals_mixed_k(): ) for _ in range(k) ], True - execute_model_data, _, _ = create_batch( + seq_group_metadata_list, _, _ = create_batch( batch_size, k, prompt_len=prompt_len, prev_output_token_len=prev_output_token_len, ) - proposals = proposer.get_proposals( - **execute_model_data.to_dict(), - proposal_len=k, - ) + proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=k), ) assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_probs) diff --git a/tests/spec_decode/test_ngram_worker.py b/tests/spec_decode/test_ngram_worker.py index e7e2e87f599dd..de305c4030aa9 100644 --- a/tests/spec_decode/test_ngram_worker.py +++ b/tests/spec_decode/test_ngram_worker.py @@ -1,10 +1,10 @@ import torch +from vllm.sequence import ExecuteModelRequest from vllm.spec_decode.ngram_worker import NGramWorker from vllm.spec_decode.top1_proposer import Top1Proposer -from .utils import (create_execute_model_data, - create_seq_group_metadata_from_prompts, create_worker) +from .utils import create_seq_group_metadata_from_prompts, create_worker def test_ngram_algo_correctness_for_single_no_match(): @@ -44,17 +44,15 @@ def test_ngram_algo_correctness_for_single_no_match(): proposal_len = 5 final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts] - ngram_sampler_output_data = create_execute_model_data( - seq_group_metadata_list=create_seq_group_metadata_from_prompts( - prompts, - num_gpu_blocks, - block_size, - final_prompt_lens=final_prompt_lens)) - - proposals = proposer.get_proposals( - **ngram_sampler_output_data.to_dict(), - proposal_len=proposal_len, - ) + seq_group_metadata_list = create_seq_group_metadata_from_prompts( + prompts, + num_gpu_blocks, + block_size, + final_prompt_lens=final_prompt_lens) + + proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=proposal_len), ) assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_probs) @@ -113,17 +111,15 @@ def test_ngram_algo_correctness_for_batches_not_match_all(): proposal_len = 5 final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts] - ngram_sampler_output_data = create_execute_model_data( - seq_group_metadata_list=create_seq_group_metadata_from_prompts( - prompts, - num_gpu_blocks, - block_size, - final_prompt_lens=final_prompt_lens)) - - proposals = proposer.get_proposals( - **ngram_sampler_output_data.to_dict(), - proposal_len=proposal_len, - ) + seq_group_metadata_list = create_seq_group_metadata_from_prompts( + prompts, + num_gpu_blocks, + block_size, + final_prompt_lens=final_prompt_lens) + + proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=proposal_len), ) assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_probs) @@ -185,17 +181,15 @@ def test_ngram_algo_correctness_for_batches_match_all(): proposal_len = 5 final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts] - ngram_sampler_output_data = create_execute_model_data( - seq_group_metadata_list=create_seq_group_metadata_from_prompts( - prompts, - num_gpu_blocks, - block_size, - final_prompt_lens=final_prompt_lens)) - - proposals = proposer.get_proposals( - **ngram_sampler_output_data.to_dict(), - proposal_len=proposal_len, - ) + seq_group_metadata_list = create_seq_group_metadata_from_prompts( + prompts, + num_gpu_blocks, + block_size, + final_prompt_lens=final_prompt_lens) + + proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=proposal_len), ) assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_probs) diff --git a/tests/spec_decode/test_spec_decode_worker.py b/tests/spec_decode/test_spec_decode_worker.py index 6763583aa85cc..ef9d32f73d668 100644 --- a/tests/spec_decode/test_spec_decode_worker.py +++ b/tests/spec_decode/test_spec_decode_worker.py @@ -7,7 +7,7 @@ from vllm.model_executor.layers.rejection_sampler import RejectionSampler from vllm.model_executor.utils import set_random_seed -from vllm.sequence import SamplerOutput +from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.metrics import (AsyncMetricsCollector, SpecDecodeWorkerMetrics) @@ -15,8 +15,7 @@ from vllm.spec_decode.spec_decode_worker import (SpecDecodeWorker, split_num_cache_blocks_evenly) -from .utils import (ExecuteModelData, create_batch, create_sampler_output_list, - mock_worker) +from .utils import create_batch, create_sampler_output_list, mock_worker @pytest.mark.parametrize('k', [1, 2, 6]) @@ -36,24 +35,19 @@ def test_correctly_calls_draft_model(k: int, batch_size: int): exception_secret = 'artificial stop' draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret) - execute_model_data, _, _ = create_batch(batch_size, k) + seq_group_metadata_list, _, _ = create_batch(batch_size, k) + execute_model_req = ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k) with pytest.raises(ValueError, match=exception_secret): - worker.execute_model(**execute_model_data.to_dict(), - num_lookahead_slots=k) + worker.execute_model(execute_model_req=execute_model_req) call_args_list = draft_worker.get_spec_proposals.call_args_list assert len(call_args_list) == 1 for args, _ in call_args_list: - (seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out, - blocks_to_copy, actual_k) = args - actual_execute_model_data = ExecuteModelData(seq_group_metadata_list, - blocks_to_swap_in, - blocks_to_swap_out, - blocks_to_copy) - assert actual_execute_model_data == execute_model_data - assert actual_k == k + actual_execute_model_data = args[0] + assert actual_execute_model_data == execute_model_req @pytest.mark.parametrize('k', [1, 2, 6]) @@ -93,7 +87,7 @@ def test_correctly_calls_target_model(k: int, batch_size: int): proposal_lens = torch.ones(batch_size, dtype=torch.int64, device='cuda') * k - execute_model_data, prompts, prev_output_tokens = create_batch( + seq_group_metadata_list, prompts, prev_output_tokens = create_batch( batch_size, k) draft_worker.get_spec_proposals.return_value = SpeculativeProposals( @@ -105,20 +99,20 @@ def test_correctly_calls_target_model(k: int, batch_size: int): target_worker.execute_model.side_effect = ValueError(exception_secret) with pytest.raises(ValueError, match=exception_secret): - worker.execute_model(**execute_model_data.to_dict(), - num_lookahead_slots=k) + worker.execute_model(execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=k)) seen_contexts = [] call_args_list = target_worker.execute_model.call_args_list assert len(call_args_list) == 1 - for args, kwargs in call_args_list: - target_execute_model_data = ExecuteModelData.from_dict(kwargs) + for _, kwargs in call_args_list: + seq_group_metadata_list = kwargs[ + "execute_model_req"].seq_group_metadata_list - assert len(target_execute_model_data.seq_group_metadata_list) == ( - k + 1) * batch_size - for seq_group_metadata in ( - target_execute_model_data.seq_group_metadata_list): + assert len(seq_group_metadata_list) == (k + 1) * batch_size + for seq_group_metadata in seq_group_metadata_list: for seq_data in seq_group_metadata.seq_data.values(): seen_contexts.append(seq_data.get_token_ids()) @@ -175,7 +169,7 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int): proposal_lens = torch.ones(batch_size, dtype=torch.int64, device='cuda') * k - execute_model_data, _, _ = create_batch(batch_size, k) + seq_group_metadata_list, _, _ = create_batch(batch_size, k) draft_worker.get_spec_proposals.return_value = SpeculativeProposals( proposal_token_ids=proposal_token_ids, @@ -207,8 +201,9 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int): rejection_sampler.side_effect = ValueError(exception_secret) with pytest.raises(ValueError, match=exception_secret): - worker.execute_model(**execute_model_data.to_dict(), - num_lookahead_slots=k) + worker.execute_model(execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=k)) assert len(rejection_sampler.call_args_list) == 1 _, kwargs = rejection_sampler.call_args_list[0] @@ -262,7 +257,7 @@ def test_correctly_formats_output(k: int, batch_size: int): proposal_lens = torch.ones(batch_size, dtype=torch.int64, device='cuda') * k - execute_model_data, _, _ = create_batch(batch_size, k) + seq_group_metadata_list, _, _ = create_batch(batch_size, k) draft_worker.get_spec_proposals.return_value = SpeculativeProposals( proposal_token_ids=proposal_token_ids, @@ -302,8 +297,9 @@ def test_correctly_formats_output(k: int, batch_size: int): rejection_sampler.return_value = rejection_sampler_output - output = worker.execute_model(**execute_model_data.to_dict(), - num_lookahead_slots=k) + output = worker.execute_model(execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=k)) expected_output = create_sampler_output_list( token_ids=rejection_sampler_output.transpose(0, 1), @@ -312,7 +308,7 @@ def test_correctly_formats_output(k: int, batch_size: int): seq_ids = [ next(iter(seq_group_metadata.seq_data.keys())) - for seq_group_metadata in execute_model_data.seq_group_metadata_list + for seq_group_metadata in seq_group_metadata_list ] actual_output_by_seq = {seq_id: [] for seq_id in seq_ids} expected_output_by_seq = {seq_id: [] for seq_id in seq_ids} @@ -383,7 +379,7 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool): proposal_lens = torch.ones(batch_size, dtype=torch.int64, device='cuda') * k - execute_model_data, _, _ = create_batch(batch_size, k) + seq_group_metadata_list, _, _ = create_batch(batch_size, k) draft_worker.get_spec_proposals.return_value = SpeculativeProposals( proposal_token_ids=proposal_token_ids, @@ -428,8 +424,9 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool): metrics_collector.maybe_collect_rejsample_metrics.return_value = ( mock_rejsample_metrics) - output = worker.execute_model(**execute_model_data.to_dict(), - num_lookahead_slots=k) + output = worker.execute_model(execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=k)) assert output[0].spec_decode_worker_metrics == mock_rejsample_metrics call_args_list = ( @@ -462,21 +459,21 @@ def test_k_equals_zero(k: int, batch_size: int): worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, metrics_collector) - execute_model_data, prompts, prev_output_tokens = create_batch( - batch_size, k, prev_output_token_len=0) + seq_group_metadata_list, _, _ = create_batch(batch_size, + k, + prev_output_token_len=0) + execute_model_req = ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k) - out = worker.execute_model(**execute_model_data.to_dict(), - num_lookahead_slots=k) + out = worker.execute_model(execute_model_req=execute_model_req) assert len(out) == 1, f"expected only one token output when {k=}" assert out[0].probs is None, "expect gpu tensor references to be None" assert out[ 0].sampled_tokens is None, "expect gpu tensor references to be None" - draft_worker.execute_model.assert_called_once_with( - **execute_model_data.to_dict()) - target_worker.execute_model.assert_called_once_with( - **execute_model_data.to_dict()) + draft_worker.execute_model.assert_called_once_with(execute_model_req) + target_worker.execute_model.assert_called_once_with(execute_model_req) @pytest.mark.parametrize('k', [0, 5]) @@ -503,21 +500,21 @@ def test_empty_input_batch(k: int, batch_size: int): worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, metrics_collector) - execute_model_data, prompts, prev_output_tokens = create_batch( - batch_size, k, prev_output_token_len=0) + seq_group_metadata_list, _, _ = create_batch(batch_size, + k, + prev_output_token_len=0) + execute_model_req = ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k) - out = worker.execute_model(**execute_model_data.to_dict(), - num_lookahead_slots=k) + out = worker.execute_model(execute_model_req=execute_model_req) assert len(out) == 1, f"expected only one token output when {k=}" assert out[0].probs is None, "expect gpu tensor references to be None" assert out[ 0].sampled_tokens is None, "expect gpu tensor references to be None" - draft_worker.execute_model.assert_called_once_with( - **execute_model_data.to_dict()) - target_worker.execute_model.assert_called_once_with( - **execute_model_data.to_dict()) + draft_worker.execute_model.assert_called_once_with(execute_model_req) + target_worker.execute_model.assert_called_once_with(execute_model_req) @pytest.mark.skip_global_cleanup diff --git a/tests/spec_decode/utils.py b/tests/spec_decode/utils.py index f0f0d09106a00..f288652d51556 100644 --- a/tests/spec_decode/utils.py +++ b/tests/spec_decode/utils.py @@ -1,4 +1,3 @@ -from dataclasses import dataclass, fields from itertools import count from typing import Dict, Iterable, List, Optional, Union from unittest.mock import MagicMock @@ -16,50 +15,10 @@ from vllm.worker.worker import Worker -@dataclass -class ExecuteModelData: - """Helper data structure which facilitates cleaner tests. - """ - seq_group_metadata_list: List[SequenceGroupMetadata] - blocks_to_swap_in: Dict[int, int] - blocks_to_swap_out: Dict[int, int] - blocks_to_copy: Dict[int, List[int]] - - def to_dict(self): - return dict( - (field.name, getattr(self, field.name)) for field in fields(self)) - - @classmethod - def from_dict(cls, d): - cleaned = dict((field.name, d[field.name]) for field in fields(cls)) - return cls(**cleaned) - - def round_up_to_next_block(seq_len: int, block_size: int) -> int: return (seq_len + block_size - 1) // block_size -def create_execute_model_data( - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Optional[Dict[int, int]] = None, - blocks_to_swap_out: Optional[Dict[int, int]] = None, - blocks_to_copy: Optional[Dict[int, int]] = None, -) -> ExecuteModelData: - if blocks_to_swap_in is None: - blocks_to_swap_in = {} - if blocks_to_swap_out is None: - blocks_to_swap_out = {} - if blocks_to_copy is None: - blocks_to_copy = {} - - return ExecuteModelData( - seq_group_metadata_list=seq_group_metadata_list, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - ) - - def mock_worker(cls=None, vocab_size: int = 30_000, max_model_len: int = 2048, @@ -258,8 +217,7 @@ def create_batch(batch_size, for prompt, prev_output_token in zip(prompts, prev_output_tokens) ] - execute_model_data = create_execute_model_data( - create_seq_group_metadata_from_prompts(prompts, num_gpu_blocks, - block_size, final_prompt_lens, - prev_output_tokens, seq_ids), ) - return execute_model_data, prompts, prev_output_tokens + seq_group_metadata_list = create_seq_group_metadata_from_prompts( + prompts, num_gpu_blocks, block_size, final_prompt_lens, + prev_output_tokens, seq_ids) + return seq_group_metadata_list, prompts, prev_output_tokens diff --git a/tests/worker/test_swap.py b/tests/worker/test_swap.py index 1804cf78d8003..07bcd343a96a6 100644 --- a/tests/worker/test_swap.py +++ b/tests/worker/test_swap.py @@ -1,6 +1,7 @@ import torch from vllm.engine.arg_utils import EngineArgs +from vllm.sequence import ExecuteModelRequest from vllm.utils import get_distributed_init_method, get_ip, get_open_port from vllm.worker.worker import Worker @@ -54,10 +55,14 @@ def test_swap() -> None: # Test swap out. blocks_to_swap_out = {3: 72, 56: 35, 84: 34} - worker.execute_model(seq_group_metadata_list=[], - blocks_to_swap_in={}, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy={}) + execute_model_req = ExecuteModelRequest( + seq_group_metadata_list=[], + blocks_to_swap_in={}, + blocks_to_swap_out=blocks_to_swap_out, + blocks_to_copy={}, + ) + worker.execute_model(execute_model_req=execute_model_req) + for i in range(num_layers): gpu_key_cache, gpu_value_cache = gpu_cache[i] cpu_key_cache, cpu_value_cache = cpu_cache[i] @@ -66,14 +71,19 @@ def test_swap() -> None: assert allclose(gpu_value_cache[src], cpu_value_cache[dst]) # Test swap in. - blocks_to_swap_in = {19: 45, 67: 23, 12: 78, 40: 99, 1: 71} - worker.execute_model(seq_group_metadata_list=[], - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out={}, - blocks_to_copy={}) + execute_model_req.blocks_to_swap_out = {} + execute_model_req.blocks_to_swap_in = { + 19: 45, + 67: 23, + 12: 78, + 40: 99, + 1: 71 + } + worker.execute_model(execute_model_req=execute_model_req) + for i in range(num_layers): gpu_key_cache, gpu_value_cache = gpu_cache[i] cpu_key_cache, cpu_value_cache = cpu_cache[i] - for src, dst in blocks_to_swap_in.items(): + for src, dst in execute_model_req.blocks_to_swap_in.items(): assert allclose(gpu_key_cache[dst], cpu_key_cache[src]) assert allclose(gpu_value_cache[dst], cpu_value_cache[src]) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 7c55b08d4857d..a9e0b05b8db67 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -128,6 +128,8 @@ class SchedulerOutputs: ignored_seq_groups: List[SequenceGroup] # The number of slots for lookahead decoding. num_lookahead_slots: int + # The number of requests in the running queue + running_queue_size: int def __post_init__(self): # Swap in and swap out should never happen at the same time. @@ -797,6 +799,7 @@ def _schedule_default(self) -> SchedulerOutputs: ignored_seq_groups=prefills.ignored_seq_groups + swapped_in.infeasible_seq_groups, num_lookahead_slots=running_scheduled.num_lookahead_slots, + running_queue_size=len(self.running), ) def _schedule_chunked_prefill(self): @@ -883,6 +886,7 @@ def _schedule_chunked_prefill(self): swapped_in.blocks_to_copy), ignored_seq_groups=prefills.ignored_seq_groups, num_lookahead_slots=running_scheduled.num_lookahead_slots, + running_queue_size=len(self.running), ) def _schedule(self) -> SchedulerOutputs: diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index cf5053bba1d48..9f72a0d11974f 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -16,7 +16,7 @@ from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams -from vllm.sequence import MultiModalData, SamplerOutput +from vllm.sequence import ExecuteModelRequest, MultiModalData, SamplerOutput from vllm.usage.usage_lib import UsageContext logger = init_logger(__name__) @@ -210,12 +210,16 @@ async def step_async(self) -> List[RequestOutput]: if not scheduler_outputs.is_empty(): # Execute the model. + execute_model_req = ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, + blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, + blocks_to_copy=scheduler_outputs.blocks_to_copy, + num_lookahead_slots=scheduler_outputs.num_lookahead_slots, + running_queue_size=scheduler_outputs.running_queue_size, + ) output = await self.model_executor.execute_model_async( - seq_group_metadata_list, - scheduler_outputs.blocks_to_swap_in, - scheduler_outputs.blocks_to_swap_out, - scheduler_outputs.blocks_to_copy, - num_lookahead_slots=scheduler_outputs.num_lookahead_slots) + execute_model_req) else: output = [] diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 94a5b397a4d43..342f2c796d6fb 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -22,8 +22,8 @@ from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams -from vllm.sequence import (MultiModalData, SamplerOutput, Sequence, - SequenceGroup, SequenceGroupMetadata, +from vllm.sequence import (ExecuteModelRequest, MultiModalData, SamplerOutput, + Sequence, SequenceGroup, SequenceGroupMetadata, SequenceStatus) from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup, @@ -583,12 +583,16 @@ def step(self) -> List[RequestOutput]: seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() if not scheduler_outputs.is_empty(): - output = self.model_executor.execute_model( + execute_model_req = ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, blocks_to_copy=scheduler_outputs.blocks_to_copy, - num_lookahead_slots=scheduler_outputs.num_lookahead_slots) + num_lookahead_slots=scheduler_outputs.num_lookahead_slots, + running_queue_size=scheduler_outputs.running_queue_size, + ) + output = self.model_executor.execute_model( + execute_model_req=execute_model_req) else: output = [] diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index 733eef828adc4..a2212459f034e 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Set, Tuple +from typing import List, Set, Tuple import torch @@ -7,7 +7,7 @@ from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, make_async) @@ -72,18 +72,10 @@ def initialize_cache(self, num_gpu_blocks: int, logger.info("# CPU blocks: %d", num_gpu_blocks) self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) - def execute_model(self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - num_lookahead_slots: int) -> List[SamplerOutput]: - output = self.driver_worker.execute_model( - seq_group_metadata_list=seq_group_metadata_list, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - ) + def execute_model( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + output = self.driver_worker.execute_model(execute_model_req) return output def add_lora(self, lora_request: LoRARequest) -> bool: @@ -104,19 +96,10 @@ def check_health(self) -> None: class CPUExecutorAsync(CPUExecutor, ExecutorAsyncBase): async def execute_model_async( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - num_lookahead_slots: int, - ) -> List[SamplerOutput]: - output = await make_async(self.driver_worker.execute_model)( - seq_group_metadata_list=seq_group_metadata_list, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - num_lookahead_slots=num_lookahead_slots) + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + output = await make_async(self.driver_worker.execute_model + )(execute_model_req=execute_model_req, ) return output async def check_health_async(self) -> None: diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 96cd18250bb37..08aa58999b1ec 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -1,11 +1,11 @@ from abc import ABC, abstractmethod -from typing import Dict, List, Optional, Set, Tuple +from typing import List, Optional, Set, Tuple from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, VisionLanguageConfig) from vllm.lora.request import LoRARequest -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import ExecuteModelRequest, SamplerOutput class ExecutorBase(ABC): @@ -68,12 +68,9 @@ def initialize_cache(self, num_gpu_blocks: int, raise NotImplementedError @abstractmethod - def execute_model(self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - num_lookahead_slots: int) -> List[SamplerOutput]: + def execute_model( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: """Executes at least one model step on the given sequences.""" raise NotImplementedError @@ -107,13 +104,8 @@ class ExecutorAsyncBase(ExecutorBase): @abstractmethod async def execute_model_async( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - num_lookahead_slots: int, - ) -> List[SamplerOutput]: + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: """Executes one model step on the given sequences.""" raise NotImplementedError diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index a58856a12f0c8..1af3bcf380843 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -3,7 +3,7 @@ from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, make_async) from vllm.worker.worker_base import WorkerWrapperBase @@ -117,20 +117,9 @@ def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None: self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) def execute_model( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - num_lookahead_slots: int, - ) -> List[SamplerOutput]: - output = self.driver_worker.execute_model( - seq_group_metadata_list=seq_group_metadata_list, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - num_lookahead_slots=num_lookahead_slots, - ) + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + output = self.driver_worker.execute_model(execute_model_req) return output def add_lora(self, lora_request: LoRARequest) -> bool: @@ -154,16 +143,8 @@ class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase): async def execute_model_async( self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - num_lookahead_slots: int, + execute_model_req: ExecuteModelRequest, ) -> List[SamplerOutput]: - output = await make_async(self.driver_worker.execute_model)( - seq_group_metadata_list=seq_group_metadata_list, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - num_lookahead_slots=num_lookahead_slots) + output = await make_async(self.driver_worker.execute_model + )(execute_model_req=execute_model_req, ) return output diff --git a/vllm/executor/neuron_executor.py b/vllm/executor/neuron_executor.py index 8a3b9cde84311..e7f0e887921b7 100644 --- a/vllm/executor/neuron_executor.py +++ b/vllm/executor/neuron_executor.py @@ -1,9 +1,9 @@ -from typing import Dict, List, Set, Tuple +from typing import List, Set, Tuple from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.utils import make_async logger = init_logger(__name__) @@ -45,20 +45,18 @@ def initialize_cache(self, num_gpu_blocks: int, """ self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) - def execute_model(self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - num_lookahead_slots: int) -> List[SamplerOutput]: - assert (blocks_to_swap_in == {} and blocks_to_swap_out == {} - and blocks_to_copy == {}), ( + def execute_model( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + assert (execute_model_req.blocks_to_swap_in == {} + and execute_model_req.blocks_to_swap_out == {} + and execute_model_req.blocks_to_copy == {}), ( "Cache operations are not supported for Neuron backend.") - assert num_lookahead_slots == 0, ( + assert execute_model_req.num_lookahead_slots == 0, ( "lookahead not supported for Neuron backend.") output = self.driver_worker.execute_model( - seq_group_metadata_list=seq_group_metadata_list) + execute_model_req.seq_group_metadata_list) return output def add_lora(self, lora_request: LoRARequest) -> bool: @@ -80,14 +78,11 @@ class NeuronExecutorAsync(NeuronExecutor, ExecutorAsyncBase): async def execute_model_async( self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - num_lookahead_slots: int, + execute_model_req: ExecuteModelRequest, ) -> List[SamplerOutput]: - output = await make_async(self.driver_worker.execute_model)( - seq_group_metadata_list=seq_group_metadata_list, ) + output = await make_async( + self.driver_worker.execute_model + )(seq_group_metadata_list=execute_model_req.seq_group_metadata_list, ) return output async def check_health_async(self) -> None: diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 4684b857ccd39..afc1c886722e6 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -10,7 +10,7 @@ DistributedGPUExecutor, DistributedGPUExecutorAsync) from vllm.executor.ray_utils import RayWorkerWrapper, ray from vllm.logger import init_logger -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, get_vllm_instance_id, make_async) @@ -166,21 +166,12 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", max_concurrent_workers=self.parallel_config. max_parallel_loading_workers) - def execute_model(self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - num_lookahead_slots: int = 0) -> List[SamplerOutput]: + def execute_model( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: all_outputs = self._run_workers( "execute_model", - driver_kwargs={ - "seq_group_metadata_list": seq_group_metadata_list, - "blocks_to_swap_in": blocks_to_swap_in, - "blocks_to_swap_out": blocks_to_swap_out, - "blocks_to_copy": blocks_to_copy, - "num_lookahead_slots": num_lookahead_slots, - }, + driver_kwargs={"execute_model_req": execute_model_req}, use_ray_compiled_dag=USE_RAY_COMPILED_DAG) # Only the driver worker returns the sampling results. diff --git a/vllm/sequence.py b/vllm/sequence.py index 35ac59d69f117..f2939eff7959b 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1,7 +1,7 @@ """Sequence and its related classes.""" import copy import enum -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import TYPE_CHECKING, Dict, List, Optional, Union from vllm.block import LogicalTokenBlock @@ -734,3 +734,33 @@ def __repr__(self) -> str: f"sampled_token_probs={sampled_token_probs_repr}, " f"sampled_token_ids={sampled_token_ids_repr}, " f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})") + + +@dataclass +class ExecuteModelRequest: + """The model execution request.""" + # The sequence group metadata list. + seq_group_metadata_list: List[SequenceGroupMetadata] + # Blocks to swap in. Dict of CPU -> GPU block number. + blocks_to_swap_in: Dict[int, int] = field(default_factory=dict) + # Blocks to swap out. Dict of GPU -> CPU block number. + blocks_to_swap_out: Dict[int, int] = field(default_factory=dict) + # Blocks to copy. Source to a list of dest blocks. + blocks_to_copy: Dict[int, List[int]] = field(default_factory=dict) + # The number of slots for lookahead decoding. + num_lookahead_slots: int = 0 + # The number of requests in the running queue. + running_queue_size: int = 0 + + def clone( + self, seq_group_metadata_list: List[SequenceGroupMetadata] + ) -> "ExecuteModelRequest": + """Clone the request with a new sequence group metadata list.""" + return ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + blocks_to_swap_in=self.blocks_to_swap_in.copy(), + blocks_to_swap_out=self.blocks_to_swap_out.copy(), + blocks_to_copy=self.blocks_to_copy.copy(), + num_lookahead_slots=self.num_lookahead_slots, + running_queue_size=self.running_queue_size, + ) diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index 8b302ba1aabeb..d5fd96907ddd7 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -1,9 +1,10 @@ from itertools import chain, count -from typing import Dict, Iterator, List, Optional, Tuple +from typing import Iterator, List, Tuple import torch -from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata +from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData, + SequenceGroupMetadata) from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeScorer, SpeculativeScores) from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range, @@ -40,11 +41,7 @@ def __init__(self, scorer_worker: WorkerBase, device: str, @nvtx_range("BatchExpansionTop1Scorer.score_proposals") def score_proposals( self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Optional[Dict[int, int]], - blocks_to_swap_out: Optional[Dict[int, int]], - blocks_to_copy: Optional[Dict[int, List[int]]], - k: int, + execute_model_req: ExecuteModelRequest, proposals: SpeculativeProposals, ) -> SpeculativeScores: """Score the proposed tokens via the scorer model. @@ -57,11 +54,7 @@ def score_proposals( no speculation is produced for that sequence. Args: - seq_group_metadata_list: The input sequence group metadata. - blocks_to_swap_in: This is passed to the worker during scoring. - blocks_to_swap_out: This is passed to the worker during scoring. - blocks_to_copy: This is passed to the worker during scoring. - k: The fixed proposal length. + execute_model_req: The execution request. proposals: The speculative proposals to score. Returns: SpeculativeScores: The scores of each speculative token, along with @@ -80,28 +73,25 @@ def score_proposals( (spec_indices, non_spec_indices, target_seq_group_metadata_list, num_scoring_tokens) = self._expand_batch( - seq_group_metadata_list=seq_group_metadata_list, + seq_group_metadata_list=execute_model_req.seq_group_metadata_list, proposal_token_ids_list=proposal_token_ids_list_without_skips, proposal_lens_list=proposal_lens_list, ) target_sampler_output = self._scorer_worker.execute_model( - seq_group_metadata_list=target_seq_group_metadata_list, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - ) + execute_model_req=execute_model_req.clone( + seq_group_metadata_list=target_seq_group_metadata_list, )) assert len(target_sampler_output) == 1, "expected single-step output" target_sampler_output = target_sampler_output[0] all_tokens, all_probs, spec_logprobs = self._contract_batch( - contracted_bs=len(seq_group_metadata_list), + contracted_bs=len(execute_model_req.seq_group_metadata_list), target_sampler_output=target_sampler_output, proposals=proposals, num_scoring_tokens=num_scoring_tokens, non_spec_indices=non_spec_indices, spec_indices=spec_indices, - k=k, + k=execute_model_req.num_lookahead_slots, ) return SpeculativeScores( diff --git a/vllm/spec_decode/interfaces.py b/vllm/spec_decode/interfaces.py index 489d940a88856..d311bfe984cbc 100644 --- a/vllm/spec_decode/interfaces.py +++ b/vllm/spec_decode/interfaces.py @@ -1,10 +1,9 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Dict, List, Optional import torch -from vllm.sequence import SequenceGroupMetadata +from vllm.sequence import ExecuteModelRequest @dataclass @@ -58,11 +57,7 @@ class SpeculativeProposer(ABC): @abstractmethod def get_proposals( self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - max_proposal_len: int, + execute_model_req: ExecuteModelRequest, ) -> SpeculativeProposals: raise NotImplementedError @@ -72,11 +67,7 @@ class SpeculativeScorer(ABC): @abstractmethod def score_proposals( self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Optional[Dict[int, int]], - blocks_to_swap_out: Optional[Dict[int, int]], - blocks_to_copy: Optional[Dict[int, List[int]]], - k: int, + execute_model_req: ExecuteModelRequest, proposals: SpeculativeProposals, ) -> SpeculativeScores: raise NotImplementedError diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index d031bc85af160..5044cc1ef85fd 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -1,9 +1,10 @@ import copy -from typing import Dict, List, Tuple +from typing import List, Tuple import torch -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import (ExecuteModelRequest, SamplerOutput, + SequenceGroupMetadata) from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.top1_proposer import Top1Proposer from vllm.worker.worker import Worker @@ -44,10 +45,7 @@ def set_include_gpu_probs_tensor(self): @torch.inference_mode() def sampler_output( self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], + execute_model_req: ExecuteModelRequest, sample_len: int, ) -> Tuple[List[SamplerOutput], bool]: """Run the model forward pass sample_len times. Returns the list of @@ -57,26 +55,24 @@ def sampler_output( For multi step worker, this indicator shall be True. """ - self._raise_if_unsupported(seq_group_metadata_list, blocks_to_swap_in, - blocks_to_swap_out, blocks_to_copy) + self._raise_if_unsupported(execute_model_req) # Shallow copy input data so modifications (such as appending tokens) # do not cause side-effects. copied_seq_group_metadata_list = self._shallow_copy_inputs( - seq_group_metadata_list) + execute_model_req.seq_group_metadata_list) + copied_execute_model_req = execute_model_req.clone( + copied_seq_group_metadata_list) # Assert enough KV space for sample_len tokens per sequence. - self._assert_enough_kv_space(seq_group_metadata_list, sample_len) + self._assert_enough_kv_space(execute_model_req.seq_group_metadata_list, + sample_len) # Run model sample_len times. model_outputs = [] for _ in range(sample_len): model_output = super().execute_model( - seq_group_metadata_list=copied_seq_group_metadata_list, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - ) + execute_model_req=copied_execute_model_req) assert (len(model_output) == 1 ), "composing multistep workers not supported" model_output = model_output[0] @@ -89,23 +85,13 @@ def sampler_output( def get_spec_proposals( self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - max_proposal_len: int, + execute_model_req: ExecuteModelRequest, ) -> SpeculativeProposals: """Produce speculations given an input batch of sequences. The number of speculative tokens per sequence is determined by max_proposal_len. """ - return self._proposer.get_proposals( - seq_group_metadata_list, - blocks_to_swap_in, - blocks_to_swap_out, - blocks_to_copy, - max_proposal_len, - ) + return self._proposer.get_proposals(execute_model_req) def _append_new_tokens( self, model_output: SamplerOutput, @@ -196,20 +182,22 @@ def _assert_enough_kv_space( def _raise_if_unsupported( self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], + execute_model_req: ExecuteModelRequest, ) -> None: """MultiStepWorker does not yet implement support for cache swap operations or beam search. """ - if any([blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy]): + if any([ + execute_model_req.blocks_to_swap_in, + execute_model_req.blocks_to_swap_out, + execute_model_req.blocks_to_copy + ]): raise NotImplementedError( "MultiStepWorker does not support cache operations") if any( len(seq_group_metadata.seq_data.keys()) != 1 - for seq_group_metadata in seq_group_metadata_list): + for seq_group_metadata in + execute_model_req.seq_group_metadata_list): raise NotImplementedError( "MultiStepWorker does not support beam search.") diff --git a/vllm/spec_decode/ngram_worker.py b/vllm/spec_decode/ngram_worker.py index cacaca697526c..fed8be42054a5 100644 --- a/vllm/spec_decode/ngram_worker.py +++ b/vllm/spec_decode/ngram_worker.py @@ -1,8 +1,8 @@ -from typing import Dict, List, Optional, Tuple +from typing import List, Optional, Tuple import torch -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.top1_proposer import Top1Proposer from vllm.worker.worker_base import LoraNotSupportedWorkerBase @@ -46,13 +46,7 @@ def set_include_gpu_probs_tensor(self): # NGram don't need gpu sampler pass - def execute_model( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Optional[Dict[int, int]], - blocks_to_swap_out: Optional[Dict[int, int]], - blocks_to_copy: Optional[Dict[int, List[int]]], - ) -> None: + def execute_model(self, execute_model_req: ExecuteModelRequest) -> None: """NGram doesn't depend on model execution, just pass this function""" pass @@ -71,10 +65,7 @@ def get_cache_block_size_bytes(self): def sampler_output( self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], + execute_model_req: ExecuteModelRequest, sample_len: int, ) -> Tuple[Optional[List[SamplerOutput]], bool]: """NGram match algo to pick proposal candidate. Returns the list of @@ -83,16 +74,11 @@ def sampler_output( For ngram worker, we already done needed transposed internal, so the indicator pass to sampler_output_to_torch shall be False. """ - self._raise_if_unsupported( - seq_group_metadata_list, - blocks_to_swap_in, - blocks_to_swap_out, - blocks_to_copy, - ) + self._raise_if_unsupported(execute_model_req) arr = [] has_spec_out = False - for seq_group_metadata in seq_group_metadata_list: + for seq_group_metadata in execute_model_req.seq_group_metadata_list: seq_data = next(iter(seq_group_metadata.seq_data.values())) input_ids = torch.as_tensor(seq_data.get_token_ids(), @@ -135,17 +121,19 @@ def sampler_output( indices = token_ids.unsqueeze(2) token_probs = torch.zeros( - (len(seq_group_metadata_list), sample_len, self.vocab_size), + (len(execute_model_req.seq_group_metadata_list), sample_len, + self.vocab_size), dtype=torch.float32, device=self.device, ) token_probs.scatter_(2, indices, 1) token_logprobs = torch.zeros( - (len(seq_group_metadata_list), sample_len, self.vocab_size), + (len(execute_model_req.seq_group_metadata_list), sample_len, + self.vocab_size), dtype=torch.float32, device=self.device, ) - for i in range(len(seq_group_metadata_list)): + for i in range(len(execute_model_req.seq_group_metadata_list)): outputs.append( SamplerOutput( outputs=None, @@ -157,40 +145,32 @@ def sampler_output( def get_spec_proposals( self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - max_proposal_len: int, + execute_model_req: ExecuteModelRequest, ) -> SpeculativeProposals: """Produce speculations given an input batch of sequences. The number of speculative tokens per sequence is determined by max_proposal_len. """ - return self._proposer.get_proposals( - seq_group_metadata_list, - blocks_to_swap_in, - blocks_to_swap_out, - blocks_to_copy, - max_proposal_len, - ) + return self._proposer.get_proposals(execute_model_req) def _raise_if_unsupported( self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], + execute_model_req: ExecuteModelRequest, ) -> None: """NGramWorker does not yet implement support for cache swap operations or beam search. """ - if any([blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy]): + if any([ + execute_model_req.blocks_to_swap_in, + execute_model_req.blocks_to_swap_out, + execute_model_req.blocks_to_copy + ]): raise NotImplementedError( "NGramWorker does not support cache operations") if any( len(seq_group_metadata.seq_data.keys()) != 1 - for seq_group_metadata in seq_group_metadata_list): + for seq_group_metadata in + execute_model_req.seq_group_metadata_list): raise NotImplementedError( "NGramWorker does not support beam search.") diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 503519a0dfc4b..c2b119fbd5036 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -1,11 +1,12 @@ from functools import cached_property -from typing import Dict, List, Optional, Tuple +from typing import List, Optional, Tuple import torch from vllm.logger import init_logger from vllm.model_executor.layers.rejection_sampler import RejectionSampler -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import (ExecuteModelRequest, SamplerOutput, + SequenceGroupMetadata) from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeScorer, SpeculativeScores) @@ -189,69 +190,37 @@ def initialize_cache(self, num_gpu_blocks: int, @torch.inference_mode() def execute_model( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Optional[Dict[int, int]], - blocks_to_swap_out: Optional[Dict[int, int]], - blocks_to_copy: Optional[Dict[int, List[int]]], - num_lookahead_slots: int, - ) -> List[SamplerOutput]: + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: """Perform speculative decoding on the input batch. """ - assert seq_group_metadata_list is not None, ( + assert execute_model_req.seq_group_metadata_list is not None, ( "speculative decoding " "requires non-None seq_group_metadata_list") - #logger.info("spec_decode_worker.execute_model num_lookahead_slots=%d", - # num_lookahead_slots) - # If no spec tokens, call the proposer and scorer workers normally. # Used for prefill. - if num_lookahead_slots == 0 or len(seq_group_metadata_list) == 0: - return self._run_no_spec( - seq_group_metadata_list=seq_group_metadata_list, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - ) - - return self._run_speculative_decoding_step( - seq_group_metadata_list=seq_group_metadata_list, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - k=num_lookahead_slots, - ) + if execute_model_req.num_lookahead_slots == 0 or len( + execute_model_req.seq_group_metadata_list) == 0: + return self._run_no_spec(execute_model_req) + + return self._run_speculative_decoding_step(execute_model_req) @nvtx_range("spec_decode_worker._run_no_spec") def _run_no_spec( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Optional[Dict[int, int]], - blocks_to_swap_out: Optional[Dict[int, int]], - blocks_to_copy: Optional[Dict[int, List[int]]], - ) -> List[SamplerOutput]: + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: """Run a prefill step, without any speculation. The input is sent to the proposer and scorer model so that the KV cache is consistent between the two. """ #logger.info("run proposer worker no spec") - self.proposer_worker.execute_model( - seq_group_metadata_list=seq_group_metadata_list, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - ) + self.proposer_worker.execute_model(execute_model_req) #logger.info("run target worker no spec") - sampler_output = self.scorer_worker.execute_model( - seq_group_metadata_list=seq_group_metadata_list, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - ) + sampler_output = self.scorer_worker.execute_model(execute_model_req) assert len(sampler_output) == 1 sampler_output = sampler_output[0] @@ -264,13 +233,8 @@ def _run_no_spec( @nvtx_range("spec_decode_worker._run_speculative_decoding_step") def _run_speculative_decoding_step( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Optional[Dict[int, int]], - blocks_to_swap_out: Optional[Dict[int, int]], - blocks_to_copy: Optional[Dict[int, List[int]]], - k: int, - ) -> List[SamplerOutput]: + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: """Execute a single step of speculative decoding. This invokes the proposer worker to get k speculative tokens for each @@ -282,33 +246,25 @@ def _run_speculative_decoding_step( #logger.info("get spec proposals") # Generate proposals using draft worker. - assert blocks_to_swap_in is not None - assert blocks_to_swap_out is not None - assert blocks_to_copy is not None - proposals = self.proposer_worker.get_spec_proposals( - seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out, - blocks_to_copy, k) + proposals = self.proposer_worker.get_spec_proposals(execute_model_req) #logger.info("score proposals") proposal_scores = self.scorer.score_proposals( - seq_group_metadata_list, - blocks_to_swap_in, - blocks_to_swap_out, - blocks_to_copy, - k, + execute_model_req, proposals, ) #logger.info("verify proposals") accepted_token_ids, target_logprobs = self._verify_tokens( - seq_group_metadata_list, proposal_scores, proposals, k) + execute_model_req.seq_group_metadata_list, proposal_scores, + proposals, execute_model_req.num_lookahead_slots) #logger.info("create output list") return self._create_output_sampler_list( - seq_group_metadata_list, + execute_model_req.seq_group_metadata_list, accepted_token_ids, target_logprobs=target_logprobs, - k=k) + k=execute_model_req.num_lookahead_slots) @nvtx_range("spec_decode_worker._verify_tokens") def _verify_tokens( diff --git a/vllm/spec_decode/top1_proposer.py b/vllm/spec_decode/top1_proposer.py index 56c63887b0315..eb622a0e2e7f4 100644 --- a/vllm/spec_decode/top1_proposer.py +++ b/vllm/spec_decode/top1_proposer.py @@ -1,8 +1,9 @@ -from typing import Dict, List, Optional, Tuple +from typing import List, Optional, Tuple import torch -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import (ExecuteModelRequest, SamplerOutput, + SequenceGroupMetadata) from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeProposer) from vllm.spec_decode.util import sampler_output_to_torch @@ -40,17 +41,15 @@ def __init__( def get_proposals( self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - proposal_len: int, + execute_model_req: ExecuteModelRequest, ) -> SpeculativeProposals: """Get speculative proposals given the input batch. Sequences which would exceed the max model length are skipped during speculation. """ + proposal_len = execute_model_req.num_lookahead_slots + seq_group_metadata_list = execute_model_req.seq_group_metadata_list # Split speculative- and non-speculative- sequences. ( @@ -66,11 +65,12 @@ def get_proposals( # token_ids is like [batch] format in proposal_len size list, # while if it is false, the format would be [proposal_len] # in batch size list - maybe_sampler_output, transposed = self._worker.sampler_output( + nonzero_execute_model_req = ExecuteModelRequest( seq_group_metadata_list=nonzero_proposal_len_seqs, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, + num_lookahead_slots=proposal_len, + ) + maybe_sampler_output, transposed = self._worker.sampler_output( + execute_model_req=nonzero_execute_model_req, sample_len=proposal_len, ) else: diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 83ededd742533..4420d4cc9e12f 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -13,7 +13,7 @@ init_distributed_environment) from vllm.logger import init_logger from vllm.model_executor import set_random_seed -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.worker.cpu_model_runner import CPUModelRunner from vllm.worker.worker_base import LoraNotSupportedWorkerBase @@ -256,22 +256,24 @@ def cache_copy( @torch.inference_mode() def execute_model( self, - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None, - blocks_to_swap_in: Optional[Dict[int, int]] = None, - blocks_to_swap_out: Optional[Dict[int, int]] = None, - blocks_to_copy: Optional[Dict[int, List[int]]] = None, + execute_model_req: Optional[ExecuteModelRequest] = None, ) -> List[SamplerOutput]: + + if execute_model_req is None: + seq_group_metadata_list = None + else: + seq_group_metadata_list = execute_model_req.seq_group_metadata_list + if self.is_driver_worker: assert seq_group_metadata_list is not None num_seq_groups: int = len(seq_group_metadata_list) - assert blocks_to_swap_in is not None - assert blocks_to_swap_out is not None - assert blocks_to_copy is not None - assert len(blocks_to_swap_in) == 0 - assert len(blocks_to_swap_out) == 0 + assert execute_model_req is not None + blocks_to_copy = execute_model_req.blocks_to_copy + assert len(execute_model_req.blocks_to_swap_in) == 0 + assert len(execute_model_req.blocks_to_swap_out) == 0 data: Dict[str, Any] = { "num_seq_groups": num_seq_groups, - "blocks_to_copy": blocks_to_copy, + "blocks_to_copy": execute_model_req.blocks_to_copy, } broadcast_tensor_dict(data, src=0) else: @@ -279,7 +281,6 @@ def execute_model( num_seq_groups = data["num_seq_groups"] blocks_to_copy = data["blocks_to_copy"] - assert blocks_to_copy is not None self.cache_copy(blocks_to_copy) # If there is no input, we don't need to execute the model. diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 808261e47318b..4add36e94f723 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -18,7 +18,7 @@ init_custom_ar) from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.worker.cache_engine import CacheEngine from vllm.worker.model_runner import ModelRunner from vllm.worker.worker_base import WorkerBase @@ -211,19 +211,21 @@ def cache_swap( @torch.inference_mode() def execute_model( self, - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None, - blocks_to_swap_in: Optional[Dict[int, int]] = None, - blocks_to_swap_out: Optional[Dict[int, int]] = None, - blocks_to_copy: Optional[Dict[int, List[int]]] = None, - num_lookahead_slots: int = 0, + execute_model_req: Optional[ExecuteModelRequest] = None ) -> List[SamplerOutput]: + if execute_model_req is None: + seq_group_metadata_list = None + else: + seq_group_metadata_list = execute_model_req.seq_group_metadata_list + if self.is_driver_worker: assert seq_group_metadata_list is not None + assert execute_model_req is not None num_seq_groups = len(seq_group_metadata_list) - assert blocks_to_swap_in is not None - assert blocks_to_swap_out is not None - assert blocks_to_copy is not None + blocks_to_swap_in = execute_model_req.blocks_to_swap_in + blocks_to_swap_out = execute_model_req.blocks_to_swap_out + blocks_to_copy = execute_model_req.blocks_to_copy data: Dict[str, Any] = { "num_seq_groups": num_seq_groups, "blocks_to_swap_in": blocks_to_swap_in, @@ -238,9 +240,6 @@ def execute_model( blocks_to_swap_out = data["blocks_to_swap_out"] blocks_to_copy = data["blocks_to_copy"] - assert blocks_to_swap_in is not None - assert blocks_to_swap_out is not None - assert blocks_to_copy is not None self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy) # If there is no input, we don't need to execute the model. diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 0a89e3a79769f..fb32feaca0c94 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -5,7 +5,7 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.utils import (enable_trace_function_call_for_thread, update_environment_variables) @@ -48,10 +48,8 @@ def initialize_cache(self, num_gpu_blocks: int, @abstractmethod def execute_model( - self, seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int, - int], - blocks_to_copy: Dict[int, List[int]]) -> List[SamplerOutput]: + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: """Executes at least one model step on the given sequences, unless no sequences are provided.""" raise NotImplementedError