diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index c2d1c5769619b..92de545acd53d 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -58,19 +58,25 @@ def test_prepare_prompt(batch_size): expected_selected_token_indices.append(selected_token_start_idx + seq_len - 1) selected_token_start_idx += seq_len - (input_tokens, input_positions, attn_metadata, return_seq_lens, _, _, _, _, - _, slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list)) + model_input = model_runner._prepare_model_input(seq_group_metadata_list) + input_tokens = model_input.input_tokens + input_positions = model_input.input_positions + attn_metadata = model_input.attn_metadata + return_seq_lens = model_input.seq_lens + slot_mapping = model_input.slot_mapping assert return_seq_lens == seq_lens assert len(slot_mapping) == len(input_tokens) # Verify input metadata is correct for prompts. device = model_runner.device - assert attn_metadata.is_prompt is True + assert attn_metadata.num_prefills > 0 + assert attn_metadata.num_decode_tokens == 0 assert torch.allclose( attn_metadata.seq_lens_tensor, torch.tensor(seq_lens, device=device, dtype=torch.int)) assert attn_metadata.seq_lens == seq_lens - assert attn_metadata.max_seq_len == max(seq_lens) + assert attn_metadata.max_prefill_seq_len == max(seq_lens) + assert attn_metadata.max_decode_seq_len == 0 # Test subquery start locs. start_idx = 0 @@ -79,11 +85,11 @@ def test_prepare_prompt(batch_size): start_idx += seq_len start_loc.append(start_idx) assert torch.allclose( - attn_metadata.subquery_start_loc, + attn_metadata.query_start_loc, torch.tensor(start_loc, dtype=torch.int32, device=device)) # Test seq start locs. Note that for normal prefill it is - # equivalent to subquery_start_loc. + # equivalent to query_start_loc. start_idx = 0 seq_start_loc = [start_idx] for seq_len in seq_lens: @@ -123,7 +129,7 @@ def test_prepare_prompt(batch_size): device=actual.device, dtype=actual.dtype) torch.testing.assert_close(actual, expected) - assert input_tokens == input_positions + torch.allclose(input_tokens, input_positions) actual = sampling_metadata.selected_token_indices expected = torch.tensor(expected_selected_token_indices, @@ -144,14 +150,18 @@ def test_prepare_decode_cuda_graph(batch_size): enable_chunked_prefill=False, ) - seq_lens = [] + context_lens = [] seq_group_metadata_list = [] + # Assume each seq group finishes prefill. for i in range(batch_size): # make sure all tokens fit into one block - seq_len = i % (model_runner.block_size - 1) + 1 - seq_lens.append(seq_len) - seq_data = list(range(seq_len)) + context_len = i % (model_runner.block_size - 1) + 1 + context_lens.append(context_len) + seq_data = list(range(context_len)) seq_data = SequenceData(seq_data) + seq_data.update_num_computed_tokens(context_len) + # Append one token ID since prefill is finished. + seq_data.append_token_id(1, 0) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=False, @@ -162,18 +172,45 @@ def test_prepare_decode_cuda_graph(batch_size): assert seq_group_metadata.token_chunk_size == 1 seq_group_metadata_list.append(seq_group_metadata) - input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = ( - model_runner._prepare_decode(seq_group_metadata_list)) + model_input = model_runner._prepare_model_input(seq_group_metadata_list) + input_tokens, input_positions, attn_metadata, slot_mapping = ( + model_input.input_tokens, model_input.input_positions, + model_input.attn_metadata, model_input.slot_mapping) assert len(slot_mapping) == len(input_tokens) expected_bs = _get_graph_batch_size(len(seq_group_metadata_list)) # Verify input metadata is correct for prompts. device = model_runner.device - assert attn_metadata.is_prompt is False - assert attn_metadata.seq_lens is None - assert attn_metadata.subquery_start_loc is None - assert attn_metadata.seq_start_loc is None - assert attn_metadata.max_seq_len == max(seq_lens) + assert attn_metadata.num_prefills == 0 + assert attn_metadata.num_prefill_tokens == 0 + seq_lens = [context_len + 1 for context_len in context_lens] + # seq_lens are padded to expected_bs + for _ in range(expected_bs - len(seq_lens)): + seq_lens.append(1) + assert attn_metadata.seq_lens == seq_lens + start_idx = 0 + start_loc = [start_idx] + for _ in context_lens: + # decode has only 1 token for query. + start_idx += 1 + start_loc.append(start_idx) + assert torch.allclose( + attn_metadata.query_start_loc, + torch.tensor(start_loc, dtype=torch.int32, device=device)) + + start_idx = 0 + seq_start_loc = [start_idx] + for seq_len in seq_lens: + start_idx += seq_len + seq_start_loc.append(start_idx) + assert torch.allclose( + attn_metadata.seq_start_loc, + torch.tensor(seq_start_loc, dtype=torch.int32, device=device)) + + assert torch.allclose( + attn_metadata.context_lens_tensor, + torch.tensor(context_lens, dtype=torch.int, device=device)) + assert attn_metadata.max_decode_seq_len == max(seq_lens) assert torch.allclose( attn_metadata.seq_lens_tensor[:len(seq_lens)], torch.tensor(seq_lens, dtype=torch.int, device=device)) @@ -185,23 +222,23 @@ def test_prepare_decode_cuda_graph(batch_size): # It is padded up to assert attn_metadata.block_tables.shape[1] == ( model_runner.get_max_block_per_batch()) - # Cuda graph should not be used for prerill. assert attn_metadata.use_cuda_graph is True assert len(input_tokens) == expected_bs assert len(input_positions) == expected_bs - assert input_tokens == input_positions + torch.allclose(input_tokens, input_positions) # Verify Sampling expected_selected_token_indices = [] selected_token_start_idx = 0 - for seq_len in seq_lens: + for _ in context_lens: expected_selected_token_indices.append(selected_token_start_idx) selected_token_start_idx += 1 sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, seq_lens, - query_lens=seq_lens, + # query lens is all 1 for decode. + query_lens=[1 for _ in range(len(context_lens))], device=model_runner.device, pin_memory=model_runner.pin_memory) actual = sampling_metadata.selected_token_indices @@ -220,15 +257,27 @@ def test_empty_seq_group(): enforce_eager=False, ) seq_group_metadata_list = [] - input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = ( - model_runner._prepare_decode(seq_group_metadata_list)) + model_input = model_runner._prepare_model_input(seq_group_metadata_list) + input_tokens, input_positions, attn_metadata, slot_mapping = ( + model_input.input_tokens, + model_input.input_positions, + model_input.attn_metadata, + model_input.slot_mapping, + ) assert len(input_tokens) == 0 assert len(input_positions) == 0 assert attn_metadata is None assert len(slot_mapping) == 0 - (input_tokens, input_positions, attn_metadata, return_seq_lens, _, _, _, _, - _, slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list)) + model_input = model_runner._prepare_model_input(seq_group_metadata_list) + (input_tokens, input_positions, attn_metadata, slot_mapping, + return_seq_lens) = ( + model_input.input_tokens, + model_input.input_positions, + model_input.attn_metadata, + model_input.slot_mapping, + model_input.seq_lens, + ) assert len(input_tokens) == 0 assert len(input_positions) == 0 assert attn_metadata is None @@ -285,9 +334,11 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): # Add decode requests for i in range(prefill_batch_size, batch_size): # make sure all tokens fit into one block - seq_len = i % (model_runner.block_size - 1) + 1 - prompt_toks = list(range(seq_len)) + context_len = i % (model_runner.block_size - 1) + 1 + prompt_toks = list(range(context_len)) seq_data = SequenceData(prompt_toks) + seq_data.append_token_id(1, 0) + seq_data.update_num_computed_tokens(context_len) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=False, @@ -308,23 +359,17 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): assert len(attn_metadata.slot_mapping) == len(input_tokens) assert len(input_positions) == len(input_tokens) assert attn_metadata.num_prefills == prefill_batch_size - if enforce_eager: - assert attn_metadata.num_decode_tokens == decode_batch_size - else: - assert attn_metadata.num_decode_tokens == _get_graph_batch_size( - decode_batch_size) + assert attn_metadata.num_decode_tokens == decode_batch_size assert attn_metadata.num_prefill_tokens == sum(seq_lens) # Verify attn metadata is consistent. We don't need to test individual # values here because they are tested above. - prefill_meta = model_runner._prepare_prompt( - prefill_metadata_list).attn_metadata - decode_meta = model_runner._prepare_decode( - decode_metadata_list).attn_metadata + attn_metadata = model_runner._prepare_model_input( + seq_group_metadata_list).attn_metadata - for attr_expected, attr_actual in zip(vars(prefill_meta), + for attr_expected, attr_actual in zip(vars(attn_metadata.prefill_metadata), vars(prefill_meta_actual)): assert attr_expected[1] == attr_actual[1] - for attr_expected, attr_actual in zip(vars(decode_meta), + for attr_expected, attr_actual in zip(vars(attn_metadata.decode_metadata), vars(decode_meta_actual)): assert attr_expected[1] == attr_actual[1] diff --git a/vllm/attention/__init__.py b/vllm/attention/__init__.py index 088f48def7668..f6bce9a187c64 100644 --- a/vllm/attention/__init__.py +++ b/vllm/attention/__init__.py @@ -1,6 +1,5 @@ from vllm.attention.backends.abstract import (AttentionBackend, - AttentionMetadata, - AttentionMetadataPerStage) + AttentionMetadata) from vllm.attention.layer import Attention from vllm.attention.selector import get_attn_backend @@ -8,6 +7,6 @@ "Attention", "AttentionBackend", "AttentionMetadata", - "AttentionMetadataPerStage", + "Attention", "get_attn_backend", ] diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 98d70fcab1a18..94ab64de30a94 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -21,7 +21,7 @@ def get_impl_cls() -> Type["AttentionImpl"]: @staticmethod @abstractmethod - def make_metadata(*args, **kwargs) -> "AttentionMetadataPerStage": + def make_metadata(*args, **kwargs) -> "AttentionMetadata": raise NotImplementedError @staticmethod @@ -53,8 +53,34 @@ def copy_blocks( @dataclass -class AttentionMetadataPerStage: - """Attention metadata for a specific stage. I.e., prefill or decode.""" +class AttentionMetadata: + """Attention metadata for prefill and decode batched together.""" + # Total number of prefill requests. + num_prefills: int + # Number of prefill tokens. + num_prefill_tokens: int + # Number of decode tokens. Note that it is equivalent to the number of + # decode requests. + num_decode_tokens: int + # (num_tokens,). The indices of the token slots that input tokens will be + # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size + # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot + # in block 0, and 1st slot in block 1, respectively. + slot_mapping: torch.Tensor + + @property + @abstractmethod + def prefill_metadata(self) -> Optional["AttentionMetadata"]: + """Return the attention metadata that's required to run prefill + attention.""" + pass + + @property + @abstractmethod + def decode_metadata(self) -> Optional["AttentionMetadata"]: + """Return the attention metadata that's required to run decode + attention.""" + pass def asdict_zerocopy(self, skip_fields: Optional[Set[str]] = None @@ -70,40 +96,10 @@ def asdict_zerocopy(self, } -T = TypeVar("T", bound=AttentionMetadataPerStage) - - -@dataclass -class AttentionMetadata(Generic[T]): - """Attention metadata for prefill and decode batched together.""" - # Total number of prefill requests. - num_prefills: int - # Number of prefill tokens. - num_prefill_tokens: int - # Number of decode tokens. Note that it is equivalent to the number of - # decode requests. - num_decode_tokens: int - # The attention metadata for prefill requests in a batch. - # None if there's no prefill requests in a batch. - prefill_metadata: Optional[T] - # The attention metadata for decode requests in a batch. - # None if there's no decode requests in a batch. - decode_metadata: Optional[T] - # (num_tokens,). The indices of the token slots that input tokens will be - # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size - # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot - # in block 0, and 1st slot in block 1, respectively. - slot_mapping: torch.Tensor - - def __post_init__(self): - if self.num_prefill_tokens > 0: - assert self.num_prefills > 0 - assert self.prefill_metadata is not None - if self.num_decode_tokens > 0: - assert self.decode_metadata is not None +T = TypeVar("T", bound=AttentionMetadata) -class AttentionImpl(ABC): +class AttentionImpl(ABC, Generic[T]): @abstractmethod def __init__( @@ -125,7 +121,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, + attn_metadata: T, kv_scale: float = 1.0, ) -> torch.Tensor: raise NotImplementedError diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index f59715bd76ede..5d1f65819ed4e 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -11,8 +11,7 @@ from vllm_flash_attn import flash_attn_varlen_func from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, - AttentionMetadataPerStage) + AttentionMetadata) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) @@ -58,8 +57,7 @@ def copy_blocks( @dataclass -class FlashAttentionMetadata(AttentionMetadataPerStage, - PagedAttentionMetadata): +class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): """Metadata for FlashAttentionBackend. NOTE: Any python object stored here is not updated when it is @@ -67,9 +65,6 @@ class FlashAttentionMetadata(AttentionMetadataPerStage, dynamically, it should be stored in tensor. The tensor has to be updated from `CUDAGraphRunner.forward` API. """ - # Currently, input sequences can only contain all prompts - # or all decoding. True if all sequences are prompts. - is_prompt: bool # (batch_size,). The sequence length per sequence. Sequence length means # the computed tokens + new tokens None if it is a decoding. seq_lens: Optional[List[int]] @@ -84,14 +79,18 @@ class FlashAttentionMetadata(AttentionMetadataPerStage, # |-------------------- seq_len ----------------------| # |-- query_len ---| - # Maximum query length in the batch. + # Maximum query length in the batch. None for decoding. max_query_len: Optional[int] - # Maximum sequence length in the batch. - max_seq_len: Optional[int] + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int # (batch_size + 1,). The cumulative subquery lengths of the sequences in # the batch, used to index into subquery. E.g., if the subquery length # is [4, 6], it is [0, 4, 10]. - subquery_start_loc: Optional[torch.Tensor] + query_start_loc: Optional[torch.Tensor] # (batch_size + 1,). The cumulative sequence lengths of the sequences in # the batch, used to index into sequence. E.g., if the sequence length is # [4, 6], it is [0, 4, 10]. @@ -105,6 +104,70 @@ class FlashAttentionMetadata(AttentionMetadataPerStage, # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool + _cached_prefill_metadata: Optional["FlashAttentionMetadata"] = None + _cached_decode_metadata: Optional["FlashAttentionMetadata"] = None + + @property + def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + assert self.seq_lens is not None + assert self.seq_lens_tensor is not None + assert self.query_start_loc is not None + assert self.context_lens_tensor is not None + assert self.block_tables is not None + assert self.seq_start_loc is not None + + self._cached_prefill_metadata = FlashAttentionMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=self.slot_mapping[:self.num_prefill_tokens], + seq_lens=self.seq_lens[:self.num_prefills], + seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_seq_len=0, + query_start_loc=self.query_start_loc[:self.num_prefills + 1], + seq_start_loc=self.seq_start_loc[:self.num_prefills + 1], + context_lens_tensor=self.context_lens_tensor[:self.num_prefills], + block_tables=self.block_tables[:self.num_prefills], + use_cuda_graph=False, + ) + return self._cached_prefill_metadata + + @property + def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + assert self.block_tables is not None + assert self.seq_lens_tensor is not None + + self._cached_decode_metadata = FlashAttentionMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=self.slot_mapping[self.num_prefill_tokens:], + seq_lens=None, + seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + max_query_len=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=self.block_tables[self.num_prefills:], + use_cuda_graph=self.use_cuda_graph, + ) + return self._cached_decode_metadata + class FlashAttentionImpl(AttentionImpl): """ @@ -168,7 +231,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata[FlashAttentionMetadata], + attn_metadata: FlashAttentionMetadata, kv_scale: float = 1.0, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. @@ -228,8 +291,8 @@ def forward( v=value, cu_seqlens_q=prefill_meta.seq_start_loc, cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_q=prefill_meta.max_seq_len, - max_seqlen_k=prefill_meta.max_seq_len, + max_seqlen_q=prefill_meta.max_prefill_seq_len, + max_seqlen_k=prefill_meta.max_prefill_seq_len, softmax_scale=self.scale, causal=True, window_size=self.sliding_window, @@ -249,7 +312,7 @@ def forward( key_cache, value_cache, prefill_meta.block_tables, - prefill_meta.subquery_start_loc, + prefill_meta.query_start_loc, prefill_meta.seq_lens_tensor, prefill_meta.context_lens_tensor, prefill_meta.max_query_len, @@ -264,7 +327,7 @@ def forward( value_cache, decode_meta.block_tables, decode_meta.seq_lens_tensor, - decode_meta.max_seq_len, + decode_meta.max_decode_seq_len, self.kv_cache_dtype, self.num_kv_heads, self.scale, diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 92d0fe0487516..5f9fd586fb70e 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -8,8 +8,7 @@ from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, - AttentionMetadataPerStage) + AttentionMetadata) class FlashInferBackend(AttentionBackend): @@ -56,9 +55,10 @@ def get_supported_head_sizes() -> List[int]: @dataclass -class FlashInferMetadata(AttentionMetadataPerStage): - - is_prompt: bool +class FlashInferMetadata(AttentionMetadata): + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int use_cuda_graph: bool = False @@ -67,7 +67,6 @@ class FlashInferMetadata(AttentionMetadataPerStage): # Metadata for the prefill stage since we still # use flash attention for prefill. seq_start_loc: Optional[torch.Tensor] = None - max_seq_len: Optional[int] = None block_tables: Optional[torch.Tensor] = None # Metadata for the decode stage @@ -113,7 +112,8 @@ def __post_init__(self): # When using flashinfer, we are also creating the FlashInferMetadata, # which will also call post_init by default, here we want to skip the # post_init if it's the prefill phase. - if not self.is_prompt: + if self.num_prefills == 0: + assert self.num_decode_tokens > 0 self.decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( self.workspace_buffer, "NHD") self.decode_wrapper.begin_forward( @@ -138,6 +138,24 @@ def asdict_zerocopy(self, skip_fields.add('decode_wrapper') return super().asdict_zerocopy(skip_fields) + @property + def prefill_metadata(self) -> Optional["FlashInferMetadata"]: + # Currently chunked prefill is not supported + if self.num_decode_tokens == 0: + assert self.num_prefills > 0 + return self + + return None + + @property + def decode_metadata(self) -> Optional["FlashInferMetadata"]: + # Currently chunked prefill is not supported + if self.num_prefills > 0: + assert self.num_decode_tokens == 0 + return None + + return self + class FlashInferImpl(AttentionImpl): @@ -172,7 +190,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: Optional[torch.Tensor], - attn_metadata: AttentionMetadata[FlashInferMetadata], + attn_metadata: FlashInferMetadata, kv_scale: float = 1.0, ) -> torch.Tensor: assert kv_scale == 1.0 @@ -208,8 +226,8 @@ def forward( v=value, cu_seqlens_q=prefill_meta.seq_start_loc, cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_q=prefill_meta.max_seq_len, - max_seqlen_k=prefill_meta.max_seq_len, + max_seqlen_q=prefill_meta.max_prefill_seq_len, + max_seqlen_k=prefill_meta.max_prefill_seq_len, softmax_scale=self.scale, causal=True, window_size=self.sliding_window, diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 539585b46c7aa..1a94dc3596358 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -6,8 +6,7 @@ import vllm.envs as envs from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, - AttentionMetadataPerStage) + AttentionMetadata) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger @@ -56,8 +55,7 @@ def copy_blocks( @dataclass -class ROCmFlashAttentionMetadata(AttentionMetadataPerStage, - PagedAttentionMetadata): +class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): """Metadata for FlashAttentionBackend. NOTE: Any python object stored here is not updated when it is @@ -65,9 +63,6 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage, dynamically, it should be stored in tensor. The tensor has to be updated from `CUDAGraphRunner.forward` API. """ - # Currently, input sequences can only contain all prompts - # or all decoding. True if all sequences are prompts. - is_prompt: bool # (batch_size,). The sequence length per sequence. Sequence length means # the computed tokens + new tokens None if it is a decoding. seq_lens: Optional[List[int]] @@ -82,14 +77,18 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage, # |-------------------- seq_len ----------------------| # |-- query_len ---| - # Maximum query length in the batch. + # Maximum query length in the batch. None for decoding. max_query_len: Optional[int] - # Maximum sequence length in the batch. - max_seq_len: Optional[int] + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int # (batch_size + 1,). The cumulative subquery lengths of the sequences in # the batch, used to index into subquery. E.g., if the subquery length # is [4, 6], it is [0, 4, 10]. - subquery_start_loc: Optional[torch.Tensor] + query_start_loc: Optional[torch.Tensor] # (batch_size + 1,). The cumulative sequence lengths of the sequences in # the batch, used to index into sequence. E.g., if the sequence length is # [4, 6], it is [0, 4, 10]. @@ -102,6 +101,69 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage, # (batch_size,) A tensor of context lengths (tokens that are computed # so far). context_lens_tensor: Optional[torch.Tensor] + _cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None + _cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None + + @property + def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + assert self.seq_lens is not None + assert self.seq_lens_tensor is not None + assert self.query_start_loc is not None + assert self.context_lens_tensor is not None + assert self.block_tables is not None + assert self.seq_start_loc is not None + + self._cached_prefill_metadata = ROCmFlashAttentionMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=self.slot_mapping[:self.num_prefill_tokens], + seq_lens=self.seq_lens[:self.num_prefills], + seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_seq_len=0, + query_start_loc=self.query_start_loc[:self.num_prefills + 1], + seq_start_loc=self.seq_start_loc[:self.num_prefills + 1], + context_lens_tensor=self.context_lens_tensor[:self.num_prefills], + block_tables=self.block_tables[:self.num_prefills], + use_cuda_graph=False, + ) + return self._cached_prefill_metadata + + @property + def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + assert self.block_tables is not None + assert self.seq_lens_tensor is not None + + self._cached_decode_metadata = ROCmFlashAttentionMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=self.slot_mapping[self.num_prefill_tokens:], + seq_lens=None, + seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + max_query_len=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=self.block_tables[self.num_prefills:], + use_cuda_graph=self.use_cuda_graph, + ) + return self._cached_decode_metadata class ROCmFlashAttentionImpl(AttentionImpl): @@ -198,7 +260,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata[ROCmFlashAttentionMetadata], + attn_metadata: ROCmFlashAttentionMetadata, kv_scale: float = 1.0, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. @@ -266,8 +328,8 @@ def forward( None, prefill_meta.seq_start_loc, prefill_meta.seq_start_loc, - prefill_meta.max_seq_len, - prefill_meta.max_seq_len, + prefill_meta.max_prefill_seq_len, + prefill_meta.max_prefill_seq_len, True, self.scale, ) @@ -290,8 +352,8 @@ def forward( v=value, cu_seqlens_q=prefill_meta.seq_start_loc, cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_q=prefill_meta.max_seq_len, - max_seqlen_k=prefill_meta.max_seq_len, + max_seqlen_q=prefill_meta.max_prefill_seq_len, + max_seqlen_k=prefill_meta.max_prefill_seq_len, softmax_scale=self.scale, causal=True, ) @@ -308,7 +370,7 @@ def forward( key_cache, value_cache, prefill_meta.block_tables, - prefill_meta.subquery_start_loc, + prefill_meta.query_start_loc, prefill_meta.seq_lens_tensor, prefill_meta.context_lens_tensor, prefill_meta.max_query_len, @@ -324,7 +386,7 @@ def forward( value_cache, decode_meta.block_tables, decode_meta.seq_lens_tensor, - decode_meta.max_seq_len, + decode_meta.max_decode_seq_len, self.kv_cache_dtype, self.num_kv_heads, self.scale, diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 2dd72a00c6e30..a3f72b9c94566 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -7,8 +7,7 @@ from torch.nn.functional import scaled_dot_product_attention from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, - AttentionMetadataPerStage) + AttentionMetadata) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) @@ -54,8 +53,7 @@ def copy_blocks( @dataclass -class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata, - AttentionMetadataPerStage): +class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata): """Metadata for TorchSDPABackend. """ # Currently, input sequences can only contain all prompts @@ -72,8 +70,26 @@ def __post_init__(self): # will not appear in the __repr__ and __init__ self.attn_bias: Optional[List[torch.Tensor]] = None + @property + def prefill_metadata(self) -> Optional["TorchSDPAMetadata"]: + # Currently chunked prefill is not supported + if self.num_decode_tokens == 0: + assert self.num_prefills > 0 + return self -class TorchSDPABackendImpl(AttentionImpl): + return None + + @property + def decode_metadata(self) -> Optional["TorchSDPAMetadata"]: + # Currently chunked prefill is not supported + if self.num_prefills > 0: + assert self.num_decode_tokens == 0 + return None + + return self + + +class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): def __init__( self, @@ -200,7 +216,7 @@ def forward( value_cache, attn_metadata.block_tables, attn_metadata.seq_lens_tensor, - attn_metadata.max_seq_len, + attn_metadata.max_decode_seq_len, self.kv_cache_dtype, self.num_kv_heads, self.scale, diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index cb2028553461f..fc46af054de4f 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -9,8 +9,7 @@ LowerTriangularMaskWithTensorBias) from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, - AttentionMetadataPerStage) + AttentionMetadata) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger @@ -59,7 +58,7 @@ def copy_blocks( @dataclass -class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata): +class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): """Metadata for XFormersbackend. NOTE: Any python object stored here is not updated when it is @@ -67,9 +66,6 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata): dynamically, it should be stored in tensor. The tensor has to be updated from `CUDAGraphRunner.forward` API. """ - # Currently, input sequences can only contain all prompts - # or all decoding. True if all sequences are prompts. - is_prompt: bool # (batch_size,). The sequence length per sequence. Sequence length means # the computed tokens + new tokens None if it is a decoding. seq_lens: Optional[List[int]] @@ -83,15 +79,19 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata): # |-------------------- seq_len ----------------------| # |-- query_len ---| - # Maximum query length in the batch. + # Maximum query length in the batch. None for decoding. max_query_len: Optional[int] # FIXME: It is for flash attn. - # Maximum sequence length in the batch. - max_seq_len: Optional[int] + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int # (batch_size + 1,). The cumulative subquery lengths of the sequences in # the batch, used to index into subquery. E.g., if the subquery length # is [4, 6], it is [0, 4, 10]. - subquery_start_loc: Optional[torch.Tensor] + query_start_loc: Optional[torch.Tensor] # FIXME: It is for flash attn. # (batch_size + 1,). The cumulative sequence lengths of the sequences in # the batch, used to index into sequence. E.g., if the sequence length is @@ -105,6 +105,8 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata): # Cuda-graph is currently enabled for decoding only. # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool + _cached_prefill_metadata: Optional["XFormersMetadata"] = None + _cached_decode_metadata: Optional["XFormersMetadata"] = None def __post_init__(self): # Set during the execution of the first attention op. @@ -114,8 +116,68 @@ def __post_init__(self): # will not appear in the __repr__ and __init__ self.attn_bias: Optional[List[AttentionBias]] = None - -class XFormersImpl(AttentionImpl): + @property + def prefill_metadata(self) -> Optional["XFormersMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + assert self.seq_lens is not None + assert self.seq_lens_tensor is not None + assert self.query_start_loc is not None + assert self.context_lens_tensor is not None + assert self.block_tables is not None + + self._cached_prefill_metadata = XFormersMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=self.slot_mapping[:self.num_prefill_tokens], + seq_lens=self.seq_lens[:self.num_prefills], + seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_seq_len=0, + query_start_loc=self.query_start_loc[:self.num_prefills + 1], + seq_start_loc=None, + context_lens_tensor=self.context_lens_tensor[:self.num_prefills], + block_tables=self.block_tables[:self.num_prefills], + use_cuda_graph=False, + ) + return self._cached_prefill_metadata + + @property + def decode_metadata(self) -> Optional["XFormersMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + assert self.block_tables is not None + assert self.seq_lens_tensor is not None + + self._cached_decode_metadata = XFormersMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=self.slot_mapping[self.num_prefill_tokens:], + seq_lens=None, + seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + max_query_len=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=self.block_tables[self.num_prefills:], + use_cuda_graph=self.use_cuda_graph, + ) + return self._cached_decode_metadata + + +class XFormersImpl(AttentionImpl[XFormersMetadata]): """ If the input tensors contain prompt tokens, the layout is as follows: |<--------------- num_prefill_tokens ----------------->| @@ -176,7 +238,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: Optional[torch.Tensor], - attn_metadata: AttentionMetadata[XFormersMetadata], + attn_metadata: "XFormersMetadata", kv_scale: float = 1.0, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. @@ -244,7 +306,7 @@ def forward( key_cache, value_cache, prefill_meta.block_tables, - prefill_meta.subquery_start_loc, + prefill_meta.query_start_loc, prefill_meta.seq_lens_tensor, prefill_meta.context_lens_tensor, prefill_meta.max_query_len, @@ -261,7 +323,7 @@ def forward( value_cache, decode_meta.block_tables, decode_meta.seq_lens_tensor, - decode_meta.max_seq_len, + decode_meta.max_decode_seq_len, self.kv_cache_dtype, self.num_kv_heads, self.scale, diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 8a872dba8c877..126692d8c9b40 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -4,8 +4,7 @@ import torch import torch.nn as nn -from vllm.attention.backends.abstract import (AttentionMetadata, - AttentionMetadataPerStage) +from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.selector import get_attn_backend from vllm.config import CacheConfig @@ -57,7 +56,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: Optional[torch.Tensor], - attn_metadata: AttentionMetadata[AttentionMetadataPerStage], + attn_metadata: AttentionMetadata, kv_scale: float = 1.0, ) -> torch.Tensor: return self.impl.forward(query, key, value, kv_cache, attn_metadata, diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 3c010b67b3120..30feaa4da254d 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -16,8 +16,8 @@ class PagedAttentionMetadata: # (batch_size,). The length of sequences (entire tokens seen so far) per # sequence. seq_lens_tensor: Optional[torch.Tensor] - # Maximum sequence length in the batch. - max_seq_len: Optional[int] + # Maximum sequence length in the batch. 0 if it is prefill-only batch. + max_decode_seq_len: int # (batch_size, max_blocks_per_seq). # Block addresses per sequence. (Seq id -> list of physical block) # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks @@ -166,7 +166,7 @@ def forward_prefix( key_cache: torch.Tensor, value_cache: torch.Tensor, block_tables: torch.Tensor, - subquery_start_loc: torch.Tensor, + query_start_loc: torch.Tensor, seq_lens_tensor: torch.Tensor, context_lens: torch.Tensor, max_query_len: int, @@ -182,8 +182,8 @@ def forward_prefix( key_cache, value_cache, block_tables, - # subquery_start_loc is (batch_size + 1,) - subquery_start_loc[:-1], + # query_start_loc is (batch_size + 1,) + query_start_loc[:-1], seq_lens_tensor, context_lens, max_query_len, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 195d9e1b33e3c..bd44c2470182b 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -618,6 +618,11 @@ def create_engine_config(self, ) -> EngineConfig: decoding_config = DecodingConfig( guided_decoding_backend=self.guided_decoding_backend) + if (model_config.get_sliding_window() is not None + and scheduler_config.chunked_prefill_enabled): + raise ValueError( + "Chunked prefill is not supported with sliding window.") + return EngineConfig(model_config=model_config, cache_config=cache_config, parallel_config=parallel_config, diff --git a/vllm/model_executor/layers/rejection_sampler.py b/vllm/model_executor/layers/rejection_sampler.py index b5f1e55d0e839..1f2ab7e2870ca 100644 --- a/vllm/model_executor/layers/rejection_sampler.py +++ b/vllm/model_executor/layers/rejection_sampler.py @@ -122,6 +122,7 @@ def forward( draft_token_ids, bonus_token_ids, ) + return output_token_ids def _batch_modified_rejection_sampling( diff --git a/vllm/sequence.py b/vllm/sequence.py index 12e930c27173e..aa759448d82b1 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -654,8 +654,9 @@ def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 @property - def token_chunk_size(self) -> Optional[int]: + def token_chunk_size(self) -> int: """Return the number of tokens to be processed (chunk size).""" + assert self._token_chunk_size is not None return self._token_chunk_size diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index d5fd96907ddd7..7792f3a3425cc 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -293,21 +293,30 @@ def _create_single_target_seq_group_metadata( prompt_token_ids = seq_data.get_prompt_token_ids() new_output_token_ids = [*seq_data.get_output_token_ids(), *token_ids] + new_seq_data_dict = { + target_seq_id: + SequenceData( + prompt_token_ids=prompt_token_ids, + output_token_ids=new_output_token_ids, + ), + } + # This is a hack. Technically, spec decoding should compute + # num_lookahead slots at one shot, but instead, it expands the batch + # and evaluate one by one right now. context_len is seq_len - 1 because + # the kv cache is filled by a previous batch in the batch expansion. + for data in new_seq_data_dict.values(): + data.update_num_computed_tokens(data.get_len() - 1) + return SequenceGroupMetadata( request_id=seq_group_metadata.request_id, is_prompt=seq_group_metadata.is_prompt, - seq_data={ - target_seq_id: - SequenceData( - prompt_token_ids=prompt_token_ids, - output_token_ids=new_output_token_ids, - ), - }, + seq_data=new_seq_data_dict, sampling_params=seq_group_metadata.sampling_params, block_tables={ target_seq_id: seq_group_metadata.block_tables[seq_id], }, lora_request=None, + token_chunk_size=1, ) def _split_scoring_output( diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index 20098ebaeea32..b5a805278d273 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -114,6 +114,7 @@ def _append_new_tokens( token_logprob = seq_output.logprobs[token_id] seq.append_token_id(token_id, token_logprob.logprob) + seq.update_num_computed_tokens(1) def _shallow_copy_inputs( self, seq_group_metadata_list: List[SequenceGroupMetadata] diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 0a0b0d70cfe21..bc88f2c5bed6c 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -159,12 +159,10 @@ def _prepare_prompt( is_prompt=True, seq_lens=seq_lens, seq_lens_tensor=None, - max_seq_len=None, + max_decode_seq_len=None, num_prefills=len(seq_lens), num_prefill_tokens=num_prompt_tokens, num_decode_tokens=0, - prefill_metadata=None, - decode_metadata=None, block_tables=torch.tensor([]), slot_mapping=slot_mapping, ) @@ -213,7 +211,7 @@ def _prepare_decode( block_table = block_table[-sliding_window_blocks:] block_tables.append(block_table) - max_seq_len = max(seq_lens) + max_decode_seq_len = max(seq_lens) input_tokens = torch.tensor(input_tokens, dtype=torch.long, @@ -243,12 +241,10 @@ def _prepare_decode( slot_mapping=slot_mapping, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, - max_seq_len=max_seq_len, + max_decode_seq_len=max_decode_seq_len, num_prefill_tokens=0, num_decode_tokens=len(input_tokens), num_prefills=0, - prefill_metadata=None, - decode_metadata=None, block_tables=block_tables, ) return ( diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index d04bebbdc31b6..91f30978ead87 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -13,7 +13,7 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.pooling_params import PoolingParams from vllm.sequence import PoolerOutput, SequenceData, SequenceGroupMetadata -from vllm.worker.model_runner import BatchType, ModelRunner +from vllm.worker.model_runner import ModelRunner logger = init_logger(__name__) @@ -88,85 +88,24 @@ def prepare_input_tensors( ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, PoolingMetadata, Set[LoRARequest], LoRAMapping, torch.Tensor]: if self.is_driver_worker: - prefill_reqs = [] - decode_reqs = [] - for seq_group_meta in seq_group_metadata_list: - if seq_group_meta.is_prompt: - prefill_reqs.append(seq_group_meta) - else: - decode_reqs.append(seq_group_meta) - # Prepare input tensors. ( input_tokens, input_positions, - prefill_attn_metadata, - prompt_lens, - subquery_lens, - lora_index_mapping, - lora_prompt_mapping, + attn_metadata, + seq_lens, + _, + lora_mapping, lora_requests, multi_modal_input, slot_mapping, - ) = self._prepare_prompt(prefill_reqs) - ( - decode_input_tokens, - decode_input_positions, - decode_attn_metadata, - decode_lora_index_mapping, - decode_lora_prompt_mapping, - decode_lora_requests, - decode_slot_mapping, - ) = self._prepare_decode(decode_reqs) - + num_prefill_tokens, + num_decode_tokens, + num_prefills, + ) = self._prepare_model_input(seq_group_metadata_list) # Prepare PoolingMetadata pooling_metadata = self._prepare_pooling(seq_group_metadata_list, - prompt_lens) - - if not self.scheduler_config.chunked_prefill_enabled: - assert (len(prefill_reqs) and len(decode_reqs)) == 0 - - num_prefills = len(prompt_lens) - num_prefill_tokens = len(input_tokens) - num_decode_tokens = len(decode_input_tokens) - - # Coalesce tensors. Note that attn_metadata is currently not - # coalesced for simplicity. - input_tokens.extend(decode_input_tokens) - input_positions.extend(decode_input_positions) - slot_mapping.extend(decode_slot_mapping) - lora_index_mapping.extend(decode_lora_index_mapping) - lora_prompt_mapping.extend(decode_lora_prompt_mapping) - lora_requests.update(decode_lora_requests) - - input_tokens = torch.tensor(input_tokens, - dtype=torch.long, - device=self.device) - input_positions = torch.tensor(input_positions, - dtype=torch.long, - device=self.device) - slot_mapping = torch.tensor(slot_mapping, - dtype=torch.long, - device=self.device) - - if self.lora_config: - lora_mapping = LoRAMapping( - lora_index_mapping, - lora_prompt_mapping, - ) - else: - lora_mapping = None - - # Broadcast the metadata. - # If batch contains both prefill and decode, it sends 2 broadcasts. - # If it only contains 1 type, it triggers a single broadcast. - if (prefill_attn_metadata is not None - and decode_attn_metadata is not None): - batch_type = BatchType.MIXED - elif prefill_attn_metadata is not None: - batch_type = BatchType.PREFILL - else: - batch_type = BatchType.DECODE + seq_lens) metadata_dict = { "input_tokens": input_tokens, @@ -178,65 +117,26 @@ def prepare_input_tensors( "num_decode_tokens": num_decode_tokens, "slot_mapping": slot_mapping, "num_prefills": num_prefills, - "batch_type": batch_type, } - if prefill_attn_metadata is not None: - metadata_dict.update(prefill_attn_metadata.asdict_zerocopy()) - else: - assert decode_attn_metadata is not None - metadata_dict.update(decode_attn_metadata.asdict_zerocopy()) + if attn_metadata: + metadata_dict.update(attn_metadata.asdict_zerocopy()) broadcast_tensor_dict(metadata_dict, src=0) - - # Broadcast decode attn metadata for mixed batch type. - # The additional broadcast costs 300us overhead on 4 A10 GPUs. - # We can potentially reduce the overhead by coelescing tensors. - if batch_type == BatchType.MIXED: - assert decode_attn_metadata is not None - metadata_dict = decode_attn_metadata.asdict_zerocopy() - broadcast_tensor_dict(metadata_dict, src=0) else: metadata_dict = broadcast_tensor_dict(src=0) input_tokens = metadata_dict.pop("input_tokens") input_positions = metadata_dict.pop("input_positions") - slot_mapping = metadata_dict.pop("slot_mapping") - num_prefills = metadata_dict.pop("num_prefills") lora_mapping = metadata_dict.pop("lora_mapping") lora_requests = metadata_dict.pop("lora_requests") multi_modal_input = metadata_dict.pop("multi_modal_input") - num_prefill_tokens = metadata_dict.pop("num_prefill_tokens") - num_decode_tokens = metadata_dict.pop("num_decode_tokens") - batch_type = metadata_dict.pop("batch_type") - - # Create an attention metadata. - prefill_attn_metadata = None - decode_attn_metadata = None - if batch_type == BatchType.PREFILL or batch_type == BatchType.MIXED: - prefill_attn_metadata = self.attn_backend.make_metadata( + if metadata_dict: + attn_metadata = self.attn_backend.make_metadata( **metadata_dict) else: - decode_attn_metadata = self.attn_backend.make_metadata( - **metadata_dict) - + attn_metadata = None pooling_metadata = PoolingMetadata(seq_groups=None, seq_data=None, prompt_lens=None) - # if it is a mixed batch, decode attn_metadata is broadcasted - # separately. - if batch_type == BatchType.MIXED: - metadata_dict = broadcast_tensor_dict(src=0) - decode_attn_metadata = self.attn_backend.make_metadata( - **metadata_dict) - - attn_metadata = AttentionMetadata( - num_prefills=num_prefills, - slot_mapping=slot_mapping, - num_prefill_tokens=num_prefill_tokens, - num_decode_tokens=num_decode_tokens, - prefill_metadata=prefill_attn_metadata, - decode_metadata=decode_attn_metadata, - ) - return (input_tokens, input_positions, attn_metadata, pooling_metadata, lora_requests, lora_mapping, multi_modal_input) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index b5e1991717b13..dcdd4b962454e 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1,13 +1,11 @@ import time -from enum import IntEnum from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union import numpy as np import torch import torch.nn as nn -from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage, - get_attn_backend) +from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) @@ -37,66 +35,38 @@ ] -class PreparePromptMetadata(NamedTuple): - input_tokens: List[int] - input_positions: List[int] - attn_metadata: Optional[AttentionMetadataPerStage] +class ModelInput(NamedTuple): + input_tokens: torch.Tensor + input_positions: torch.Tensor + attn_metadata: Optional[AttentionMetadata] seq_lens: List[int] query_lens: List[int] - lora_index_mapping: List[int] - lora_prompt_mapping: List[int] + lora_mapping: Optional[LoRAMapping] lora_requests: Set[LoRARequest] multi_modal_input: Optional[torch.Tensor] - slot_mapping: List[int] + slot_mapping: torch.Tensor + num_prefill_tokens: int + num_decode_tokens: int + num_prefills: int @classmethod - def empty(cls): - return PreparePromptMetadata( - input_tokens=[], - input_positions=[], + def empty(cls, device): + return ModelInput( + input_tokens=torch.empty(0, device=device), + input_positions=torch.empty(0, device=device), attn_metadata=None, seq_lens=[], query_lens=[], - lora_index_mapping=[], - lora_prompt_mapping=[], + lora_mapping=None, lora_requests=set(), multi_modal_input=None, - slot_mapping=[], - ) - - -class PrepareDecodeMetadata(NamedTuple): - input_tokens: List[int] - input_positions: List[int] - attn_metadata: Optional[AttentionMetadata] - lora_index_mapping: List[int] - lora_prompt_mapping: List[int] - lora_requests: Set[LoRARequest] - slot_mapping: List[int] - - @classmethod - def empty(cls): - return PrepareDecodeMetadata( - input_tokens=[], - input_positions=[], - attn_metadata=None, - lora_index_mapping=[], - lora_prompt_mapping=[], - lora_requests=set(), - slot_mapping=[], + slot_mapping=torch.empty(0, device=device), + num_prefill_tokens=0, + num_decode_tokens=0, + num_prefills=0, ) -# How batches are constructed. -class BatchType(IntEnum): - # Every batch is prefill. - PREFILL = 0 - # Every batch is decode. - DECODE = 1 - # Batch is a mixture of prefill and decode. - MIXED = 2 - - class ModelRunner: def __init__( @@ -216,10 +186,22 @@ def get_max_block_per_batch(self) -> int: block_size = self.block_size return (self.max_seq_len_to_capture + block_size - 1) // block_size - def _prepare_prompt( + def _prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> PreparePromptMetadata: + ) -> ModelInput: + """Prepare the model input based on a given sequence group. + + The API assumes seq_group_metadata_list is sorted by prefill -> decode. + + The result tensors and data structure also batches input in prefill + -> decode order. For example, + + - input_tokens[:num_prefill_tokens] contains prefill tokens. + - input_tokens[num_prefill_tokens:] contains decode tokens. + + If cuda graph is required, this API automatically pads inputs. + """ input_tokens: List[int] = [] input_positions: List[int] = [] slot_mapping: List[int] = [] @@ -228,212 +210,16 @@ def _prepare_prompt( lora_requests: Set[LoRARequest] = set() seq_lens: List[int] = [] + prefill_seq_lens: List[int] = [] + decode_seq_lens: List[int] = [] context_lens: List[int] = [] query_lens: List[int] = [] - prefix_block_tables: List[List[int]] = [] - multi_modal_input_list: List[torch.Tensor] = [] - - if len(seq_group_metadata_list) == 0: - return PreparePromptMetadata.empty() - - for seq_group_metadata in seq_group_metadata_list: - assert seq_group_metadata.is_prompt - seq_ids = list(seq_group_metadata.seq_data.keys()) - assert len(seq_ids) == 1 - seq_id = seq_ids[0] - - computed_block_nums = seq_group_metadata.computed_block_nums - if (self.scheduler_config is not None - and self.scheduler_config.chunked_prefill_enabled - and not (computed_block_nums is None - or computed_block_nums == [])): - raise RuntimeError( - "chunked prefill cannot be used with prefix caching " - "now.") - - token_chunk_size = seq_group_metadata.token_chunk_size - seq_data = seq_group_metadata.seq_data[seq_id] - context_len = seq_data.get_num_computed_tokens() - # We should use get_len here because in case of preemption - # it contains output tokens. - seq_len = min(seq_data.get_len(), context_len + token_chunk_size) - prompt_tokens = seq_data.get_token_ids()[context_len:seq_len] - seq_lens.append(seq_len) - - # NOTE: This only works for oooooooxxx style attention. - if computed_block_nums is not None and len( - computed_block_nums) > 0 and self.sliding_window is None: - # Prefix is not supported with sliding_window - context_len = len(computed_block_nums) * self.block_size - prompt_tokens = prompt_tokens[context_len:] - prefix_block_tables.append(computed_block_nums) - elif self.scheduler_config.chunked_prefill_enabled: - if seq_group_metadata.block_tables is not None: - # Prefill has chunked before. - block_table = seq_group_metadata.block_tables[seq_id] - prefix_block_tables.append(block_table) - else: - # The first prefill. - prefix_block_tables.append([]) - else: - prefix_block_tables.append([]) - # Right now, prefill start is always 0. However, this - # assumption can be changed once chunked prefill is introduced. - assert context_len == 0 - - # actual prompt lens - context_lens.append(context_len) - query_lens.append(seq_len - context_len) - - input_tokens.extend(prompt_tokens) - # NOTE(woosuk): Here we assume that the first token in the prompt - # is always the first token in the sequence. - input_positions.extend(list(range(context_len, seq_len))) - lora_id = seq_group_metadata.lora_int_id - - if lora_id > 0: - lora_requests.add(seq_group_metadata.lora_request) - - lora_index_mapping += [lora_id] * (seq_len - context_len) - lora_prompt_mapping.extend([lora_id] * ( - seq_len - context_len if seq_group_metadata.sampling_params - and seq_group_metadata.sampling_params.prompt_logprobs else 1)) - - if seq_group_metadata.multi_modal_data: - multi_modal_input_list.append( - seq_group_metadata.multi_modal_data.data) - - if _is_block_tables_empty(seq_group_metadata.block_tables): - # During memory profiling, the block tables are not initialized - # yet. In this case, we just use a dummy slot mapping. - # In embeddings, the block tables are {seq_id: None}. - slot_mapping.extend([_PAD_SLOT_ID] * seq_len) - continue - - # Compute the slot mapping. - block_table = seq_group_metadata.block_tables[seq_id] - - # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, - # where start_idx is max(0, seq_len - sliding_window). - # For example, if the prompt len is 10, sliding window is 8, and - # block size is 4, the first two tokens are masked and the slot - # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. - start_idx = 0 - if self.sliding_window is not None: - assert context_len == 0, ( - "Prefix caching is currently not supported with " - "sliding window attention") - start_idx = max(0, seq_len - self.sliding_window) - - for i in range(context_len, seq_len): - if i < start_idx: - slot_mapping.append(_PAD_SLOT_ID) - continue - - block_number = block_table[i // self.block_size] - block_offset = i % self.block_size - slot = block_number * self.block_size + block_offset - slot_mapping.append(slot) - - max_query_len = max(query_lens) - max_seq_len = max(seq_lens) - assert max_query_len > 0 - - context_lens_tensor = torch.tensor(context_lens, - dtype=torch.int, - device=self.device) - - if multi_modal_input_list: - assert self.vision_language_config, ( - "Multi-modal inputs are only supported by " - "vision language models.") - multi_modal_input = torch.cat(multi_modal_input_list, - dim=0).to(self.device) - else: - multi_modal_input = None - - # Prepare prefix block tables - max_prompt_block_table_len = max(len(t) for t in prefix_block_tables) - block_tables = make_tensor_with_pad( - prefix_block_tables, - max_len=max_prompt_block_table_len, - pad=0, - dtype=torch.int, - device=self.device, - ) - - # Query length can be shorter than key (i.e., prompt) when prefill - # is chunked or prefix cached. - query_lens_tensor = torch.tensor(query_lens, - dtype=torch.long, - device=self.device) - subquery_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=self.device) - - seq_lens_tensor = torch.tensor(seq_lens, - dtype=torch.int, - device=self.device) - seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=self.device) - - torch.cumsum(query_lens_tensor, - dim=0, - dtype=subquery_start_loc.dtype, - out=subquery_start_loc[1:]) - - torch.cumsum(seq_lens_tensor, - dim=0, - dtype=seq_start_loc.dtype, - out=seq_start_loc[1:]) - - if self.attn_backend.get_name() == "flashinfer": - attn_metadata = self.attn_backend.make_metadata( - is_prompt=True, - use_cuda_graph=False, - seq_start_loc=seq_start_loc, - max_seq_len=max_seq_len, - block_tables=block_tables) - else: - attn_metadata = self.attn_backend.make_metadata( - is_prompt=True, - seq_lens=seq_lens, - seq_lens_tensor=seq_lens_tensor, - max_query_len=max_query_len, - max_seq_len=max_seq_len, - subquery_start_loc=subquery_start_loc, - seq_start_loc=seq_start_loc, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - use_cuda_graph=False, - ) - - return PreparePromptMetadata( - input_tokens=input_tokens, - input_positions=input_positions, - attn_metadata=attn_metadata, - seq_lens=seq_lens, - query_lens=query_lens, - lora_index_mapping=lora_index_mapping, - lora_prompt_mapping=lora_prompt_mapping, - lora_requests=lora_requests, - multi_modal_input=multi_modal_input, - slot_mapping=slot_mapping, - ) - - def _prepare_decode( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> PrepareDecodeMetadata: - input_tokens: List[int] = [] - input_positions: List[int] = [] - slot_mapping: List[int] = [] - seq_lens: List[int] = [] block_tables: List[List[int]] = [] - lora_index_mapping: List[int] = [] - lora_prompt_mapping: List[int] = [] - lora_requests: Set[LoRARequest] = set() + multi_modal_input_list: List[torch.Tensor] = [] + decode_only = True + num_prefills = 0 + num_prefill_tokens = 0 + num_decode_tokens = 0 # The following fields are only for flashinfer # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout @@ -454,60 +240,186 @@ def _prepare_decode( paged_kv_last_page_len: List[int] = [] if len(seq_group_metadata_list) == 0: - return PrepareDecodeMetadata.empty() + return ModelInput.empty(self.device) for seq_group_metadata in seq_group_metadata_list: - assert not seq_group_metadata.is_prompt - assert seq_group_metadata.token_chunk_size == 1 - seq_ids = list(seq_group_metadata.seq_data.keys()) - lora_id = seq_group_metadata.lora_int_id - - if lora_id > 0: - lora_requests.add(seq_group_metadata.lora_request) + is_prompt = seq_group_metadata.is_prompt for seq_id in seq_ids: + computed_block_nums = seq_group_metadata.computed_block_nums + if (self.scheduler_config is not None + and self.scheduler_config.chunked_prefill_enabled + and not (computed_block_nums is None + or computed_block_nums == [])): + raise RuntimeError( + "chunked prefill cannot be used with prefix caching " + "now.") + seq_data = seq_group_metadata.seq_data[seq_id] - generation_token = seq_data.get_last_token_id() - input_tokens.append(generation_token) + if is_prompt: + context_len = seq_data.get_num_computed_tokens() + else: + # get_num_computed_tokens is incorrect for spec decoding. + # So, we should have a special logic here. + # TODO(sang): Fix it. + context_len = seq_data.get_len() - 1 + + seq_len = min( + seq_data.get_len(), + context_len + seq_group_metadata.token_chunk_size) + if is_prompt: + tokens = seq_data.get_token_ids()[context_len:seq_len] + else: + # Optimization. get_token_ids requires the entire copy of + # tokens. + tokens = [seq_data.get_last_token_id()] + + # Prefix cache was hit. + # Prefix is not supported with sliding_window + prefix_cache_hit = (computed_block_nums is not None + and len(computed_block_nums) > 0 + and self.sliding_window is None + and is_prompt) + + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + if prefix_cache_hit: + assert computed_block_nums is not None + context_len = len(computed_block_nums) * self.block_size + tokens = tokens[context_len:] + if self.attn_backend.get_name() == "flash-attn": + # NOTE(woosuk): For flash-attn, the block table should + # include the entries for the incoming prefill tokens. + # TODO(woosuk): This is a temporary fix. We should + # provide a unified interface for different backends. + block_table = seq_group_metadata.block_tables[seq_id] + else: + block_table = computed_block_nums + elif (self.scheduler_config.chunked_prefill_enabled + or not is_prompt): + if seq_group_metadata.block_tables is not None: + # chunked prefill or decode + block_table = seq_group_metadata.block_tables[seq_id] + if self.sliding_window is not None: + # chunked prefill doesn't support sliding window. + assert (not self.scheduler_config. + chunked_prefill_enabled) + sliding_window_blocks = (self.sliding_window // + self.block_size) + block_table = block_table[-sliding_window_blocks:] + + if self.attn_backend.get_name() == "flashinfer": + paged_kv_indices.extend(block_table) + paged_kv_indptr.append(paged_kv_indptr[-1] + + len(block_table)) + last_page_len = seq_data.get_len( + ) % self.block_size + if last_page_len == 0: + last_page_len = self.block_size + paged_kv_last_page_len.append(last_page_len) + else: + # Only happens when memory profiling runs. + block_table = [] + else: + # Prefill without chunked prefill or memory profiling. + block_table = [] + block_tables.append(block_table) - seq_len = seq_data.get_len() - position = seq_len - 1 - input_positions.append(position) + # TODO(sang): This is a hack to make sliding window work with + # paged attn. We can remove it if we make paged attn kernel + # to properly handle slinding window attn. + if (self.sliding_window is not None and not is_prompt): + seq_len = min(seq_len, self.sliding_window) + context_len = seq_len - 1 - seq_len = seq_len if self.sliding_window is None else min( - seq_len, self.sliding_window) seq_lens.append(seq_len) + context_lens.append(context_len) + query_len = seq_len - context_len + query_lens.append(query_len) + input_tokens.extend(tokens) + input_positions.extend(list(range(context_len, seq_len))) + lora_id = seq_group_metadata.lora_int_id + + if is_prompt: + assert len(seq_ids) == 1 + num_prefills += 1 + num_prefill_tokens += len(tokens) + decode_only = False + prefill_seq_lens.append(seq_len) + else: + assert query_len == 1, ( + "seq_len: {}, context_len: {}, query_len: {}".format( + seq_len, context_len, query_len)) + num_decode_tokens += query_len + decode_seq_lens.append(seq_len) + + if lora_id > 0: + lora_requests.add(seq_group_metadata.lora_request) + + lora_index_mapping += [lora_id] * (seq_len - context_len) + lora_prompt_mapping.extend( + [lora_id] * + (seq_len - + context_len if seq_group_metadata.sampling_params + and seq_group_metadata.sampling_params.prompt_logprobs + else 1)) + + if seq_group_metadata.multi_modal_data: + multi_modal_input_list.append( + seq_group_metadata.multi_modal_data.data) + + if _is_block_tables_empty(seq_group_metadata.block_tables): + # During memory profiling, the block tables are not + # initialized yet. In this case, we just use a dummy + # slot mapping. + # In embeddings, the block tables are {seq_id: None}. + slot_mapping.extend([_PAD_SLOT_ID] * seq_len) + continue + # Compute the slot mapping. block_table = seq_group_metadata.block_tables[seq_id] - block_number = block_table[position // self.block_size] - block_offset = position % self.block_size - slot = block_number * self.block_size + block_offset - slot_mapping.append(slot) - lora_index_mapping.append(lora_id) - lora_prompt_mapping.append(lora_id) + # Mask the [0, start_idx) tokens of the prompt with + # _PAD_SLOT_ID, where start_idx is max(0, seq_len - + # sliding_window). For example, if the prompt len is 10, + # sliding window is 8, and block size is 4, the first two + # tokens are masked and the slot mapping will be + # [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. + start_idx = 0 if self.sliding_window is not None: - sliding_window_blocks = (self.sliding_window // - self.block_size) - block_table = block_table[-sliding_window_blocks:] - block_tables.append(block_table) + if is_prompt: + assert context_len == 0, ( + "Prefix caching is currently not supported with " + "sliding window attention") + # It is an optimization. When it is decoding, it is always + # 0. When prefill, we use it to not write slots to kv cache + # to save memory. + start_idx = max(0, query_len - self.sliding_window) + + for i in range(context_len, seq_len): + if i < start_idx: + slot_mapping.append(_PAD_SLOT_ID) + continue + + block_number = block_table[i // self.block_size] + block_offset = i % self.block_size + slot = block_number * self.block_size + block_offset + slot_mapping.append(slot) - paged_kv_indices.extend(block_table) - paged_kv_indptr.append(paged_kv_indptr[-1] + len(block_table)) - last_page_len = seq_data.get_len() % self.block_size - if last_page_len == 0: - last_page_len = self.block_size - paged_kv_last_page_len.append(last_page_len) + batch_size = len(input_tokens) + max_query_len = max(query_lens) + max_prefill_seq_len = max(prefill_seq_lens, default=0) + max_decode_seq_len = max(decode_seq_lens, default=0) - # vLLM uses cuda graph only for decoding requests. + # If cuda graph can be used, pad tensors accordingly. # See `capture_model` API for more details. - # For decoding requests, batch_size == input_tokens. - batch_size = len(input_tokens) - max_seq_len = max(seq_lens) - use_captured_graph = (not self.model_config.enforce_eager - and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] - and max_seq_len <= self.max_seq_len_to_capture) + # vLLM uses cuda graph only for decoding requests. + use_captured_graph = ( + decode_only and not self.model_config.enforce_eager + and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] + and max_decode_seq_len <= self.max_seq_len_to_capture) if use_captured_graph: graph_batch_size = _get_graph_batch_size(batch_size) assert graph_batch_size >= batch_size @@ -519,18 +431,9 @@ def _prepare_decode( block_tables.append([]) lora_index_mapping.append(0) batch_size = graph_batch_size - - seq_lens_tensor = torch.tensor(seq_lens, - dtype=torch.int, - device=self.device) + num_decode_tokens = batch_size if use_captured_graph: - # When using cuda-graph all these tensors should be - # padded. - assert seq_lens_tensor.shape[0] == len(input_tokens) - assert seq_lens_tensor.shape[0] == len(input_positions) - assert seq_lens_tensor.shape[0] == len(slot_mapping) - # The shape of graph_block_tables is # [max batch size, max context len // block size]. input_block_tables = self.graph_block_tables[:batch_size] @@ -548,6 +451,57 @@ def _prepare_decode( dtype=torch.int, device=self.device, ) + assert max_query_len > 0, ("query_lens: {}".format(query_lens)) + + context_lens_tensor = torch.tensor(context_lens, + dtype=torch.int, + device=self.device) + + if multi_modal_input_list: + assert self.vision_language_config, ( + "Multi-modal inputs are only supported by " + "vision language models.") + multi_modal_input = torch.cat(multi_modal_input_list, + dim=0).to(self.device) + else: + multi_modal_input = None + + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.int, + device=self.device) + query_lens_tensor = torch.tensor(query_lens, + dtype=torch.long, + device=self.device) + query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=self.device) + + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.int, + device=self.device) + seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=self.device) + + torch.cumsum(query_lens_tensor, + dim=0, + dtype=query_start_loc.dtype, + out=query_start_loc[1:]) + + torch.cumsum(seq_lens_tensor, + dim=0, + dtype=seq_start_loc.dtype, + out=seq_start_loc[1:]) + + input_tokens_tensor = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) + input_positions_tensor = torch.tensor(input_positions, + dtype=torch.long, + device=self.device) + slot_mapping_tensor = torch.tensor(slot_mapping, + dtype=torch.long, + device=self.device) if self.attn_backend.get_name() == "flashinfer": if not hasattr(self, "flashinfer_workspace_buffer"): @@ -555,53 +509,75 @@ def _prepare_decode( # Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html self.flashinfer_workspace_buffer = torch.empty( 16 * 1024 * 1024, dtype=torch.uint8, device=self.device) - paged_kv_indptr = torch.tensor(paged_kv_indptr, - dtype=torch.int, - device=self.device) - paged_kv_indices = torch.tensor(paged_kv_indices, - dtype=torch.int, - device=self.device) - paged_kv_last_page_len = torch.tensor(paged_kv_last_page_len, + paged_kv_indptr_tensor = torch.tensor(paged_kv_indptr, dtype=torch.int, device=self.device) + paged_kv_indices_tensor = torch.tensor(paged_kv_indices, + dtype=torch.int, + device=self.device) + paged_kv_last_page_len_tensor = torch.tensor( + paged_kv_last_page_len, dtype=torch.int, device=self.device) kv_cache_dtype = get_kv_cache_torch_dtype(self.kv_cache_dtype, self.model_config.dtype) - attn_metadata = self.attn_backend.make_metadata( - is_prompt=False, + num_prefills=num_prefills, + slot_mapping=slot_mapping_tensor, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=num_decode_tokens, use_cuda_graph=False, + max_prefill_seq_len=max_prefill_seq_len, + block_tables=block_tables, workspace_buffer=self.flashinfer_workspace_buffer, - paged_kv_indptr=paged_kv_indptr, - paged_kv_indices=paged_kv_indices, - paged_kv_last_page_len=paged_kv_last_page_len, + paged_kv_indptr=paged_kv_indptr_tensor, + paged_kv_indices=paged_kv_indices_tensor, + paged_kv_last_page_len=paged_kv_last_page_len_tensor, num_qo_heads=self.model_config.get_num_attention_heads( self.parallel_config), num_kv_heads=self.model_config.get_num_kv_heads( self.parallel_config), head_dim=self.model_config.get_head_size(), - page_size=self.block_size, + page_size=16, + seq_start_loc=seq_start_loc, data_type=kv_cache_dtype) else: attn_metadata = self.attn_backend.make_metadata( - is_prompt=False, - seq_lens=None, + num_prefills=num_prefills, + slot_mapping=slot_mapping_tensor, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, - max_query_len=None, - max_seq_len=max_seq_len, - subquery_start_loc=None, - seq_start_loc=None, - context_lens_tensor=None, + max_query_len=max_query_len, + max_prefill_seq_len=max_prefill_seq_len, + max_decode_seq_len=max_decode_seq_len, + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=use_captured_graph, ) - return PrepareDecodeMetadata( - input_tokens=input_tokens, - input_positions=input_positions, + + if self.lora_config: + lora_mapping = LoRAMapping( + lora_index_mapping, + lora_prompt_mapping, + ) + else: + lora_mapping = None + + return ModelInput( + input_tokens=input_tokens_tensor, + input_positions=input_positions_tensor, attn_metadata=attn_metadata, - lora_index_mapping=lora_index_mapping, - lora_prompt_mapping=lora_prompt_mapping, + seq_lens=seq_lens, + query_lens=query_lens, + lora_mapping=lora_mapping, lora_requests=lora_requests, - slot_mapping=slot_mapping, + multi_modal_input=multi_modal_input, + slot_mapping=slot_mapping_tensor, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + num_prefills=num_prefills, ) def prepare_input_tensors( @@ -610,85 +586,25 @@ def prepare_input_tensors( ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata, Set[LoRARequest], LoRAMapping, torch.Tensor]: if self.is_driver_worker: - prefill_reqs = [] - decode_reqs = [] - for seq_group_meta in seq_group_metadata_list: - if seq_group_meta.is_prompt: - prefill_reqs.append(seq_group_meta) - else: - decode_reqs.append(seq_group_meta) - # Prepare input tensors. ( input_tokens, input_positions, - prefill_attn_metadata, + attn_metadata, seq_lens, query_lens, - lora_index_mapping, - lora_prompt_mapping, + lora_mapping, lora_requests, multi_modal_input, slot_mapping, - ) = self._prepare_prompt(prefill_reqs) - ( - decode_input_tokens, - decode_input_positions, - decode_attn_metadata, - decode_lora_index_mapping, - decode_lora_prompt_mapping, - decode_lora_requests, - decode_slot_mapping, - ) = self._prepare_decode(decode_reqs) + num_prefill_tokens, + num_decode_tokens, + num_prefills, + ) = self._prepare_model_input(seq_group_metadata_list) sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, seq_lens, query_lens, self.device, self.pin_memory) - if not self.scheduler_config.chunked_prefill_enabled: - assert (len(prefill_reqs) and len(decode_reqs)) == 0 - - num_prefills = len(seq_lens) - num_prefill_tokens = len(input_tokens) - num_decode_tokens = len(decode_input_tokens) - - # Coalesce tensors. Note that attn_metadata is currently not - # coalesced for simplicity. - input_tokens.extend(decode_input_tokens) - input_positions.extend(decode_input_positions) - slot_mapping.extend(decode_slot_mapping) - lora_index_mapping.extend(decode_lora_index_mapping) - lora_prompt_mapping.extend(decode_lora_prompt_mapping) - lora_requests.update(decode_lora_requests) - - input_tokens = torch.tensor(input_tokens, - dtype=torch.long, - device=self.device) - input_positions = torch.tensor(input_positions, - dtype=torch.long, - device=self.device) - slot_mapping = torch.tensor(slot_mapping, - dtype=torch.long, - device=self.device) - - if self.lora_config: - lora_mapping = LoRAMapping( - lora_index_mapping, - lora_prompt_mapping, - ) - else: - lora_mapping = None - - # Broadcast the metadata. - # If batch contains both prefill and decode, it sends 2 broadcasts. - # If it only contains 1 type, it triggers a single broadcast. - if (prefill_attn_metadata is not None - and decode_attn_metadata is not None): - batch_type = BatchType.MIXED - elif prefill_attn_metadata is not None: - batch_type = BatchType.PREFILL - else: - batch_type = BatchType.DECODE - metadata_dict = { "input_tokens": input_tokens, "input_positions": input_positions, @@ -701,46 +617,24 @@ def prepare_input_tensors( "num_decode_tokens": num_decode_tokens, "slot_mapping": slot_mapping, "num_prefills": num_prefills, - "batch_type": batch_type, } - if prefill_attn_metadata is not None: - metadata_dict.update(prefill_attn_metadata.asdict_zerocopy()) - else: - assert decode_attn_metadata is not None - metadata_dict.update(decode_attn_metadata.asdict_zerocopy()) + if attn_metadata: + metadata_dict.update(attn_metadata.asdict_zerocopy()) broadcast_tensor_dict(metadata_dict, src=0) - - # Broadcast decode attn metadata for mixed batch type. - # The additional broadcast costs 300us overhead on 4 A10 GPUs. - # We can potentially reduce the overhead by coelescing tensors. - if batch_type == BatchType.MIXED: - assert decode_attn_metadata is not None - metadata_dict = decode_attn_metadata.asdict_zerocopy() - broadcast_tensor_dict(metadata_dict, src=0) else: metadata_dict = broadcast_tensor_dict(src=0) input_tokens = metadata_dict.pop("input_tokens") input_positions = metadata_dict.pop("input_positions") - slot_mapping = metadata_dict.pop("slot_mapping") - num_prefills = metadata_dict.pop("num_prefills") selected_token_indices = metadata_dict.pop( "selected_token_indices") lora_mapping = metadata_dict.pop("lora_mapping") lora_requests = metadata_dict.pop("lora_requests") multi_modal_input = metadata_dict.pop("multi_modal_input") - num_prefill_tokens = metadata_dict.pop("num_prefill_tokens") - num_decode_tokens = metadata_dict.pop("num_decode_tokens") - batch_type = metadata_dict.pop("batch_type") - - # Create an attention metadata. - prefill_attn_metadata = None - decode_attn_metadata = None - if batch_type == BatchType.PREFILL or batch_type == BatchType.MIXED: - prefill_attn_metadata = self.attn_backend.make_metadata( + if metadata_dict: + attn_metadata = self.attn_backend.make_metadata( **metadata_dict) else: - decode_attn_metadata = self.attn_backend.make_metadata( - **metadata_dict) + attn_metadata = None sampling_metadata = SamplingMetadata( seq_groups=None, selected_token_indices=selected_token_indices, @@ -748,22 +642,6 @@ def prepare_input_tensors( num_prompts=0, ) - # if it is a mixed batch, decode attn_metadata is broadcasted - # separately. - if batch_type == BatchType.MIXED: - metadata_dict = broadcast_tensor_dict(src=0) - decode_attn_metadata = self.attn_backend.make_metadata( - **metadata_dict) - - attn_metadata = AttentionMetadata( - num_prefills=num_prefills, - slot_mapping=slot_mapping, - num_prefill_tokens=num_prefill_tokens, - num_decode_tokens=num_decode_tokens, - prefill_metadata=prefill_attn_metadata, - decode_metadata=decode_attn_metadata, - ) - return (input_tokens, input_positions, attn_metadata, sampling_metadata, lora_requests, lora_mapping, multi_modal_input) @@ -954,26 +832,22 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: # memory usage of CUDA graph. for batch_size in reversed(batch_size_capture_list): # Create dummy attn_metadata. - decode_metadata = self.attn_backend.make_metadata( - is_prompt=False, + attn_metadata = self.attn_backend.make_metadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=batch_size, + slot_mapping=slot_mapping[:batch_size], seq_lens=None, seq_lens_tensor=seq_lens[:batch_size], max_query_len=None, - max_seq_len=self.max_seq_len_to_capture, - subquery_start_loc=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_seq_len_to_capture, + query_start_loc=None, seq_start_loc=None, context_lens_tensor=None, block_tables=block_tables[:batch_size], use_cuda_graph=True, ) - attn_metadata = AttentionMetadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=batch_size, - slot_mapping=slot_mapping[:batch_size], - prefill_metadata=None, - decode_metadata=decode_metadata, - ) if self.lora_config: lora_mapping = LoRAMapping(