Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core][5/N] Fully working chunked prefill e2e #3884

Merged
merged 16 commits into from
Apr 11, 2024
62 changes: 38 additions & 24 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,25 +74,31 @@ def run_vllm(
quantization_param_path: Optional[str],
device: str,
enable_prefix_caching: bool,
enable_chunked_prefill: bool,
max_num_batched_tokens: int,
gpu_memory_utilization: float = 0.9,
download_dir: Optional[str] = None,
) -> float:
from vllm import LLM, SamplingParams
llm = LLM(model=model,
tokenizer=tokenizer,
quantization=quantization,
tensor_parallel_size=tensor_parallel_size,
seed=seed,
trust_remote_code=trust_remote_code,
dtype=dtype,
max_model_len=max_model_len,
gpu_memory_utilization=gpu_memory_utilization,
enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype,
quantization_param_path=quantization_param_path,
device=device,
enable_prefix_caching=enable_prefix_caching,
download_dir=download_dir)
llm = LLM(
model=model,
tokenizer=tokenizer,
quantization=quantization,
tensor_parallel_size=tensor_parallel_size,
seed=seed,
trust_remote_code=trust_remote_code,
dtype=dtype,
max_model_len=max_model_len,
gpu_memory_utilization=gpu_memory_utilization,
enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype,
quantization_param_path=quantization_param_path,
device=device,
enable_prefix_caching=enable_prefix_caching,
download_dir=download_dir,
enable_chunked_prefill=enable_chunked_prefill,
max_num_batched_tokens=max_num_batched_tokens,
)

# Add the requests to the engine.
for prompt, _, output_len in requests:
Expand Down Expand Up @@ -213,15 +219,15 @@ def main(args: argparse.Namespace):
args.output_len)

if args.backend == "vllm":
elapsed_time = run_vllm(requests, args.model, args.tokenizer,
args.quantization, args.tensor_parallel_size,
args.seed, args.n, args.use_beam_search,
args.trust_remote_code, args.dtype,
args.max_model_len, args.enforce_eager,
args.kv_cache_dtype,
args.quantization_param_path, args.device,
args.enable_prefix_caching,
args.gpu_memory_utilization, args.download_dir)
elapsed_time = run_vllm(
requests, args.model, args.tokenizer, args.quantization,
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
args.trust_remote_code, args.dtype, args.max_model_len,
args.enforce_eager, args.kv_cache_dtype,
args.quantization_param_path, args.device,
args.enable_prefix_caching, args.enable_chunked_prefill,
args.max_num_batched_tokens, args.gpu_memory_utilization,
args.download_dir)
elif args.backend == "hf":
assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
Expand Down Expand Up @@ -335,6 +341,14 @@ def main(args: argparse.Namespace):
"--enable-prefix-caching",
action='store_true',
help="enable automatic prefix caching for vLLM backend.")
parser.add_argument("--enable-chunked-prefill",
action='store_true',
help="enable chunked prefill for vLLM backend.")
parser.add_argument('--max-num-batched-tokens',
type=int,
default=None,
help='maximum number of batched tokens per '
'iteration')
parser.add_argument('--download-dir',
type=str,
default=None,
Expand Down
68 changes: 68 additions & 0 deletions tests/basic_correctness/test_chunked_prefill.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""Compare the outputs of HF and vLLM when using greedy sampling.

It tests chunked prefill. Chunked prefill can be enabled by
enable_chunked_prefill=True. If prefill size exceeds max_num_batched_tokens,
prefill requests are chunked.

Run `pytest tests/models/test_chunked_prefill.py`.
"""
import pytest

MODELS = [
"facebook/opt-125m",
"meta-llama/Llama-2-7b-hf",
]


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16])
@pytest.mark.parametrize("enforce_eager", [False, True])
@pytest.mark.parametrize("tensor_parallel_size", [2, 1])
def test_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
chunked_prefill_token_size: int,
enforce_eager: bool,
tensor_parallel_size: int,
) -> None:
if (tensor_parallel_size == 2 and chunked_prefill_token_size != 16
and not enforce_eager):
pytest.skip(f"Skip {chunked_prefill_token_size=} and {enforce_eager=} "
"for high TP to save testing time.")
# To pass the small model tests, we need full precision.
# assert dtype == "float"
enable_chunked_prefill = False
max_num_batched_tokens = None
if chunked_prefill_token_size != -1:
enable_chunked_prefill = True
max_num_batched_tokens = chunked_prefill_token_size

hf_model = hf_runner(model, dtype=dtype)
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
del hf_model

vllm_model = vllm_runner(
model,
dtype=dtype,
max_num_batched_tokens=max_num_batched_tokens,
enable_chunked_prefill=enable_chunked_prefill,
tensor_parallel_size=tensor_parallel_size,
enforce_eager=enforce_eager,
)
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
del vllm_model
print(vllm_outputs[0])

for i in range(len(example_prompts)):
hf_output_ids, hf_output_str = hf_outputs[i]
vllm_output_ids, vllm_output_str = vllm_outputs[i]
assert hf_output_str == vllm_output_str, (
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
assert hf_output_ids == vllm_output_ids, (
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
1 change: 1 addition & 0 deletions tests/distributed/test_basic_distributed_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def test_models(
dtype: str,
max_tokens: int,
) -> None:

hf_model = hf_runner(model, dtype=dtype)
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
del hf_model
Expand Down
60 changes: 50 additions & 10 deletions tests/worker/test_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@ def test_prepare_prompt(batch_size):
prompt_len - 1)
selected_token_start_idx += prompt_len
(input_tokens, input_positions, attn_metadata, return_prompt_lens, _, _, _,
_, _) = (model_runner._prepare_prompt(seq_group_metadata_list))
_, _,
slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list))
assert return_prompt_lens == prompt_lens
assert len(slot_mapping) == len(input_tokens)

# Verify input metadata is correct for prompts.
device = model_runner.device
Expand Down Expand Up @@ -85,21 +87,21 @@ def test_prepare_prompt(batch_size):
assert attn_metadata.use_cuda_graph is False
assert attn_metadata.kv_cache_dtype == "auto"

assert input_tokens.shape == (sum(prompt_lens), )
assert input_positions.shape == (sum(prompt_lens), )
assert len(input_tokens) == sum(prompt_lens)
assert len(input_positions) == sum(prompt_lens)
torch.testing.assert_close(input_tokens, input_positions)

sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens)
assert input_tokens.shape == (sum(prompt_lens), )
assert input_positions.shape == (sum(prompt_lens), )
assert len(input_tokens) == sum(prompt_lens)
assert len(input_positions) == sum(prompt_lens)
actual = sampling_metadata.selected_token_indices
expected = torch.tensor(expected_selected_token_indices,
device=actual.device,
dtype=actual.dtype)
torch.testing.assert_close(actual, expected)
torch.testing.assert_close(input_tokens, input_positions)
assert input_tokens == input_positions

actual = sampling_metadata.selected_token_indices
expected = torch.tensor(expected_selected_token_indices,
Expand Down Expand Up @@ -143,8 +145,9 @@ def test_prepare_decode_cuda_graph(batch_size):
assert seq_group_metadata.token_chunk_size == 1
seq_group_metadata_list.append(seq_group_metadata)

input_tokens, input_positions, attn_metadata, _, _, _ = (
input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = (
model_runner._prepare_decode(seq_group_metadata_list))
assert len(slot_mapping) == len(input_tokens)

expected_bs = _get_graph_batch_size(len(seq_group_metadata_list))
# Verify input metadata is correct for prompts.
Expand Down Expand Up @@ -172,9 +175,9 @@ def test_prepare_decode_cuda_graph(batch_size):
assert attn_metadata.use_cuda_graph is True
assert attn_metadata.kv_cache_dtype == "auto"

assert input_tokens.shape == (expected_bs, )
assert input_positions.shape == (expected_bs, )
torch.testing.assert_close(input_tokens, input_positions)
assert len(input_tokens) == expected_bs
assert len(input_positions) == expected_bs
assert input_tokens == input_positions

# Verify Sampling
expected_selected_token_indices = []
Expand All @@ -190,3 +193,40 @@ def test_prepare_decode_cuda_graph(batch_size):
device=actual.device,
dtype=actual.dtype)
torch.testing.assert_close(actual, expected)


def test_empty_seq_group():
"""Verify prepare prompt and decode returns empty output."""
model_config = ModelConfig(
"facebook/opt-125m",
"facebook/opt-125m",
tokenizer_mode="auto",
trust_remote_code=False,
download_dir=None,
load_format="dummy",
seed=0,
dtype="float16",
revision=None,
enforce_eager=False,
)
model_runner = ModelRunner(model_config, None, None, None, None)
model_runner.set_block_size(16)
seq_group_metadata_list = []
input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = (
model_runner._prepare_decode(seq_group_metadata_list))
assert len(input_tokens) == 0
assert len(input_positions) == 0
assert attn_metadata is None
assert len(slot_mapping) == 0

(input_tokens, input_positions, attn_metadata, return_prompt_lens, _, _, _,
_, _,
slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list))
assert len(input_tokens) == 0
assert len(input_positions) == 0
assert attn_metadata is None
assert len(slot_mapping) == 0
assert len(return_prompt_lens) == 0


# SANG-TODO Test chunked prefill case.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one is WIP

4 changes: 3 additions & 1 deletion vllm/attention/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata)
AttentionMetadata,
AttentionMetadataPerStage)
from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend

Expand All @@ -8,4 +9,5 @@
"AttentionMetadata",
"Attention",
"get_attn_backend",
"AttentionMetadataPerStage",
]
30 changes: 27 additions & 3 deletions vllm/attention/backends/abstract.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@WoosukKwon @zhuohan123 would love your feedback on this interface change where the new AttentionMetadata can contains both prefill and decode stage metadata. I think this is a necessary change but would like to hear your thought on the interface design.

Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, fields
from typing import Any, Dict, List, Optional, Tuple, Type
from typing import Any, Dict, Generic, List, Optional, Tuple, Type, TypeVar

import torch

Expand Down Expand Up @@ -47,7 +47,8 @@ def copy_blocks(


@dataclass
class AttentionMetadata:
class AttentionMetadataPerStage:
"""Attention metadata for a specific stage. I.e., prefill or decode."""

def asdict_zerocopy(self) -> Dict[str, Any]:
"""Similar to dataclasses.asdict, but avoids deepcopying."""
Expand All @@ -59,6 +60,29 @@ def asdict_zerocopy(self) -> Dict[str, Any]:
}


T = TypeVar("T", bound=AttentionMetadataPerStage)


@dataclass
class AttentionMetadata(Generic[T]):
"""Attention metadata for prefill and decode batched together."""
# Total number of prefill requests.
num_prefills: int
# Number of prefill tokens.
num_prefill_tokens: int
# Number of decode tokens. Note that it is equivalent to the number of
# decode requests.
num_decode_tokens: int
# (num_tokens,). The indices of the token slots that input tokens will be
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor
kv_cache_dtype: str
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

group these two separately from the prefill/decode. i guess it make sense here as a general arguments instead of the per stage ones.

prefill_metadata: Optional[T]
decode_metadata: Optional[T]


class AttentionImpl(ABC):

@abstractmethod
Expand All @@ -80,7 +104,7 @@ def forward(
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
attn_metadata: AttentionMetadata[AttentionMetadataPerStage],
kv_scale: float,
) -> torch.Tensor:
raise NotImplementedError
Loading
Loading