Skip to content

Commit

Permalink
[Core][2/N] Model runner refactoring part 2. Combine prepare prefill …
Browse files Browse the repository at this point in the history
…/ decode to a single API (vllm-project#4681)

This PR combines prepare_prompt and prepare_decode into a single API. This PR also coelsce the attn metadata for prefill/decode to a single class and allow to slice them when running attn backend.

It also refactors subquery_start_loc which was not refactored in the previous PR
  • Loading branch information
rkooo567 authored and tjohnson31415 committed May 16, 2024
1 parent a69f3af commit 71cd938
Show file tree
Hide file tree
Showing 18 changed files with 777 additions and 730 deletions.
123 changes: 84 additions & 39 deletions tests/worker/test_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,19 +58,25 @@ def test_prepare_prompt(batch_size):
expected_selected_token_indices.append(selected_token_start_idx +
seq_len - 1)
selected_token_start_idx += seq_len
(input_tokens, input_positions, attn_metadata, return_seq_lens, _, _, _, _,
_, slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list))
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
input_tokens = model_input.input_tokens
input_positions = model_input.input_positions
attn_metadata = model_input.attn_metadata
return_seq_lens = model_input.seq_lens
slot_mapping = model_input.slot_mapping
assert return_seq_lens == seq_lens
assert len(slot_mapping) == len(input_tokens)

# Verify input metadata is correct for prompts.
device = model_runner.device
assert attn_metadata.is_prompt is True
assert attn_metadata.num_prefills > 0
assert attn_metadata.num_decode_tokens == 0
assert torch.allclose(
attn_metadata.seq_lens_tensor,
torch.tensor(seq_lens, device=device, dtype=torch.int))
assert attn_metadata.seq_lens == seq_lens
assert attn_metadata.max_seq_len == max(seq_lens)
assert attn_metadata.max_prefill_seq_len == max(seq_lens)
assert attn_metadata.max_decode_seq_len == 0

# Test subquery start locs.
start_idx = 0
Expand All @@ -79,11 +85,11 @@ def test_prepare_prompt(batch_size):
start_idx += seq_len
start_loc.append(start_idx)
assert torch.allclose(
attn_metadata.subquery_start_loc,
attn_metadata.query_start_loc,
torch.tensor(start_loc, dtype=torch.int32, device=device))

# Test seq start locs. Note that for normal prefill it is
# equivalent to subquery_start_loc.
# equivalent to query_start_loc.
start_idx = 0
seq_start_loc = [start_idx]
for seq_len in seq_lens:
Expand Down Expand Up @@ -123,7 +129,7 @@ def test_prepare_prompt(batch_size):
device=actual.device,
dtype=actual.dtype)
torch.testing.assert_close(actual, expected)
assert input_tokens == input_positions
torch.allclose(input_tokens, input_positions)

actual = sampling_metadata.selected_token_indices
expected = torch.tensor(expected_selected_token_indices,
Expand All @@ -144,14 +150,18 @@ def test_prepare_decode_cuda_graph(batch_size):
enable_chunked_prefill=False,
)

seq_lens = []
context_lens = []
seq_group_metadata_list = []
# Assume each seq group finishes prefill.
for i in range(batch_size):
# make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1
seq_lens.append(seq_len)
seq_data = list(range(seq_len))
context_len = i % (model_runner.block_size - 1) + 1
context_lens.append(context_len)
seq_data = list(range(context_len))
seq_data = SequenceData(seq_data)
seq_data.update_num_computed_tokens(context_len)
# Append one token ID since prefill is finished.
seq_data.append_token_id(1, 0)
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=False,
Expand All @@ -162,18 +172,45 @@ def test_prepare_decode_cuda_graph(batch_size):
assert seq_group_metadata.token_chunk_size == 1
seq_group_metadata_list.append(seq_group_metadata)

input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = (
model_runner._prepare_decode(seq_group_metadata_list))
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
input_tokens, input_positions, attn_metadata, slot_mapping = (
model_input.input_tokens, model_input.input_positions,
model_input.attn_metadata, model_input.slot_mapping)
assert len(slot_mapping) == len(input_tokens)

expected_bs = _get_graph_batch_size(len(seq_group_metadata_list))
# Verify input metadata is correct for prompts.
device = model_runner.device
assert attn_metadata.is_prompt is False
assert attn_metadata.seq_lens is None
assert attn_metadata.subquery_start_loc is None
assert attn_metadata.seq_start_loc is None
assert attn_metadata.max_seq_len == max(seq_lens)
assert attn_metadata.num_prefills == 0
assert attn_metadata.num_prefill_tokens == 0
seq_lens = [context_len + 1 for context_len in context_lens]
# seq_lens are padded to expected_bs
for _ in range(expected_bs - len(seq_lens)):
seq_lens.append(1)
assert attn_metadata.seq_lens == seq_lens
start_idx = 0
start_loc = [start_idx]
for _ in context_lens:
# decode has only 1 token for query.
start_idx += 1
start_loc.append(start_idx)
assert torch.allclose(
attn_metadata.query_start_loc,
torch.tensor(start_loc, dtype=torch.int32, device=device))

start_idx = 0
seq_start_loc = [start_idx]
for seq_len in seq_lens:
start_idx += seq_len
seq_start_loc.append(start_idx)
assert torch.allclose(
attn_metadata.seq_start_loc,
torch.tensor(seq_start_loc, dtype=torch.int32, device=device))

assert torch.allclose(
attn_metadata.context_lens_tensor,
torch.tensor(context_lens, dtype=torch.int, device=device))
assert attn_metadata.max_decode_seq_len == max(seq_lens)
assert torch.allclose(
attn_metadata.seq_lens_tensor[:len(seq_lens)],
torch.tensor(seq_lens, dtype=torch.int, device=device))
Expand All @@ -185,23 +222,23 @@ def test_prepare_decode_cuda_graph(batch_size):
# It is padded up to
assert attn_metadata.block_tables.shape[1] == (
model_runner.get_max_block_per_batch())
# Cuda graph should not be used for prerill.
assert attn_metadata.use_cuda_graph is True

assert len(input_tokens) == expected_bs
assert len(input_positions) == expected_bs
assert input_tokens == input_positions
torch.allclose(input_tokens, input_positions)

# Verify Sampling
expected_selected_token_indices = []
selected_token_start_idx = 0
for seq_len in seq_lens:
for _ in context_lens:
expected_selected_token_indices.append(selected_token_start_idx)
selected_token_start_idx += 1
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
seq_lens,
query_lens=seq_lens,
# query lens is all 1 for decode.
query_lens=[1 for _ in range(len(context_lens))],
device=model_runner.device,
pin_memory=model_runner.pin_memory)
actual = sampling_metadata.selected_token_indices
Expand All @@ -220,15 +257,27 @@ def test_empty_seq_group():
enforce_eager=False,
)
seq_group_metadata_list = []
input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = (
model_runner._prepare_decode(seq_group_metadata_list))
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
input_tokens, input_positions, attn_metadata, slot_mapping = (
model_input.input_tokens,
model_input.input_positions,
model_input.attn_metadata,
model_input.slot_mapping,
)
assert len(input_tokens) == 0
assert len(input_positions) == 0
assert attn_metadata is None
assert len(slot_mapping) == 0

(input_tokens, input_positions, attn_metadata, return_seq_lens, _, _, _, _,
_, slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list))
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
(input_tokens, input_positions, attn_metadata, slot_mapping,
return_seq_lens) = (
model_input.input_tokens,
model_input.input_positions,
model_input.attn_metadata,
model_input.slot_mapping,
model_input.seq_lens,
)
assert len(input_tokens) == 0
assert len(input_positions) == 0
assert attn_metadata is None
Expand Down Expand Up @@ -285,9 +334,11 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
# Add decode requests
for i in range(prefill_batch_size, batch_size):
# make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1
prompt_toks = list(range(seq_len))
context_len = i % (model_runner.block_size - 1) + 1
prompt_toks = list(range(context_len))
seq_data = SequenceData(prompt_toks)
seq_data.append_token_id(1, 0)
seq_data.update_num_computed_tokens(context_len)
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=False,
Expand All @@ -308,23 +359,17 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
assert len(attn_metadata.slot_mapping) == len(input_tokens)
assert len(input_positions) == len(input_tokens)
assert attn_metadata.num_prefills == prefill_batch_size
if enforce_eager:
assert attn_metadata.num_decode_tokens == decode_batch_size
else:
assert attn_metadata.num_decode_tokens == _get_graph_batch_size(
decode_batch_size)
assert attn_metadata.num_decode_tokens == decode_batch_size
assert attn_metadata.num_prefill_tokens == sum(seq_lens)

# Verify attn metadata is consistent. We don't need to test individual
# values here because they are tested above.
prefill_meta = model_runner._prepare_prompt(
prefill_metadata_list).attn_metadata
decode_meta = model_runner._prepare_decode(
decode_metadata_list).attn_metadata
attn_metadata = model_runner._prepare_model_input(
seq_group_metadata_list).attn_metadata

for attr_expected, attr_actual in zip(vars(prefill_meta),
for attr_expected, attr_actual in zip(vars(attn_metadata.prefill_metadata),
vars(prefill_meta_actual)):
assert attr_expected[1] == attr_actual[1]
for attr_expected, attr_actual in zip(vars(decode_meta),
for attr_expected, attr_actual in zip(vars(attn_metadata.decode_metadata),
vars(decode_meta_actual)):
assert attr_expected[1] == attr_actual[1]
5 changes: 2 additions & 3 deletions vllm/attention/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata,
AttentionMetadataPerStage)
AttentionMetadata)
from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend

__all__ = [
"Attention",
"AttentionBackend",
"AttentionMetadata",
"AttentionMetadataPerStage",
"Attention",
"get_attn_backend",
]
68 changes: 32 additions & 36 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def get_impl_cls() -> Type["AttentionImpl"]:

@staticmethod
@abstractmethod
def make_metadata(*args, **kwargs) -> "AttentionMetadataPerStage":
def make_metadata(*args, **kwargs) -> "AttentionMetadata":
raise NotImplementedError

@staticmethod
Expand Down Expand Up @@ -53,8 +53,34 @@ def copy_blocks(


@dataclass
class AttentionMetadataPerStage:
"""Attention metadata for a specific stage. I.e., prefill or decode."""
class AttentionMetadata:
"""Attention metadata for prefill and decode batched together."""
# Total number of prefill requests.
num_prefills: int
# Number of prefill tokens.
num_prefill_tokens: int
# Number of decode tokens. Note that it is equivalent to the number of
# decode requests.
num_decode_tokens: int
# (num_tokens,). The indices of the token slots that input tokens will be
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor

@property
@abstractmethod
def prefill_metadata(self) -> Optional["AttentionMetadata"]:
"""Return the attention metadata that's required to run prefill
attention."""
pass

@property
@abstractmethod
def decode_metadata(self) -> Optional["AttentionMetadata"]:
"""Return the attention metadata that's required to run decode
attention."""
pass

def asdict_zerocopy(self,
skip_fields: Optional[Set[str]] = None
Expand All @@ -70,40 +96,10 @@ def asdict_zerocopy(self,
}


T = TypeVar("T", bound=AttentionMetadataPerStage)


@dataclass
class AttentionMetadata(Generic[T]):
"""Attention metadata for prefill and decode batched together."""
# Total number of prefill requests.
num_prefills: int
# Number of prefill tokens.
num_prefill_tokens: int
# Number of decode tokens. Note that it is equivalent to the number of
# decode requests.
num_decode_tokens: int
# The attention metadata for prefill requests in a batch.
# None if there's no prefill requests in a batch.
prefill_metadata: Optional[T]
# The attention metadata for decode requests in a batch.
# None if there's no decode requests in a batch.
decode_metadata: Optional[T]
# (num_tokens,). The indices of the token slots that input tokens will be
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor

def __post_init__(self):
if self.num_prefill_tokens > 0:
assert self.num_prefills > 0
assert self.prefill_metadata is not None
if self.num_decode_tokens > 0:
assert self.decode_metadata is not None
T = TypeVar("T", bound=AttentionMetadata)


class AttentionImpl(ABC):
class AttentionImpl(ABC, Generic[T]):

@abstractmethod
def __init__(
Expand All @@ -125,7 +121,7 @@ def forward(
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
attn_metadata: T,
kv_scale: float = 1.0,
) -> torch.Tensor:
raise NotImplementedError
Loading

0 comments on commit 71cd938

Please sign in to comment.