Skip to content

[Frontend] Add prefix sorting as a precursor to BatchLLM optimiza… #13740

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

fangtaosong
Copy link

@fangtaosong fangtaosong commented Feb 24, 2025

This PR implements request sorting to maximize prefix reuse when both chunked prefill and prefix caching are enabled. This serves as the first step towards the full BatchLLM optimization proposed in our RFC.

Motivation:

Currently, vLLM performs implicit (or just-in-time) shared prefix identification and metadata collection, and then performs cascade attention when there's a single shared prefix for all requests, as described in PR #11635. However, as suggested by WoosukKwon, this approach does not fully utilize the shared prefix in offline scenarios where there are many requests with different shared prefixes.

In offline settings, all requests are available before inference begins, making implicit prefix identification suboptimal. By explicitly sorting requests based on their shared prefixes, we can better maximize prefix reuse, improve KV-cache management, and significantly enhance throughput for batched requests.

Changes:

  • Add a --enable-prefix-sorting flag to control prefix sorting
  • Implement prefix sorting using Python's built-in sort() function in the LLM API layer
  • Only enable sorting when both --enable-chunked-prefill and --enable-prefix-caching are enabled

Performance improvement:

Test setup:

  • Model: Llama-3.1-8B
  • 3000 requests with context length 3000 and prompt length 50
  • VLLM_USE_V1=1 and FLASHINFER backend

Test Script:

# SPDX-License-Identifier: Apache-2.0

import torch
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
import time
import datetime
import argparse
import random


def parse_args():
    parser = argparse.ArgumentParser(description='vLLM performance test')
    parser.add_argument('--model-path', type=str, default='/path/to/Llama-3.1-8B-Instruct')
    parser.add_argument('--request-num', type=int, default=800)
    parser.add_argument('--context-len', type=int, default=2000)
    parser.add_argument('--prompt-len', type=int, default=200)
    parser.add_argument('--generate-len', type=int, default=100)
    parser.add_argument('--enable-chunked-prefill', action='store_true')
    parser.add_argument('--share-degree', type=int, default=8)
    parser.add_argument('--shuffle', action='store_true')
    parser.add_argument('--use-original-str', action='store_true')
    parser.add_argument('--enable-prefix-sorting', action='store_true',
                        help='enable_prefix_sorting')
    return parser.parse_args()


def prepare_token_ids(context_len, prompt_len, group_num, batch_size):
    share, all_t, group_idx = [], [], []
    for i in range(group_num):
        context_this_group = torch.randint(1, 20000, (context_len,))
        share.append(context_this_group.tolist())
        for _ in range(batch_size):
            prompt_this_request = torch.randint(1, 20000, (prompt_len,))
            all_t.append(torch.concat((context_this_group[0:context_len],
                                       prompt_this_request[0:prompt_len]), 0).tolist())
            group_idx.append(i)

    return all_t, group_num, group_num * batch_size, share


def prepare_tokens(context_len, prompt_len, group_num, batch_size):
    share, all_t, group_idx = [], [], []
    for i in range(group_num):
        # rand string as context
        import random
        context_this_group = "".join([random.choice('abcdefghijklmnopqrstuvwxyz') for _ in range(context_len)])
        for _ in range(batch_size):
            prompt_this_request = "".join([random.choice('abcdefghijklmnopqrstuvwxyz') for _ in range(prompt_len)])
            all_t.append(context_this_group + prompt_this_request)
            group_idx.append(i)
            share.append(context_this_group)


    return all_t, group_num, group_num * batch_size, share

def run_pipeline(input_tokens,input_prompts, pipe, max_tokens=100):
    sampling_params = SamplingParams(temperature=0.01, top_p=0.1,
                                     max_tokens=max_tokens)
    t1 = time.time()
    assert input_tokens is None or input_prompts is None
    output = pipe.generate(prompts=input_prompts,prompt_token_ids=input_tokens, sampling_params=sampling_params)
    return output, time.time() - t1


def main():
    args = parse_args()
    model_name = args.model_path.split('/')[-2] if args.model_path.endswith('/') else args.model_path.split('/')[-1]
    tokenizer = AutoTokenizer.from_pretrained(args.model_path)
    if args.enable_chunked_prefill:
        max_num_batched_tokens = 4096
        pipe = LLM(model=args.model_path, enable_prefix_caching=True,
                   enable_chunked_prefill=args.enable_chunked_prefill,
                   max_num_batched_tokens=max_num_batched_tokens,
                   enable_prefix_sorting=args.enable_prefix_sorting)
    else:
        pipe = LLM(model=args.model_path, enable_prefix_caching=True)

    print(f'Model: {model_name}')
    print(f'Context length: {args.context_len + args.prompt_len}, Output length: {args.generate_len}')
    print(f'Time: {datetime.datetime.now()} vLLM performance test')
    print(f'Engine path: {args.model_path}')
    print(f'Chunked prefill: {args.enable_chunked_prefill}')


    group_num = args.request_num // args.share_degree
    if not args.use_original_str:
        input_tokens, group_num, prompt_num, share_tokens = prepare_token_ids(
            args.context_len, args.prompt_len, group_num,
            args.share_degree
        )
        input_prompt = None
    else:
        input_prompt, group_num, prompt_num, share_tokens = prepare_tokens(
            args.context_len, args.prompt_len, group_num,
            args.share_degree
        )
        input_tokens = None
    if args.shuffle:
        print('Shuffling input tokens')
        random.shuffle(input_tokens if input_tokens is not None else input_prompt)


    final_output, gen_time = run_pipeline(input_tokens, input_prompt, pipe, args.generate_len)
    print('\ngroup_num prompt_num   time   throughput')
    print(
        f'{group_num:9} {prompt_num:10} {gen_time:8.2f} {prompt_num / gen_time:10.2f} ')


if __name__ == "__main__":
    main()

Test Commands:

# Baseline with shuffle (random order)
VLLM_USE_V1=1 VLLM_ATTENTION_BACKEND=FLASHINFER python test_script.py \
    --model-path /path/to/Llama-3.1-8B-Instruct \
    --enable-chunked-prefill \
    --request-num 3000 \
    --context-len 3000 \
    --prompt-len 50 \
    --shuffle

# With prefix sorting enabled
VLLM_USE_V1=1 VLLM_ATTENTION_BACKEND=FLASHINFER python test_script.py \
    --model-path /path/to/Llama-3.1-8B-Instruct \
    --enable-chunked-prefill \
    --request-num 3000 \
    --context-len 3000 \
    --prompt-len 50 \
    --shuffle \
    --enable-prefix-sorting

Results:

Configuration Throughput (tokens/s)
With shuffle 3.93
With prefix sorting 9.21

The results show that:

  1. Random order (shuffle) significantly reduces throughput by ~57%.
  2. Prefix sorting restores the throughput back to baseline level.
  3. For datasets with more prefix sharing potential, sorting could provide even better improvements.

This is the first part of the BatchLLM optimization, focusing on request sorting only. Support for more complex prefix sharing patterns will be addressed in a separate PR.

Important Notes:

This optimization is currently only recommended when chunked prefill is enabled. With the current FlashInfer Cascade implementation in the default mode, prefix clustering can actually lead to a ~20% performance degradation. To achieve optimal performance across all modes, please refer to our original BatchLLM implementation in PR #12641.

The current PR serves as an initial step towards full BatchLLM optimization, focusing on request sorting only. Support for more complex prefix sharing patterns about Scheduler, ModelRunner and Kernel will be addressed in a separate PRs

…tion

Co-authored-by: xinji1 <xinji1@microsoft.com>
Co-authored-by: Fanghao Zhou <fanghaozhou@microsoft.com>
Co-authored-by: Zhen Zheng <zhengzhen@microsoft.com>
Signed-off-by: Taosong Fang <constfrost@foxmail.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the frontend label Feb 24, 2025
@fangtaosong
Copy link
Author

Please take a look at the RFC #12080 for more details.

cc @WoosukKwon @comaniac for the next step.

@fangtaosong fangtaosong changed the title [Optimization] Add prefix sorting as a precursor to BatchLLM optimiza… [Frontend] Add prefix sorting as a precursor to BatchLLM optimiza… Feb 24, 2025
@fangtaosong fangtaosong deleted the batchllm-pr1 branch February 24, 2025 04:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant