Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
[Core][5/N] Fully working chunked prefill e2e (vllm-project#3884)
Browse files Browse the repository at this point in the history
  • Loading branch information
rkooo567 authored and andy-neuma committed Apr 12, 2024
1 parent f2c4427 commit ffb2afc
Show file tree
Hide file tree
Showing 26 changed files with 927 additions and 315 deletions.
2 changes: 2 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ steps:
- pytest -v -s test_pynccl.py
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_basic_distributed_correctness.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_chunked_prefill_distributed.py

- label: Engine Test
command: pytest -v -s engine tokenization test_sequence.py test_config.py
Expand Down
3 changes: 1 addition & 2 deletions benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,7 @@ def run_to_completion(profile_dir: Optional[str] = None):
help='block size of key/value cache')
parser.add_argument(
'--enable-chunked-prefill',
type=bool,
default=False,
action='store_true',
help='If True, the prefill requests can be chunked based on the '
'max_num_batched_tokens')
parser.add_argument(
Expand Down
62 changes: 38 additions & 24 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,25 +74,31 @@ def run_vllm(
quantization_param_path: Optional[str],
device: str,
enable_prefix_caching: bool,
enable_chunked_prefill: bool,
max_num_batched_tokens: int,
gpu_memory_utilization: float = 0.9,
download_dir: Optional[str] = None,
) -> float:
from vllm import LLM, SamplingParams
llm = LLM(model=model,
tokenizer=tokenizer,
quantization=quantization,
tensor_parallel_size=tensor_parallel_size,
seed=seed,
trust_remote_code=trust_remote_code,
dtype=dtype,
max_model_len=max_model_len,
gpu_memory_utilization=gpu_memory_utilization,
enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype,
quantization_param_path=quantization_param_path,
device=device,
enable_prefix_caching=enable_prefix_caching,
download_dir=download_dir)
llm = LLM(
model=model,
tokenizer=tokenizer,
quantization=quantization,
tensor_parallel_size=tensor_parallel_size,
seed=seed,
trust_remote_code=trust_remote_code,
dtype=dtype,
max_model_len=max_model_len,
gpu_memory_utilization=gpu_memory_utilization,
enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype,
quantization_param_path=quantization_param_path,
device=device,
enable_prefix_caching=enable_prefix_caching,
download_dir=download_dir,
enable_chunked_prefill=enable_chunked_prefill,
max_num_batched_tokens=max_num_batched_tokens,
)

# Add the requests to the engine.
for prompt, _, output_len in requests:
Expand Down Expand Up @@ -213,15 +219,15 @@ def main(args: argparse.Namespace):
args.output_len)

if args.backend == "vllm":
elapsed_time = run_vllm(requests, args.model, args.tokenizer,
args.quantization, args.tensor_parallel_size,
args.seed, args.n, args.use_beam_search,
args.trust_remote_code, args.dtype,
args.max_model_len, args.enforce_eager,
args.kv_cache_dtype,
args.quantization_param_path, args.device,
args.enable_prefix_caching,
args.gpu_memory_utilization, args.download_dir)
elapsed_time = run_vllm(
requests, args.model, args.tokenizer, args.quantization,
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
args.trust_remote_code, args.dtype, args.max_model_len,
args.enforce_eager, args.kv_cache_dtype,
args.quantization_param_path, args.device,
args.enable_prefix_caching, args.enable_chunked_prefill,
args.max_num_batched_tokens, args.gpu_memory_utilization,
args.download_dir)
elif args.backend == "hf":
assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
Expand Down Expand Up @@ -335,6 +341,14 @@ def main(args: argparse.Namespace):
"--enable-prefix-caching",
action='store_true',
help="enable automatic prefix caching for vLLM backend.")
parser.add_argument("--enable-chunked-prefill",
action='store_true',
help="enable chunked prefill for vLLM backend.")
parser.add_argument('--max-num-batched-tokens',
type=int,
default=None,
help='maximum number of batched tokens per '
'iteration')
parser.add_argument('--download-dir',
type=str,
default=None,
Expand Down
70 changes: 70 additions & 0 deletions tests/basic_correctness/test_chunked_prefill.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""Compare the outputs of HF and vLLM when using greedy sampling.
It tests chunked prefill. Chunked prefill can be enabled by
enable_chunked_prefill=True. If prefill size exceeds max_num_batched_tokens,
prefill requests are chunked.
Run `pytest tests/models/test_chunked_prefill.py`.
"""
import pytest

MODELS = [
"facebook/opt-125m",
"meta-llama/Llama-2-7b-hf",
]


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16])
@pytest.mark.parametrize("enforce_eager", [False, True])
# NOTE: Increasing this in this suite will fail CI because we currently cannot
# reset distributed env properly. Use a value > 1 just when you test.
@pytest.mark.parametrize("tensor_parallel_size", [1])
def test_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
chunked_prefill_token_size: int,
enforce_eager: bool,
tensor_parallel_size: int,
) -> None:
if (tensor_parallel_size == 2 and chunked_prefill_token_size != 16
and not enforce_eager):
pytest.skip(f"Skip {chunked_prefill_token_size=} and {enforce_eager=} "
"for high TP to save testing time.")
max_num_seqs = min(chunked_prefill_token_size, 256)
enable_chunked_prefill = False
max_num_batched_tokens = None
if chunked_prefill_token_size != -1:
enable_chunked_prefill = True
max_num_batched_tokens = chunked_prefill_token_size

hf_model = hf_runner(model, dtype=dtype)
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
del hf_model

vllm_model = vllm_runner(
model,
dtype=dtype,
max_num_batched_tokens=max_num_batched_tokens,
enable_chunked_prefill=enable_chunked_prefill,
tensor_parallel_size=tensor_parallel_size,
enforce_eager=enforce_eager,
max_num_seqs=max_num_seqs,
)
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
del vllm_model
print(vllm_outputs[0])

for i in range(len(example_prompts)):
hf_output_ids, hf_output_str = hf_outputs[i]
vllm_output_ids, vllm_output_str = vllm_outputs[i]
assert hf_output_str == vllm_output_str, (
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
assert hf_output_ids == vllm_output_ids, (
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
16 changes: 8 additions & 8 deletions tests/core/test_chunked_prefill_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,10 @@ def test_chunk():
# One chunked prefill, and one decoding.
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert set(get_sequence_groups(out)) == set(running)
# The first one is decoding.
assert seq_group_meta[0].token_chunk_size == 1
# The first one is prefill. Scheduler guarantees ordering.
assert seq_group_meta[0].token_chunk_size == 56
# The second one is a chunked prefill.
assert seq_group_meta[1].token_chunk_size == 56
assert seq_group_meta[1].token_chunk_size == 1
assert out.num_prefill_groups == 1
assert out.num_batched_tokens == 57

Expand Down Expand Up @@ -157,12 +157,12 @@ def test_complex():
# Decoding & chunked prefill & first chunk of 3rd request is scheduled.
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert len(get_sequence_groups(out)) == 3
# The first one is decoding.
assert seq_group_meta[0].token_chunk_size == 1
# The second one is a chunked prefill.
# The first one is the first chunked prefill.
assert seq_group_meta[0].token_chunk_size == 7
# The second one is the second new chunked prefill.
assert seq_group_meta[1].token_chunk_size == 56
# The third one is also chunked.
assert seq_group_meta[2].token_chunk_size == 7
# The last one is decode.
assert seq_group_meta[2].token_chunk_size == 1
# Two of them are in chunked prefill.
assert out.num_prefill_groups == 2
assert out.num_batched_tokens == 64
Expand Down
7 changes: 6 additions & 1 deletion tests/distributed/test_basic_distributed_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,16 @@ def test_models(
dtype: str,
max_tokens: int,
) -> None:

hf_model = hf_runner(model, dtype=dtype)
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
del hf_model

vllm_model = vllm_runner(model, dtype=dtype, tensor_parallel_size=2)
vllm_model = vllm_runner(
model,
dtype=dtype,
tensor_parallel_size=2,
)
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
del vllm_model

Expand Down
66 changes: 66 additions & 0 deletions tests/distributed/test_chunked_prefill_distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""Compare the outputs of HF and distributed vLLM when using greedy sampling.
vLLM will allocate all the available memory, so we need to run the tests one
by one. The solution is to pass arguments (model name) by environment
variables.
Run:
```sh
TEST_DIST_MODEL=facebook/opt-125m pytest \
test_chunked_prefill_distributed.py
TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf \
test_chunked_prefill_distributed.py
```
"""
import os

import pytest
import torch

MODELS = [
os.environ["TEST_DIST_MODEL"],
]


@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("chunked_prefill_token_size", [16])
def test_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
chunked_prefill_token_size: int,
) -> None:
# Add a chunked prefill config.
max_num_seqs = min(chunked_prefill_token_size, 256)
assert chunked_prefill_token_size != -1
enable_chunked_prefill = True
max_num_batched_tokens = chunked_prefill_token_size

hf_model = hf_runner(model, dtype=dtype)
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
del hf_model

vllm_model = vllm_runner(
model,
dtype=dtype,
tensor_parallel_size=2,
max_num_seqs=max_num_seqs,
enable_chunked_prefill=enable_chunked_prefill,
max_num_batched_tokens=max_num_batched_tokens,
)
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
del vllm_model

for i in range(len(example_prompts)):
hf_output_ids, hf_output_str = hf_outputs[i]
vllm_output_ids, vllm_output_str = vllm_outputs[i]
assert hf_output_str == vllm_output_str, (
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
assert hf_output_ids == vllm_output_ids, (
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
2 changes: 1 addition & 1 deletion tests/entrypoints/test_openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def server(zephyr_lora_files):
"--max-cpu-loras",
"2",
"--max-num-seqs",
"128"
"128",
])
ray.get(server_runner.ready.remote())
yield server_runner
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"gpt2",
"bigcode/tiny_starcoder_py",
"EleutherAI/pythia-70m",
"bigscience/bloom-560m",
"bigscience/bloom-560m", # Testing alibi slopes.
"microsoft/phi-2",
"stabilityai/stablelm-3b-4e1t",
# "allenai/OLMo-1B", # Broken
Expand Down
Loading

0 comments on commit ffb2afc

Please sign in to comment.