Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance] Introducing Prefix-Cached Chunked Prefill with flash-attn backend and 10% throughput gained under prompt <1K #6819

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions vllm/core/block_manager_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,11 @@ def compute_full_blocks_in_seq(self, seq: Sequence):
if max_full_block == -1:
return
for i in reversed(range(max_full_block)):
# [help wanted]
# max_full_block < block_table makes sense, but combining pc + cp may produce a conflict,
# (do not know why) so following 'if' statement is needed, little hurt for performance
if i >= len(block_table):
continue
if block_table[i].computed:
break
block_table[i].computed = True
Expand Down
35 changes: 27 additions & 8 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,9 @@ def __init__(

self.multi_modal_inputs = multi_modal_inputs
self.prefix_cache_hit = prefix_cache_hit

# maybe dirty hack
self.cached_len = 0

self.__post_init__()

Expand Down Expand Up @@ -365,20 +368,36 @@ def _compute_for_prefix_cache_hit(
and self.sliding_window is None
and inter_data.is_prompt)
inter_data.prefix_cache_hit = prefix_cache_hit
if self.chunked_prefill_enabled and prefix_cache_hit:
raise RuntimeError(
"chunked prefill cannot be used with prefix caching now.")
# if self.chunked_prefill_enabled and prefix_cache_hit:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove it?

# raise RuntimeError(
# "chunked prefill cannot be used with prefix caching now.")

# If prefix cache is hit, advance context length to bypass
# hit blocks. Accordingly, input tokens, position and query length
# have to be updated.
if prefix_cache_hit:
assert computed_block_nums is not None
context_len = len(computed_block_nums) * self.block_size
inter_data.input_tokens[seq_idx] = inter_data.input_tokens[
seq_idx][context_len:]
inter_data.input_positions[seq_idx] = inter_data.input_positions[
seq_idx][context_len:]

# [help wanted]
# inter_data.cached_len is just a work-around for mutable cached length when chunked prefill
# enabled with prefix caching, in order to fix the cached length for the same seq being chunked
context_len = inter_data.context_lens[seq_idx]
if context_len == 0:
inter_data.cached_len = len(computed_block_nums) * self.block_size
context_len = min(inter_data.cached_len, seq_group_metadata.token_chunk_size - 1)
inter_data.input_tokens[seq_idx] = inter_data.input_tokens[
seq_idx][context_len:]
inter_data.input_positions[seq_idx] = inter_data.input_positions[
seq_idx][context_len:]
else:
if inter_data.cached_len > context_len:
delta_len = min(inter_data.cached_len - context_len, seq_group_metadata.token_chunk_size - 1)
context_len += delta_len
inter_data.input_tokens[seq_idx] = inter_data.input_tokens[
seq_idx][delta_len:]
inter_data.input_positions[seq_idx] = inter_data.input_positions[
seq_idx][delta_len:]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Juelianqvq, I have read the diff and if I understand it correctly, the key different between this PR with #6144 is: leaving at least 1 token for prefill for each sequence. I could add such logic into #6144.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sighingnow Actually not only that, you have mentioned the correctness of keeping at least 1 tokens which you have missed before. Moreover, the modification in block_manager.py matters too. Just have a try and see whether you can have a inference speed up with only modifying keeping token logic. I've got the answer because I've developed in so many cases and pointed out the existing problem in my PR which certainly behaves faster using a work-around way.


inter_data.context_lens[seq_idx] = context_len
inter_data.query_lens[
seq_idx] = inter_data.seq_lens[seq_idx] - context_len
Expand Down
Loading