Skip to content

[Frontend] Add prefix sorting as a precursor to BatchLLM optimization #13762

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

fangtaosong
Copy link

@fangtaosong fangtaosong commented Feb 24, 2025

This PR implements request sorting to maximize prefix reuse when both chunked prefill and prefix caching are enabled. This serves as the first step towards the full BatchLLM optimization proposed in our RFC.

Motivation:

Currently, vLLM performs implicit (or just-in-time) shared prefix identification and metadata collection, and then performs cascade attention when there's a single shared prefix for all requests, as described in PR #11635. However, as suggested by WoosukKwon, this approach does not fully utilize the shared prefix in offline scenarios where there are many requests with different shared prefixes.

In offline settings, all requests are available before inference begins, making implicit prefix identification suboptimal. By explicitly sorting requests based on their shared prefixes, we can better maximize prefix reuse, improve KV-cache management, and significantly enhance throughput for batched requests.

Changes:

  • Add a --enable-prefix-sorting flag to control prefix sorting
  • Implement prefix sorting using Python's built-in sort() function in the LLM API layer
  • Only enable sorting when both --enable-chunked-prefill and --enable-prefix-caching are enabled

Performance improvement:

Test setup:

  • Model: Llama-3.1-8B
  • 3000 requests with context length 3000 and prompt length 50
  • VLLM_USE_V1=1 and FLASHINFER backend

Test Script:

# SPDX-License-Identifier: Apache-2.0

import torch
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
import time
import datetime
import argparse
import random


def parse_args():
    parser = argparse.ArgumentParser(description='vLLM performance test')
    parser.add_argument('--model-path', type=str, default='/path/to/Llama-3.1-8B-Instruct')
    parser.add_argument('--request-num', type=int, default=800)
    parser.add_argument('--context-len', type=int, default=2000)
    parser.add_argument('--prompt-len', type=int, default=200)
    parser.add_argument('--generate-len', type=int, default=100)
    parser.add_argument('--enable-chunked-prefill', action='store_true')
    parser.add_argument('--share-degree', type=int, default=8)
    parser.add_argument('--shuffle', action='store_true')
    parser.add_argument('--use-original-str', action='store_true')
    parser.add_argument('--enable-prefix-sorting', action='store_true',
                        help='enable_prefix_sorting')
    return parser.parse_args()


def prepare_token_ids(context_len, prompt_len, group_num, batch_size):
    share, all_t, group_idx = [], [], []
    for i in range(group_num):
        context_this_group = torch.randint(1, 20000, (context_len,))
        share.append(context_this_group.tolist())
        for _ in range(batch_size):
            prompt_this_request = torch.randint(1, 20000, (prompt_len,))
            all_t.append(torch.concat((context_this_group[0:context_len],
                                       prompt_this_request[0:prompt_len]), 0).tolist())
            group_idx.append(i)

    return all_t, group_num, group_num * batch_size, share


def prepare_tokens(context_len, prompt_len, group_num, batch_size):
    share, all_t, group_idx = [], [], []
    for i in range(group_num):
        # rand string as context
        import random
        context_this_group = "".join([random.choice('abcdefghijklmnopqrstuvwxyz') for _ in range(context_len)])
        for _ in range(batch_size):
            prompt_this_request = "".join([random.choice('abcdefghijklmnopqrstuvwxyz') for _ in range(prompt_len)])
            all_t.append(context_this_group + prompt_this_request)
            group_idx.append(i)
            share.append(context_this_group)


    return all_t, group_num, group_num * batch_size, share

def run_pipeline(input_tokens,input_prompts, pipe, max_tokens=100):
    sampling_params = SamplingParams(temperature=0.01, top_p=0.1,
                                     max_tokens=max_tokens)
    t1 = time.time()
    assert input_tokens is None or input_prompts is None
    output = pipe.generate(prompts=input_prompts,prompt_token_ids=input_tokens, sampling_params=sampling_params)
    return output, time.time() - t1


def main():
    args = parse_args()
    model_name = args.model_path.split('/')[-2] if args.model_path.endswith('/') else args.model_path.split('/')[-1]
    tokenizer = AutoTokenizer.from_pretrained(args.model_path)
    if args.enable_chunked_prefill:
        max_num_batched_tokens = 4096
        pipe = LLM(model=args.model_path, enable_prefix_caching=True,
                   enable_chunked_prefill=args.enable_chunked_prefill,
                   max_num_batched_tokens=max_num_batched_tokens,
                   enable_prefix_sorting=args.enable_prefix_sorting)
    else:
        pipe = LLM(model=args.model_path, enable_prefix_caching=True)

    print(f'Model: {model_name}')
    print(f'Context length: {args.context_len + args.prompt_len}, Output length: {args.generate_len}')
    print(f'Time: {datetime.datetime.now()} vLLM performance test')
    print(f'Engine path: {args.model_path}')
    print(f'Chunked prefill: {args.enable_chunked_prefill}')


    group_num = args.request_num // args.share_degree
    if not args.use_original_str:
        input_tokens, group_num, prompt_num, share_tokens = prepare_token_ids(
            args.context_len, args.prompt_len, group_num,
            args.share_degree
        )
        input_prompt = None
    else:
        input_prompt, group_num, prompt_num, share_tokens = prepare_tokens(
            args.context_len, args.prompt_len, group_num,
            args.share_degree
        )
        input_tokens = None
    if args.shuffle:
        print('Shuffling input tokens')
        random.shuffle(input_tokens if input_tokens is not None else input_prompt)


    final_output, gen_time = run_pipeline(input_tokens, input_prompt, pipe, args.generate_len)
    print('\ngroup_num prompt_num   time   throughput')
    print(
        f'{group_num:9} {prompt_num:10} {gen_time:8.2f} {prompt_num / gen_time:10.2f} ')


if __name__ == "__main__":
    main()

Test Commands:

# Baseline with shuffle (random order)
VLLM_USE_V1=1 VLLM_ATTENTION_BACKEND=FLASHINFER python test_script.py \
    --model-path /path/to/Llama-3.1-8B-Instruct \
    --enable-chunked-prefill \
    --request-num 3000 \
    --context-len 3000 \
    --prompt-len 50 \
    --shuffle

# With prefix sorting enabled
VLLM_USE_V1=1 VLLM_ATTENTION_BACKEND=FLASHINFER python test_script.py \
    --model-path /path/to/Llama-3.1-8B-Instruct \
    --enable-chunked-prefill \
    --request-num 3000 \
    --context-len 3000 \
    --prompt-len 50 \
    --shuffle \
    --enable-prefix-sorting

Results:

Configuration Throughput (tokens/s)
With shuffle 3.93
With prefix sorting 9.21

The results show that:

  1. Random order (shuffle) significantly reduces throughput by ~57%.
  2. Prefix sorting restores the throughput back to baseline level.
  3. For datasets with more prefix sharing potential, sorting could provide even better improvements.

This is the first part of the BatchLLM optimization, focusing on request sorting only. Support for more complex prefix sharing patterns will be addressed in a separate PR.

Important Notes:

This optimization is currently only recommended when chunked prefill is enabled. With the current FlashInfer Cascade implementation in the default mode, prefix clustering can actually lead to a ~20% performance degradation. To achieve optimal performance across all modes, please refer to our original BatchLLM implementation in PR #12641.

The current PR serves as an initial step towards full BatchLLM optimization, focusing on request sorting only. Support for more complex prefix sharing patterns about Scheduler, ModelRunner and Kernel will be addressed in a separate PRs

…tion

Co-authored-by: xinji1 <xinji1@microsoft.com>
Co-authored-by: Fanghao Zhou <fanghaozhou@microsoft.com>
Co-authored-by: Zhen Zheng <zhengzhen@microsoft.com>
Signed-off-by: Taosong Fang <constfrost@foxmail.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the frontend label Feb 24, 2025
Copy link

mergify bot commented Feb 24, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @fangtaosong.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Feb 24, 2025
@fangtaosong
Copy link
Author

cc @WoosukKwon @comaniac for the next step.

Signed-off-by: Taosong Fang <constfrost@foxmail.com>
@mergify mergify bot removed the needs-rebase label Feb 24, 2025
Signed-off-by: Taosong Fang <constfrost@foxmail.com>
Copy link

mergify bot commented Feb 25, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @fangtaosong.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Feb 25, 2025
@mergify mergify bot removed the needs-rebase label Feb 25, 2025
Modify some code comments.
@fangtaosong fangtaosong changed the title [Frontend] Add prefix sorting as a precursor to BatchLLM optimiza… [Frontend] Add prefix sorting as a precursor to BatchLLM optimization Feb 27, 2025
@OpenDarrenlu
Copy link

@fangtaosong Hello, I noticed that you haven't maintained this PR recently. Maybe you are busy. I am interested in BatchLLM and would like to continue working on this PR. Is that ok?

@fangtaosong
Copy link
Author

fangtaosong commented Mar 11, 2025

@fangtaosong Hello, I noticed that you haven't maintained this PR recently. Maybe you are busy. I am interested in BatchLLM and would like to continue working on this PR. Is that ok?

Thank you for your interest in continuing this PR and BatchLLM. Please go ahead and work on it, I'll be happy to see the progress being made.

You can get more information from the PR at #12641. At the same time, I highly recommend that you share your work plan (such as adding new modules or performing a rebase, etc.) publicly or privately.

Copy link

mergify bot commented Mar 12, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @fangtaosong.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 12, 2025
Copy link

This pull request has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this pull request should remain open. Thank you!

@github-actions github-actions bot added the stale Over 90 days of inactivity label Jun 11, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
frontend needs-rebase stale Over 90 days of inactivity
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants