-
-
Notifications
You must be signed in to change notification settings - Fork 5.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Core][5/N] Fully working chunked prefill e2e #3884
Changes from 5 commits
61b7294
0272344
01e5b3b
502bd19
afa247e
a4cbe2d
e735cc2
62db33a
a18ae3a
4b84904
5ec4891
b814fdb
d01f893
346e862
addf88e
26bfcc3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
"""Compare the outputs of HF and vLLM when using greedy sampling. | ||
|
||
It tests chunked prefill. Chunked prefill can be enabled by | ||
enable_chunked_prefill=True. If prefill size exceeds max_num_batched_tokens, | ||
prefill requests are chunked. | ||
|
||
Run `pytest tests/models/test_chunked_prefill.py`. | ||
""" | ||
import pytest | ||
|
||
MODELS = [ | ||
"facebook/opt-125m", | ||
"meta-llama/Llama-2-7b-hf", | ||
] | ||
|
||
|
||
@pytest.mark.parametrize("model", MODELS) | ||
@pytest.mark.parametrize("dtype", ["half"]) | ||
@pytest.mark.parametrize("max_tokens", [32]) | ||
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16]) | ||
@pytest.mark.parametrize("enforce_eager", [False, True]) | ||
@pytest.mark.parametrize("tensor_parallel_size", [2, 1]) | ||
def test_models( | ||
hf_runner, | ||
vllm_runner, | ||
example_prompts, | ||
model: str, | ||
dtype: str, | ||
max_tokens: int, | ||
chunked_prefill_token_size: int, | ||
enforce_eager: bool, | ||
tensor_parallel_size: int, | ||
) -> None: | ||
if (tensor_parallel_size == 2 and chunked_prefill_token_size != 16 | ||
and not enforce_eager): | ||
pytest.skip(f"Skip {chunked_prefill_token_size=} and {enforce_eager=} " | ||
"for high TP to save testing time.") | ||
# To pass the small model tests, we need full precision. | ||
# assert dtype == "float" | ||
enable_chunked_prefill = False | ||
max_num_batched_tokens = None | ||
if chunked_prefill_token_size != -1: | ||
enable_chunked_prefill = True | ||
max_num_batched_tokens = chunked_prefill_token_size | ||
|
||
hf_model = hf_runner(model, dtype=dtype) | ||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) | ||
del hf_model | ||
|
||
vllm_model = vllm_runner( | ||
model, | ||
dtype=dtype, | ||
max_num_batched_tokens=max_num_batched_tokens, | ||
enable_chunked_prefill=enable_chunked_prefill, | ||
tensor_parallel_size=tensor_parallel_size, | ||
enforce_eager=enforce_eager, | ||
) | ||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) | ||
del vllm_model | ||
print(vllm_outputs[0]) | ||
|
||
for i in range(len(example_prompts)): | ||
hf_output_ids, hf_output_str = hf_outputs[i] | ||
vllm_output_ids, vllm_output_str = vllm_outputs[i] | ||
assert hf_output_str == vllm_output_str, ( | ||
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") | ||
assert hf_output_ids == vllm_output_ids, ( | ||
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @WoosukKwon @zhuohan123 would love your feedback on this interface change where the new |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
from abc import ABC, abstractmethod | ||
from dataclasses import dataclass, fields | ||
from typing import Any, Dict, List, Optional, Tuple, Type | ||
from typing import Any, Dict, Generic, List, Optional, Tuple, Type, TypeVar | ||
|
||
import torch | ||
|
||
|
@@ -47,7 +47,8 @@ def copy_blocks( | |
|
||
|
||
@dataclass | ||
class AttentionMetadata: | ||
class AttentionMetadataPerStage: | ||
"""Attention metadata for a specific stage. I.e., prefill or decode.""" | ||
|
||
def asdict_zerocopy(self) -> Dict[str, Any]: | ||
"""Similar to dataclasses.asdict, but avoids deepcopying.""" | ||
|
@@ -59,6 +60,29 @@ def asdict_zerocopy(self) -> Dict[str, Any]: | |
} | ||
|
||
|
||
T = TypeVar("T", bound=AttentionMetadataPerStage) | ||
|
||
|
||
@dataclass | ||
class AttentionMetadata(Generic[T]): | ||
"""Attention metadata for prefill and decode batched together.""" | ||
# Total number of prefill requests. | ||
num_prefills: int | ||
# Number of prefill tokens. | ||
num_prefill_tokens: int | ||
# Number of decode tokens. Note that it is equivalent to the number of | ||
# decode requests. | ||
num_decode_tokens: int | ||
# (num_tokens,). The indices of the token slots that input tokens will be | ||
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size | ||
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot | ||
# in block 0, and 1st slot in block 1, respectively. | ||
slot_mapping: torch.Tensor | ||
kv_cache_dtype: str | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. group these two separately from the prefill/decode. i guess it make sense here as a general arguments instead of the per stage ones. |
||
prefill_metadata: Optional[T] | ||
decode_metadata: Optional[T] | ||
|
||
|
||
class AttentionImpl(ABC): | ||
|
||
@abstractmethod | ||
|
@@ -80,7 +104,7 @@ def forward( | |
key: torch.Tensor, | ||
value: torch.Tensor, | ||
kv_cache: torch.Tensor, | ||
attn_metadata: AttentionMetadata, | ||
attn_metadata: AttentionMetadata[AttentionMetadataPerStage], | ||
kv_scale: float, | ||
) -> torch.Tensor: | ||
raise NotImplementedError |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This one is WIP