Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[2/N] Chunked prefill data update #3538

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
127 commits
Select commit Hold shift + click to select a range
06fe872
[1/n] Support efficient reshape caching.
rkooo567 Feb 28, 2024
9a0b6be
[2/n] support flash attention kernel
rkooo567 Feb 28, 2024
6947167
oss flash attention works
rkooo567 Feb 28, 2024
4769a26
in progress
rkooo567 Feb 28, 2024
963db44
flash attn enabled.
rkooo567 Feb 29, 2024
2b9c36b
ip
rkooo567 Feb 29, 2024
2c1bb6c
support every model
rkooo567 Feb 29, 2024
2bb5e62
Fixed broken tests.
rkooo567 Feb 29, 2024
4d6a05f
[2/n] scheduler changes
rkooo567 Feb 29, 2024
0831f84
[2/n] ip
rkooo567 Feb 29, 2024
f31371f
[2/n]ip
rkooo567 Feb 29, 2024
78bb887
ip
rkooo567 Feb 29, 2024
b9d93c5
Merge branch 'chunked-prefill-3' into chunked-prefill-scheduler
rkooo567 Feb 29, 2024
42dd362
[2/n] ip
rkooo567 Mar 1, 2024
74ac900
seems to work.
rkooo567 Mar 1, 2024
e3afc25
Merge branch 'chunked-prefill-3' into chunked-prefill-scheduler
rkooo567 Mar 1, 2024
6141885
[2/n] ip
rkooo567 Mar 1, 2024
71bdada
.
rkooo567 Mar 1, 2024
d4c3b5d
ip?
rkooo567 Mar 1, 2024
baef7c6
block tables updated correctly
rkooo567 Mar 1, 2024
d503a22
Merge branch 'chunked-prefill-3' into chunked-prefill-scheduler
rkooo567 Mar 1, 2024
a12ec68
hopefully tests pass
rkooo567 Mar 1, 2024
85760db
Merge branch 'chunked-prefill-3' into chunked-prefill-scheduler
rkooo567 Mar 3, 2024
e40bc45
[2/n] update sequence data
rkooo567 Mar 3, 2024
d85670f
[2/n] add prefill range apis
rkooo567 Mar 3, 2024
0d8785f
Merge branch 'main' into chunked-prefill-3
rkooo567 Mar 3, 2024
08c8541
.
rkooo567 Mar 3, 2024
3bac9af
ip
rkooo567 Mar 3, 2024
0ca1284
add data.
rkooo567 Mar 3, 2024
2487bda
ip
rkooo567 Mar 3, 2024
81151e8
ip
rkooo567 Mar 3, 2024
31aa920
ip
rkooo567 Mar 4, 2024
2049b35
.
rkooo567 Mar 4, 2024
ef679d7
.
rkooo567 Mar 4, 2024
71bda97
.
rkooo567 Mar 4, 2024
4e00e7f
done?
rkooo567 Mar 4, 2024
c5f3a0d
Merge branch 'chunked-prefill-3' into chunked-prefill-scheduler
rkooo567 Mar 4, 2024
7fd70f2
Merge branch 'main' into chunked-prefill-3
rkooo567 Mar 5, 2024
9bbb04e
Merge branch 'chunked-prefill-3' into chunked-prefill-scheduler-data-…
rkooo567 Mar 5, 2024
9177d54
Merge branch 'main' into chunked-prefill-3
rkooo567 Mar 6, 2024
5e47c1e
Merge branch 'chunked-prefill-3' into chunked-prefill-scheduler-data-…
rkooo567 Mar 6, 2024
c0384a4
Refactor 2d query to 1d query
rkooo567 Mar 6, 2024
6032edf
.,
rkooo567 Mar 6, 2024
c1ab0b0
done
rkooo567 Mar 6, 2024
f48dc72
Addressed code review.
rkooo567 Mar 7, 2024
769b2b4
working
rkooo567 Mar 7, 2024
4a20f4a
Merge branch 'main' into 1dquery
rkooo567 Mar 7, 2024
f7347b8
working
rkooo567 Mar 7, 2024
d931725
Merge branch 'main' into 1dquery
rkooo567 Mar 7, 2024
f91d73e
fix lora
rkooo567 Mar 8, 2024
f7d79da
fixed
rkooo567 Mar 8, 2024
851c018
Merge branch 'main' into 1dquery
rkooo567 Mar 8, 2024
406f1d4
fix
rkooo567 Mar 8, 2024
c66ec36
Merge branch '1dquery' into chunked-prefill-scheduler-data-update
rkooo567 Mar 11, 2024
c067a4c
working.
rkooo567 Mar 11, 2024
e1f244a
clean up.
rkooo567 Mar 11, 2024
d09eaf5
.
rkooo567 Mar 11, 2024
4a8ab3c
Merge branch 'main' into chunked-prefill-scheduler-data-update
rkooo567 Mar 11, 2024
a08e65e
Merge branch 'main' into 1dquery
rkooo567 Mar 11, 2024
d9532f8
Merge branch '1dquery' into chunked-prefill-scheduler-data-update
rkooo567 Mar 11, 2024
93a7b90
.
rkooo567 Mar 12, 2024
b4b94c6
Merge branch '1dquery' into chunked-prefill-scheduler-data-update
rkooo567 Mar 12, 2024
647d8cc
.
rkooo567 Mar 12, 2024
65ac6ce
Merge branch '1dquery' into chunked-prefill-scheduler-data-update
rkooo567 Mar 12, 2024
b2f4b3e
ip
rkooo567 Mar 12, 2024
cc8419f
.
rkooo567 Mar 12, 2024
76e7ca8
Merge branch '1dquery' into chunked-prefill-scheduler-data-update
rkooo567 Mar 12, 2024
d3d0336
Merge branch 'main' into 1dquery
rkooo567 Mar 15, 2024
11ec167
Merge branch '1dquery' into chunked-prefill-scheduler-data-update
rkooo567 Mar 15, 2024
3cb8093
ip addressing comments.
rkooo567 Mar 16, 2024
5391129
Alibi slopes working now.
rkooo567 Mar 18, 2024
6b04443
Merge branch 'main' into 1dquery
rkooo567 Mar 18, 2024
fe344f6
add new fieflds
rkooo567 Mar 18, 2024
e619c4e
Flash attn works now
rkooo567 Mar 18, 2024
9c86aa3
Linting
rkooo567 Mar 18, 2024
5b4aa09
temporary
rkooo567 Mar 18, 2024
03dd155
Merge branch '1dquery' into chunked-prefill-scheduler-data-update
rkooo567 Mar 18, 2024
4cced78
fix tests
rkooo567 Mar 18, 2024
cdb7a2c
Fixed
rkooo567 Mar 18, 2024
276be06
Merge branch '1dquery' into chunked-prefill-scheduler-data-update
rkooo567 Mar 18, 2024
d87b651
Pass unit tests.
rkooo567 Mar 18, 2024
2c18896
experiment
rkooo567 Mar 18, 2024
b46f902
.
rkooo567 Mar 18, 2024
07b22f8
.
rkooo567 Mar 18, 2024
9bd7ea1
.
rkooo567 Mar 18, 2024
c55402f
trial
rkooo567 Mar 18, 2024
a13cf7e
remove --fork
rkooo567 Mar 18, 2024
c5c5581
Merge branch 'main' into 1dquery
rkooo567 Mar 18, 2024
ec91304
fixed
rkooo567 Mar 19, 2024
4977e53
Merge branch '1dquery' into chunked-prefill-scheduler-data-update
rkooo567 Mar 19, 2024
4a54688
Merge branch 'main' into 1dquery
rkooo567 Mar 19, 2024
2e6e919
Addressed code review.
rkooo567 Mar 19, 2024
1f6f6b0
Merge branch 'main' into 1dquery
rkooo567 Mar 19, 2024
ac7828c
revert removing forked
rkooo567 Mar 19, 2024
3d7f1a1
done
rkooo567 Mar 19, 2024
bcdd74a
Merge branch 'main' into 1dquery
rkooo567 Mar 20, 2024
fa3ce4e
final code review.
rkooo567 Mar 20, 2024
a83b235
Merge branch '1dquery' into chunked-prefill-scheduler-data-update
rkooo567 Mar 20, 2024
7205ef9
Merge branch 'main' into chunked-prefill-scheduler-data-update
rkooo567 Mar 21, 2024
8bc0af5
.
rkooo567 Mar 21, 2024
97bcb6f
ip
rkooo567 Mar 21, 2024
df34350
working except tests.
rkooo567 Mar 21, 2024
e70e03d
.
rkooo567 Mar 21, 2024
f89f428
ip
rkooo567 Mar 21, 2024
bf02f8e
done
rkooo567 Mar 21, 2024
ad43095
done
rkooo567 Mar 21, 2024
16b6196
Addressed code review.
rkooo567 Mar 22, 2024
916abc8
merge conflict fixed
rkooo567 Mar 25, 2024
5002e61
update
rkooo567 Mar 25, 2024
80f51ea
test fix
rkooo567 Mar 25, 2024
3cc5e99
Merge branch 'main' into chunked-prefill-scheduler-data-update
rkooo567 Mar 25, 2024
fa7ba35
lint
rkooo567 Mar 25, 2024
51cf7f2
fix broken tests.
rkooo567 Mar 25, 2024
cdee1c6
.
rkooo567 Mar 26, 2024
16e3a7d
done
rkooo567 Mar 26, 2024
e0d301c
remove num chunked prefill from seq group metadata
rkooo567 Mar 27, 2024
5e0f87e
change apis
rkooo567 Mar 27, 2024
6e72648
cleaned
rkooo567 Mar 27, 2024
4f869be
now working
rkooo567 Mar 27, 2024
4f63c57
update with new apis
rkooo567 Mar 27, 2024
5c3abf4
working!
rkooo567 Mar 27, 2024
66f3fcf
fixed
rkooo567 Mar 27, 2024
9c12d8e
Merge branch 'main' into chunked-prefill-scheduler-data-update
rkooo567 Mar 27, 2024
9d4b65c
Addressed code review.
rkooo567 Mar 28, 2024
54a58b2
Merge branch 'main' into chunked-prefill-scheduler-data-update
rkooo567 Mar 28, 2024
9bdb9dc
fix tests.
rkooo567 Mar 28, 2024
88126a9
fixed a bug
rkooo567 Mar 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
working except tests.
  • Loading branch information
rkooo567 committed Mar 21, 2024
commit df343509a5568afe5bd5894c1a46bf701bd2f0cc
28 changes: 5 additions & 23 deletions benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,6 @@

from vllm import LLM, SamplingParams

SAMPLE_PROMPTS = [
"The president of the United States is",
"Hello, my name is",
"The capital of France is",
"The future of AI is",
]


def main(args: argparse.Namespace):
print(args)
Expand All @@ -35,7 +28,6 @@ def main(args: argparse.Namespace):
device=args.device,
block_size=args.block_size,
max_chunked_prefill_len=args.max_chunked_prefill_len,
max_num_prompt_seqs=args.max_num_prompt_seqs,
ray_workers_use_nsight=args.ray_workers_use_nsight,
)

Expand Down Expand Up @@ -68,25 +60,16 @@ def run_to_completion(profile_dir: Optional[str] = None):
print(p.key_averages())
else:
start_time = time.perf_counter()
if args.use_sample:
batch = (SAMPLE_PROMPTS *
(args.batch_size // len(SAMPLE_PROMPTS) +
1))[:args.batch_size]
outputs = llm.generate(prompts=batch,
sampling_params=sampling_params,
use_tqdm=False)
else:
outputs = llm.generate(prompt_token_ids=dummy_prompt_token_ids,
sampling_params=sampling_params,
use_tqdm=False)
outputs = llm.generate(prompt_token_ids=dummy_prompt_token_ids,
sampling_params=sampling_params,
use_tqdm=False)
end_time = time.perf_counter()
if args.verbose:
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(
f"Prompt: {prompt!r}, Generated text: {generated_text!r}"
)
print(f"Prompt: {prompt!r}, Generated text: "
f"{generated_text!r}")
latency = end_time - start_time
return latency

Expand Down Expand Up @@ -182,7 +165,6 @@ def run_to_completion(profile_dir: Optional[str] = None):
action='store_true',
help='print generated text')
parser.add_argument('--max-chunked-prefill-len', type=int, default=-1)
rkooo567 marked this conversation as resolved.
Show resolved Hide resolved
parser.add_argument('--max-num-prompt-seqs', type=int, default=1000)
parser.add_argument(
"--ray-workers-use-nsight",
action='store_true',
Expand Down
1 change: 0 additions & 1 deletion benchmarks/kernels/benchmark_paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random
from vllm._C import ops
from vllm.model_executor.layers.attention import flash_attn_with_kvcache_paged

NUM_BLOCKS = 1024
PARTITION_SIZE = 512
Expand Down
99 changes: 0 additions & 99 deletions tests/chunked_prefill/test_correctness.py

This file was deleted.

2 changes: 0 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,6 @@ def __init__(
tensor_parallel_size: int = 1,
block_size: int = 16,
max_chunked_prefill_len: int = -1,
max_num_prompt_seqs: int = 1000,
max_num_batched_tokens: int = 4096,
**kwargs,
) -> None:
Expand All @@ -182,7 +181,6 @@ def __init__(
tensor_parallel_size=tensor_parallel_size,
block_size=block_size,
max_chunked_prefill_len=max_chunked_prefill_len,
max_num_prompt_seqs=max_num_prompt_seqs,
max_num_batched_tokens=max_num_batched_tokens,
**kwargs,
)
Expand Down
27 changes: 11 additions & 16 deletions tests/core/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,34 +134,28 @@ def test_scheduler_schedule_chunked_prefill():
num_seq_group = 2
max_model_len = 16
max_chunked_prefill_len = 2
max_num_prompt_seqs = 1
scheduler_config = SchedulerConfig(
64,
num_seq_group,
max_model_len,
flash_style=True,
max_chunked_prefill_len=max_chunked_prefill_len,
max_num_prompt_seqs=max_num_prompt_seqs)
cache_config = CacheConfig(block_size, 1.0, 1)
max_chunked_prefill_len=max_chunked_prefill_len)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 8
cache_config.num_gpu_blocks = 8
scheduler = Scheduler(scheduler_config, cache_config, None)

# Add seq groups to scheduler.
seq_groups: List[SequenceGroup] = []
for i in range(num_seq_group):
_, seq_group = create_dummy_prompt(str(i),
prompt_length=block_size,
num_processed_token_ids=block_size -
1)
_, seq_group = create_dummy_prompt(str(i), prompt_length=block_size)
scheduler.add_seq_group(seq_group)
seq_groups.append(seq_group)

# Schedule chunk prefill. Only the first seq_group should be scheduled.
seq_group_meta, out = scheduler.schedule()
assert set(out.scheduled_seq_groups) == set(seq_groups[:1])
seq_groups[0].get_num_unprefilled() == 2
seq_groups[1].get_num_unprefilled() == 4
assert seq_groups[0].get_num_unprefilled() == 2
assert seq_groups[1].get_num_unprefilled() == 4
assert out.num_batched_tokens == 2
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
and not out.blocks_to_swap_out)
Expand All @@ -170,11 +164,12 @@ def test_scheduler_schedule_chunked_prefill():
assert seq_group_meta[0].is_chunked_prefill
assert seq_group_meta[0].is_prompt

# Schedule chunk prefill. Still Only the first seq_group should be scheduled.
# Schedule chunk prefill. Still Only the first seq_group should be
# scheduled.
seq_group_meta, out = scheduler.schedule()
assert set(out.scheduled_seq_groups) == set(seq_groups[:1])
seq_groups[0].get_num_unprefilled() == 0
seq_groups[1].get_num_unprefilled() == 4
assert seq_groups[0].get_num_unprefilled() == 0
assert seq_groups[1].get_num_unprefilled() == 4
assert out.num_batched_tokens == 2
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
and not out.blocks_to_swap_out)
Expand All @@ -187,8 +182,8 @@ def test_scheduler_schedule_chunked_prefill():
# for chunk prefill, and the first seq_group should be select for decoding.
seq_group_meta, out = scheduler.schedule()
assert set(out.scheduled_seq_groups) == set(seq_groups)
seq_groups[0].get_num_unprefilled() == 0
seq_groups[1].get_num_unprefilled() == 2
assert seq_groups[0].get_num_unprefilled() == 0
assert seq_groups[1].get_num_unprefilled() == 2
assert out.num_batched_tokens == 3
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
and not out.blocks_to_swap_out)
Expand Down
43 changes: 0 additions & 43 deletions tests/samplers/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,49 +285,6 @@ def test_sampling(model_runner: ModelRunner):
del model_runner


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_logits_processors(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
input_tensor, _, sampler, model_runner = _prepare_test(batch_size)

# This sample logits processor gives infinite score to the i-th token,
# where i is the length of the input sequence.
# We therefore expect the output token sequence to be [0, 1, 2, ...]
def pick_ith(token_ids, logits):
logits[len(token_ids)] = float("inf")
return logits

seq_group_metadata_list = []
prompt_lens = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
is_chunked_prefill=False,
seq_data={0: SequenceData([1, 2, 3])},
sampling_params=SamplingParams(temperature=0,
logits_processors=[pick_ith]),
block_tables={0: [1]},
))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())

sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
for _, sequence_output in enumerate(sampler_output):
for idx, nth_output in enumerate(sequence_output.samples):
assert nth_output.output_token == idx

del model_runner


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_top_k_top_p(seed: int, device: str):
Expand Down
4 changes: 0 additions & 4 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,8 +539,6 @@ class SchedulerConfig:
requests. Longer requests will be chunked into multiple chunks.
-1 means no chunking (disabled). This features is only supported
for flash style attention.
max_num_prompt_seqs: The maximum number of prompt sequences that can be
processed in a single iteration.
"""

def __init__(
Expand All @@ -549,7 +547,6 @@ def __init__(
max_num_seqs: int,
max_model_len: int,
max_chunked_prefill_len: int = -1,
max_num_prompt_seqs: int = 1024,
) -> None:
if max_num_batched_tokens is not None:
self.max_num_batched_tokens = max_num_batched_tokens
Expand All @@ -561,7 +558,6 @@ def __init__(
self.max_model_len = max_model_len
self.chunked_prefill_enabled = max_chunked_prefill_len != -1
self.max_chunked_prefill_len = max_chunked_prefill_len
self.max_num_prompt_seqs = max_num_prompt_seqs
self._verify_args()

def _verify_args(self) -> None:
Expand Down
1 change: 0 additions & 1 deletion vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,6 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
block_tables[seq_id] = self.block_manager.get_block_table(seq)
self.block_manager.access_all_blocks_in_seq(seq, now)

# SANG-TODO Update chunked prefill related info.
seq_group_metadata = SequenceGroupMetadata(
request_id=seq_group.request_id,
is_prompt=scheduler_outputs.prompt_run,
Expand Down
4 changes: 1 addition & 3 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ class EngineArgs:
device: str = 'auto'
ray_workers_use_nsight: bool = False
max_chunked_prefill_len: int = -1
max_num_prompt_seqs: int = 256

def __post_init__(self):
if self.tokenizer is None:
Expand Down Expand Up @@ -357,8 +356,7 @@ def create_engine_configs(
self.max_num_batched_tokens,
self.max_num_seqs,
model_config.max_model_len,
max_chunked_prefill_len=self.max_chunked_prefill_len,
max_num_prompt_seqs=self.max_num_prompt_seqs)
max_chunked_prefill_len=self.max_chunked_prefill_len)
lora_config = LoRAConfig(
max_lora_rank=self.max_lora_rank,
max_loras=self.max_loras,
Expand Down
2 changes: 0 additions & 2 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,8 +615,6 @@ def step(self) -> List[RequestOutput]:
>>> break
"""
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
# print("SANG-TODO step seq_group_metadata_list length: ",
# len(seq_group_metadata_list))
if not scheduler_outputs.is_empty():
output = self.model_executor.execute_model(
seq_group_metadata_list, scheduler_outputs.blocks_to_swap_in,
Expand Down
1 change: 0 additions & 1 deletion vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ def generate(
A list of `RequestOutput` objects containing the generated
completions in the same order as the input prompts.
"""
# print("SANG-TODO generate: ", prompts, prompt_token_ids)
if prompts is None and prompt_token_ids is None:
raise ValueError("Either prompts or prompt_token_ids must be "
"provided.")
Expand Down
1 change: 0 additions & 1 deletion vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,6 @@ def create_kv_caches_with_random(
key_caches.append(key_cache)

value_cache_shape = (num_blocks, num_heads, head_size, block_size)

value_caches = []
for _ in range(num_layers):
value_cache = torch.empty(size=value_cache_shape,
Expand Down
Loading
Loading