Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core][5/N] Fully working chunked prefill e2e #3884

Merged
merged 16 commits into from
Apr 11, 2024

Conversation

rkooo567
Copy link
Collaborator

@rkooo567 rkooo567 commented Apr 6, 2024

This PR is a part of the RFC #3130.

This PR enables chunked prefill e2e. Note that chunked prefill is an experimental feature now (though it is actively used within Anyscale), and I will start serious benchmark with this PR.

The feature can be enabled by using enable_chunked_prefill. The chunking is done based on max_num_batched_tokens (a.k.a token budget). This is the same way as described in SARATHI paper https://arxiv.org/abs/2308.16369. It also means we can have up to maximum 2 chunked prefill at any given time.

This PR

  • Allow to put 2 attention metadata, one for prefill and one for decode.
    • it is done by broadcasting metadata twice. It can be more optimized by coelescing tensors better, but I made this way for simplicity. The normal path should just use 1 broadcast as usual.
  • Allow to run attention backend when there are prefill and decode mixed up.
  • Ignore the generated token if chunked prefill is enabled.
  • Add a chunked prefill test with various chunk size, cuda graph, and tp settings.

PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

@rkooo567 rkooo567 changed the title [WIP] Chunked prefill e2e oss [WIP] Chunked prefill e2e Apr 6, 2024
@rkooo567 rkooo567 changed the title [WIP] Chunked prefill e2e [Core][5/N] Fully working chunked prefill e2e Apr 6, 2024
# higher throughput.
self.max_num_batched_tokens = max(max_model_len, 2048)
if enable_chunked_prefill:
# For chunked prefill, choose the well-tuned batch size.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is arbitrary value. We should profile again.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would imagine this is hardware/model dependent. @AgrawalAmey any suggestion for a good default value?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, there is a no good default -- all depends on application requirements, model, hw, etc. I have an experimental auto-tuner which sets the maximum chunk size that can satisfy the application requirements. It is pretty straight forward, it just takes tbt target and find out the max token budget that can satisfy it.

vllm/core/scheduler.py Outdated Show resolved Hide resolved
@@ -49,17 +50,14 @@ def copy_blocks(


@dataclass
class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
class TorchSDPAMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: Currently, do we run all attention backend in tests?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR welcomed!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will create a PR for it

assert len(return_prompt_lens) == 0


# SANG-TODO Test chunked prefill case.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one is WIP

@rkooo567
Copy link
Collaborator Author

rkooo567 commented Apr 6, 2024

Initial throughput benchmark on 7B, 1 GPU

# not enabled
Throughput: 1.98 requests/s, 974.93 tokens/s
# enabled
Throughput: 2.02 requests/s, 995.76 tokens/s

I am going to start further benchmark at better machines and higher TP next week

@scv119
Copy link
Contributor

scv119 commented Apr 6, 2024

for the benchmarks, does it enable cuda-graph? also worth noting the detail of benchmarks (i.e. type of GPU and benchmark setup)

@ywang96
Copy link
Member

ywang96 commented Apr 7, 2024

@rkooo567 If this PR is ready, I'm happy to test it with some benchmarking as well!

@rkooo567
Copy link
Collaborator Author

rkooo567 commented Apr 7, 2024

@ywang96 I think it may actually work. The test failures now may be just test specific issues. The basic_correctness_test seems pass when I ran it (including cuda graph + tp > 1).

@scv119 actually the original RFC doesn't include cuda graph yet after talking with Woosuk (he mentioned he didn't see much improvement in cuda graph if batch size > 256). I will run our internal with decode-only cuda graph (oss status quo) vs whole cuda graph to see the actual impact and do a follow up. Have you observed big perf difference here?

@rkooo567
Copy link
Collaborator Author

rkooo567 commented Apr 7, 2024

also want to emphaisze the benchmark above runs 1 A10 GPU + llama 7B. It was just a dry run. I will do more serious benchmark next week, and we may need cuda graph for chunked prefill to (now it is enabled only when batch only contains decode)

@scv119
Copy link
Contributor

scv119 commented Apr 7, 2024

actually the original RFC doesn't include cuda graph yet after talking with Woosuk (he mentioned he didn't see much improvement in cuda graph if batch size > 256). I will run our internal with decode-only cuda graph (oss status quo) vs whole cuda graph to see the actual impact and do a follow up. Have you observed big perf difference here

my sense is cuda-graph will help with tp > 1 cases regardless of chunk-prefill enabled or not; but better just run some benchmarks.

@@ -49,17 +50,14 @@ def copy_blocks(


@dataclass
class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
class TorchSDPAMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR welcomed!

Comment on lines 229 to 231
if num_prefill_tokens > 0:
prefill_meta = attn_metadata.prefill_metadata
assert prefill_meta is not None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we the invariant check in the AttentionMetadata data class?

Suggested change
if num_prefill_tokens > 0:
prefill_meta = attn_metadata.prefill_metadata
assert prefill_meta is not None
if prefill_meta: = attn_metadata.prefill_metadata:
assert prefill_meta is not None

Python 3.8's valrus operator is very useful here. But it does need num_prefill_tokens invariant to be verified somewhere else?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved invariant check and used warlus operator

Copy link
Member

@zhuohan123 zhuohan123 Apr 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a fan of Walrus operator in python and I believe it's pretty controversial. I would vote for limiting it's use in vLLM. @simon-mo what do you think?

But this is a small thing. I don't think we should block any PR with this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@WoosukKwon @zhuohan123 would love your feedback on this interface change where the new AttentionMetadata can contains both prefill and decode stage metadata. I think this is a necessary change but would like to hear your thought on the interface design.

Comment on lines 80 to 81
slot_mapping: torch.Tensor
kv_cache_dtype: str
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

group these two separately from the prefill/decode. i guess it make sense here as a general arguments instead of the per stage ones.

# higher throughput.
self.max_num_batched_tokens = max(max_model_len, 2048)
if enable_chunked_prefill:
# For chunked prefill, choose the well-tuned batch size.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would imagine this is hardware/model dependent. @AgrawalAmey any suggestion for a good default value?

vllm/config.py Outdated Show resolved Hide resolved
vllm/core/scheduler.py Outdated Show resolved Hide resolved
vllm/worker/model_runner.py Outdated Show resolved Hide resolved
vllm/worker/model_runner.py Outdated Show resolved Hide resolved
@zhuohan123 zhuohan123 self-assigned this Apr 9, 2024
attn_metadata.prompt_lens_tensor,
attn_metadata.context_lens,
attn_metadata.max_subquery_len,
prefill_meta.block_tables,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rkooo567 you mentioned this kernel is really slow right? should we show some warning if someone uses this backend with chunked prefills?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm feel like that's a bit weird because there's no alternative option. I think we should probably just mark it as experimental feature for now?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can also do some additional e2e benchmark. When I ran it with 7B, the perf was actually not very different

else:
# prefix-enabled attention
# TODO(Hai) this triton kernel has regression issue (broke) to
# deal with different data types between KV and FP8 KV cache,
# to be addressed separately.
output = PagedAttention.forward_prefix(
output[:num_prefill_tokens] = PagedAttention.forward_prefix(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with flash_attn_varlen_func, we can do both prefill and decode attention together right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! I heard that it doesn't make much perf change, but we can iterate on this once the API supports larger page size

# for higher throughput.
self.max_num_batched_tokens = max(max_model_len, 2048)
if enable_chunked_prefill:
logger.info("Chunked prefill is enabled (EXPERIMENTAL).")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rkooo567
Copy link
Collaborator Author

All comments are addressed, and I believe the tests are passing now

Copy link
Member

@zhuohan123 zhuohan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rkooo567 Let's chat offline about this PR? My high-level comment: with chunked prefill, we should no longer distinguish prefill and decode outside the attention layer, which should simplify the code significantly.

vllm/attention/backends/abstract.py Show resolved Hide resolved
assert query.shape[0] == num_prefill_tokens
assert decode_query.shape[0] == num_decode_tokens

if prefill_meta := attn_metadata.prefill_metadata:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small Python style question: is := widely used now? I personally prefer the more old-school way.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was based on simon's feedback. #3884 (comment). Lmk if you want to just revert this

Comment on lines 174 to 185
async_handles = []
for key, value in metadata_list:
if isinstance(value, TensorMetadata):
tensor = tensor_dict[key]
torch.distributed.broadcast(tensor, src=src, group=group)
async_handles.append(
torch.distributed.broadcast(tensor,
src=src,
group=group,
async_op=True))
for async_handle in async_handles:
async_handle.wait()

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for optimizing this!

Comment on lines 44 to 64
class PreparePromptMetadata(NamedTuple):
input_tokens: List[int]
input_positions: List[int]
attn_metadata: Optional[AttentionMetadataPerStage]
prompt_lens: List[int]
subquery_lens: List[int]
lora_index_mapping: List[int]
lora_prompt_mapping: List[int]
lora_requests: Set[LoRARequest]
multi_modal_input: Optional[torch.Tensor]
slot_mapping: List[int]


class PrepareDecodeMetadata(NamedTuple):
input_tokens: List[int]
input_positions: List[int]
attn_metadata: Optional[AttentionMetadata]
lora_index_mapping: List[int]
lora_prompt_mapping: List[int]
lora_requests: Set[LoRARequest]
slot_mapping: List[int]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is chunked prefill handled here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This addresses #3884 (comment) comment. It is combined inside prepare_input_tensors

vllm/worker/model_runner.py Outdated Show resolved Hide resolved
vllm/worker/model_runner.py Show resolved Hide resolved
vllm/worker/model_runner.py Show resolved Hide resolved
@rkooo567
Copy link
Collaborator Author

rkooo567 commented Apr 10, 2024

@zhuohan123 I believe the current style is the best iterative step. (it is also the way it is done in our internal repo). I'd love to do unification as a followup, but I feel like we should do it step by step instead of all in the current PR;

  • Some of internal metadata is not consistent between decode/prefill. For example, context_len. We should fix these tech debt.
  • This requires to refactor existing code to be prefill/decode agonistic because chunked prefill co-exist with existing code at least for some time.

Let's talk more in details tmrw!

@rkooo567
Copy link
Collaborator Author

Discussed offline.

  • We agreed on what's "clean" state. It requires some thoughts, and I will do it as a follow up
  • The sequence is logprob fix -> cuda graph enablement for prefill -> refactor this part to be clean

@simon-mo simon-mo merged commit 67b4221 into vllm-project:main Apr 11, 2024
35 checks passed
attn_metadata.kv_cache_dtype,
self.num_kv_heads,
self.scale,
self.alibi_slopes,
kv_scale,
)
assert out.shape == output[num_prefill_tokens:].shape
output[num_prefill_tokens:]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be output[num_prefill_tokens:] = out here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops that;s correct. I don't know how the test passes...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh it is handled 8afca50

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We didn't enable much test for cpu path, we will try to add step by step later. thanks for your confirming!:)

SageMoore pushed a commit to neuralmagic/nm-vllm that referenced this pull request Apr 11, 2024
@casper-hansen
Copy link
Contributor

It also means we can have up to maximum 2 chunked prefill at any given time.

Can we optimize this to allow many more chunks at a time? I thought the main optimization of DeepSpeed-MII was the chunked prefilling, which greatly speeds up the context processing and therefore the throughput.

@rkooo567
Copy link
Collaborator Author

rkooo567 commented Apr 12, 2024

@casper-hansen that's totally doable. Do you happen to know the perf benefits it can bring or a reference sections in a paper that explaisn it? I followed the policy from https://arxiv.org/abs/2308.16369. Also cc @AgrawalAmey for thoughts

andy-neuma pushed a commit to neuralmagic/nm-vllm that referenced this pull request Apr 12, 2024
z103cb pushed a commit to z103cb/opendatahub_vllm that referenced this pull request Apr 22, 2024
Temirulan pushed a commit to Temirulan/vllm-whisper that referenced this pull request Sep 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants