This repository has been archived by the owner on Oct 11, 2024. It is now read-only.
forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Core][5/N] Fully working chunked prefill e2e (vllm-project#3884)
- Loading branch information
Showing
26 changed files
with
927 additions
and
315 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
"""Compare the outputs of HF and vLLM when using greedy sampling. | ||
It tests chunked prefill. Chunked prefill can be enabled by | ||
enable_chunked_prefill=True. If prefill size exceeds max_num_batched_tokens, | ||
prefill requests are chunked. | ||
Run `pytest tests/models/test_chunked_prefill.py`. | ||
""" | ||
import pytest | ||
|
||
MODELS = [ | ||
"facebook/opt-125m", | ||
"meta-llama/Llama-2-7b-hf", | ||
] | ||
|
||
|
||
@pytest.mark.parametrize("model", MODELS) | ||
@pytest.mark.parametrize("dtype", ["half"]) | ||
@pytest.mark.parametrize("max_tokens", [32]) | ||
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16]) | ||
@pytest.mark.parametrize("enforce_eager", [False, True]) | ||
# NOTE: Increasing this in this suite will fail CI because we currently cannot | ||
# reset distributed env properly. Use a value > 1 just when you test. | ||
@pytest.mark.parametrize("tensor_parallel_size", [1]) | ||
def test_models( | ||
hf_runner, | ||
vllm_runner, | ||
example_prompts, | ||
model: str, | ||
dtype: str, | ||
max_tokens: int, | ||
chunked_prefill_token_size: int, | ||
enforce_eager: bool, | ||
tensor_parallel_size: int, | ||
) -> None: | ||
if (tensor_parallel_size == 2 and chunked_prefill_token_size != 16 | ||
and not enforce_eager): | ||
pytest.skip(f"Skip {chunked_prefill_token_size=} and {enforce_eager=} " | ||
"for high TP to save testing time.") | ||
max_num_seqs = min(chunked_prefill_token_size, 256) | ||
enable_chunked_prefill = False | ||
max_num_batched_tokens = None | ||
if chunked_prefill_token_size != -1: | ||
enable_chunked_prefill = True | ||
max_num_batched_tokens = chunked_prefill_token_size | ||
|
||
hf_model = hf_runner(model, dtype=dtype) | ||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) | ||
del hf_model | ||
|
||
vllm_model = vllm_runner( | ||
model, | ||
dtype=dtype, | ||
max_num_batched_tokens=max_num_batched_tokens, | ||
enable_chunked_prefill=enable_chunked_prefill, | ||
tensor_parallel_size=tensor_parallel_size, | ||
enforce_eager=enforce_eager, | ||
max_num_seqs=max_num_seqs, | ||
) | ||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) | ||
del vllm_model | ||
print(vllm_outputs[0]) | ||
|
||
for i in range(len(example_prompts)): | ||
hf_output_ids, hf_output_str = hf_outputs[i] | ||
vllm_output_ids, vllm_output_str = vllm_outputs[i] | ||
assert hf_output_str == vllm_output_str, ( | ||
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") | ||
assert hf_output_ids == vllm_output_ids, ( | ||
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
"""Compare the outputs of HF and distributed vLLM when using greedy sampling. | ||
vLLM will allocate all the available memory, so we need to run the tests one | ||
by one. The solution is to pass arguments (model name) by environment | ||
variables. | ||
Run: | ||
```sh | ||
TEST_DIST_MODEL=facebook/opt-125m pytest \ | ||
test_chunked_prefill_distributed.py | ||
TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf \ | ||
test_chunked_prefill_distributed.py | ||
``` | ||
""" | ||
import os | ||
|
||
import pytest | ||
import torch | ||
|
||
MODELS = [ | ||
os.environ["TEST_DIST_MODEL"], | ||
] | ||
|
||
|
||
@pytest.mark.skipif(torch.cuda.device_count() < 2, | ||
reason="Need at least 2 GPUs to run the test.") | ||
@pytest.mark.parametrize("model", MODELS) | ||
@pytest.mark.parametrize("dtype", ["half"]) | ||
@pytest.mark.parametrize("max_tokens", [5]) | ||
@pytest.mark.parametrize("chunked_prefill_token_size", [16]) | ||
def test_models( | ||
hf_runner, | ||
vllm_runner, | ||
example_prompts, | ||
model: str, | ||
dtype: str, | ||
max_tokens: int, | ||
chunked_prefill_token_size: int, | ||
) -> None: | ||
# Add a chunked prefill config. | ||
max_num_seqs = min(chunked_prefill_token_size, 256) | ||
assert chunked_prefill_token_size != -1 | ||
enable_chunked_prefill = True | ||
max_num_batched_tokens = chunked_prefill_token_size | ||
|
||
hf_model = hf_runner(model, dtype=dtype) | ||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) | ||
del hf_model | ||
|
||
vllm_model = vllm_runner( | ||
model, | ||
dtype=dtype, | ||
tensor_parallel_size=2, | ||
max_num_seqs=max_num_seqs, | ||
enable_chunked_prefill=enable_chunked_prefill, | ||
max_num_batched_tokens=max_num_batched_tokens, | ||
) | ||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) | ||
del vllm_model | ||
|
||
for i in range(len(example_prompts)): | ||
hf_output_ids, hf_output_str = hf_outputs[i] | ||
vllm_output_ids, vllm_output_str = vllm_outputs[i] | ||
assert hf_output_str == vllm_output_str, ( | ||
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") | ||
assert hf_output_ids == vllm_output_ids, ( | ||
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.