-
Notifications
You must be signed in to change notification settings - Fork 264
【Feature】 dp chuncked prefill balance. #937
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @hiworldwzj, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request introduces a new feature to balance chunked prefill operations in Data Parallel (DP) mode, aiming to reduce computational waste while allowing for configurable control over first token latency. My changes centralize the state management for chunked prefill scheduling and refactor core inference logic into reusable components, improving overall efficiency and maintainability of the system.
Highlights
- New Configuration for DP Prefill Pacing: I've introduced a new command-line argument
--dp_prefill_wait_step
to provide fine-grained control over the pacing of chunked prefill operations in Data Parallel (DP) mode. This parameter helps balance computational efficiency by reducing waste with potential impacts on first token latency. - Centralized Chunked Prefill State Management: A new
ChunkedPrefillState
class has been added to centralize and manage the scheduling logic for chunked prefill. This includes specific mechanisms to intelligently decide when to perform prefill, especially in DP setups, by considering the 'use ratio' of requests across distributed ranks. - Refactored Inference Logic for Reusability: I've extracted common prefill and decode inference steps into reusable
normal_prefill_reqs
andnormal_decode
methods withinContinuesBatchBackend
. This significantly reduces code duplication and improves maintainability across various chunked prefill backend implementations. - Enhanced Distributed Communication for Scheduling: New helper methods (
_dp_all_gather_prefill_req_num
and_dp_all_reduce_decode_req_num
) have been added toModeBackend
. These facilitate efficient collection of prefill and decode request counts across distributed processes, enabling smarter and more balanced scheduling decisions in DP environments.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
The code changes introduce a new parameter dp_prefill_wait_step
to control the pacing of DP chunked prefill mode, aiming to reduce computational waste during prefill. The changes also include modifications to several backend implementations to support this new parameter and improve the efficiency of prefill and decode operations in DP mode. The code has several long comments that could be improved, and some missing error handling that should be added.
"--dp_prefill_wait_step", | ||
type=int, | ||
default=0, | ||
help="""dp_prefill_wait_step is used to control the pacing of dp chunked prefill mode, aiming to reduce | ||
computational waste during prefill. However, higher values can negatively impact the | ||
first token latency. It is generally recommended to set this value between 0 and 6.""", | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# chuncked prefill 推理的过程中,存在很多模式的延迟 step 推理的控制, 用于 | ||
# 保证更好的包间数据或者是提升 dp 模式下prefill 的效率,但是在估计 token 显存 | ||
# 占用量的过程中,分chuncked 需要考虑其因为分 chuncked带来的生命期的延长,具体 | ||
# 体现就是在 b_len 的计算中,xxx * (max_waiting_token + 1) 的部分,这部分 | ||
# 就是通过模拟加长其输出token长度,来延长其在估计阶段的生命周期。max_waiting_token | ||
# 的计算是保守的,每次chuncked prefill 延迟的最大步数为两种模式之合,因为 | ||
# 这个并不会导致预估的token占用量大幅增加,所以可以放心使用。 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
req.mtp_gen_token_ids.append(token_id) | ||
return | ||
|
||
def _dp_all_gather_prefill_req_num(self, prefill_reqs: List[InferReq]) -> Tuple[np.ndarray, int]: | ||
""" | ||
Gather the number of prefill requests across all DP ranks. | ||
""" | ||
current_dp_prefill_num = len(prefill_reqs) | ||
self.dp_gather_item_tensor.fill_(current_dp_prefill_num) | ||
dist.all_gather_into_tensor(self.dp_all_gather_tensor, self.dp_gather_item_tensor, group=None, async_op=False) | ||
dp_prefill_req_nums = self.dp_all_gather_tensor.cpu().numpy() | ||
max_prefill_num = np.max(dp_prefill_req_nums) | ||
return dp_prefill_req_nums, max_prefill_num |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
current_dp_decode_num = len(decode_reqs) | ||
self.dp_reduce_tensor.fill_(current_dp_decode_num) | ||
dist.all_reduce(self.dp_reduce_tensor, op=dist.ReduceOp.MAX, group=None, async_op=False) | ||
max_decode_num = self.dp_reduce_tensor.item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for req_obj in run_reqs: | ||
req_obj.sampling_param.shm_param.top_k = req_obj.origin_topk | ||
if hasattr(req_obj, "origin_topk"): | ||
req_obj.sampling_param.shm_param.top_k = req_obj.origin_topk |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No description provided.