Skip to content

[Misc]add coding benchmark for speculative decoding #15303

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Mar 28, 2025

Conversation

CXIAAAAA
Copy link
Contributor

@CXIAAAAA CXIAAAAA commented Mar 21, 2025

add likaixin/InstructCoder for speculative decoding benchmark throughput

to run instruct coder benchmark:

VLLM_WORKER_MULTIPROC_METHOD=spawn VLLM_USE_V1=1  python3 benchmarks/benchmark_throughput.py --dataset-name=hf --dataset-path=likaixin/InstructCoder --model <you hf model> --input-len 1000 --output-len 100 --num-prompts 2048 --async-engine

to run random benchmark:

VLLM_WORKER_MULTIPROC_METHOD=spawn VLLM_USE_V1=1  python3 benchmarks/benchmark_throughput.py --dataset-name=random --model <you hf model> --input-len 1000 --output-len 100 --num-prompts 2048 --async-engine

baseline:
Throughput: 37.29 requests/s, 12226.80 total tokens/s, 7457.64 output tokens/s

ngram proposer: --speculative-model "[ngram]" --ngram_prompt_lookup_min 2 --ngram-prompt-lookup-max 5 --num_speculative_tokens 5
Throughput: 35.28 requests/s, 11569.61 total tokens/s, 7056.79 output tokens/s

benchmark_serving:
server: VLLM_USE_V1=1 vllm serve meta-llama/Meta-Llama-3-8B-Instruct (--speculative-model "[ngram]" --ngram_prompt_lookup_min 2 --ngram-prompt-lookup-max 5 --num_speculative_tokens 5)
client: python3 benchmarks/benchmark_serving.py --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name hf --dataset-path likaixin/InstructCoder --num-prompts 2048

baseline:
image

ngram proposer:
image

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link

@youran-qi youran-qi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you and LGTM!

For your references, here are the statistics about the input / output lengths (number of tokens) in this dataset. Not sure whether you want to change DEFAULT_OUTPUT_LEN accordingly

avg min max
instruction + input 151 15 837
output 179 9 1317

Copy link
Collaborator

@LiuXiaoxuanPKU LiuXiaoxuanPKU left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, @ywang96 please also take a look in case I miss anything!

Also, I'm just wondering if it's possible to share some simple benchmark results on the instructcoder dataset, really appreciate it!

@CXIAAAAA
Copy link
Contributor Author

CXIAAAAA commented Mar 25, 2025

@LiuXiaoxuanPKU i updated the benchmark, but looks like the throughput regressed a bit. I also shared my command do you think we should use low batch to benchmark or this looks good to you?

@CXIAAAAA CXIAAAAA requested a review from LiuXiaoxuanPKU March 25, 2025 06:08
Copy link
Member

@ywang96 ywang96 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! I left some nits.

Can you also share a example benchmark command with benchmark_serving.py and its result? Thanks!

@CXIAAAAA
Copy link
Contributor Author

@ywang96 posted the command and the results in the pr description

@CXIAAAAA
Copy link
Contributor Author

@JenZhao updated the code

@CXIAAAAA CXIAAAAA requested a review from JenZhao March 27, 2025 21:35
Copy link
Member

@ywang96 ywang96 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for the contribution!

Can you fix the pre-commit errors so we can merge it?

Copy link
Contributor

@JenZhao JenZhao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you!

@ywang96 ywang96 added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 27, 2025
CXIAAAAA added 10 commits March 28, 2025 01:11
Signed-off-by: CXIAAAAA <cxia0209@gmail.com>
Signed-off-by: CXIAAAAA <cxia0209@gmail.com>
Signed-off-by: CXIAAAAA <cxia0209@gmail.com>
Signed-off-by: CXIAAAAA <cxia0209@gmail.com>
Signed-off-by: CXIAAAAA <cxia0209@gmail.com>
Signed-off-by: CXIAAAAA <cxia0209@gmail.com>
Signed-off-by: CXIAAAAA <cxia0209@gmail.com>
Signed-off-by: CXIAAAAA <cxia0209@gmail.com>
Signed-off-by: CXIAAAAA <cxia0209@gmail.com>
Signed-off-by: CXIAAAAA <cxia0209@gmail.com>
@CXIAAAAA CXIAAAAA force-pushed the cxia_add_coder_benchmark branch from 93aa6d3 to a6e95c7 Compare March 28, 2025 01:11
@CXIAAAAA
Copy link
Contributor Author

@ywang96 look like it is ready to merge

@DarkLight1337 DarkLight1337 merged commit e7f720e into vllm-project:main Mar 28, 2025
15 of 16 checks passed
@lihuahua123
Copy link
Contributor

The ngram result seems worse, possibly because you did not use --speculative-disable-by-batch-size?

Alex4210987 pushed a commit to LeiWang1999/vllm-bitblas that referenced this pull request Apr 5, 2025
Signed-off-by: CXIAAAAA <cxia0209@gmail.com>
Signed-off-by: xinyuxiao <xinyuxiao2024@gmail.com>
lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
Signed-off-by: CXIAAAAA <cxia0209@gmail.com>
Signed-off-by: Louis Ulmer <ulmerlouis@gmail.com>
dataset_class = VisionArenaDataset
elif args.dataset_path == "likaixin/InstructCoder":
dataset_class = InstructCoderDataset
args.hf_split = "train"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious why was this hardcoded to train and not paramterized by args.hf_split as done in other cases? If we are eval a model then should it be tested on non train split?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is hardcoded for now, sure we could add it to args.hf_split.

For train/eval split, everyone's method of split might be slightly different, so this is just an example.

@CXIAAAAA
Copy link
Contributor Author

The ngram result seems worse, possibly because you did not use --speculative-disable-by-batch-size?

The default batch size is too large, normally we measure in low batch size, e.g. 1 or 32 etc.

lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Apr 29, 2025
shreyankg pushed a commit to shreyankg/vllm that referenced this pull request May 3, 2025
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
Signed-off-by: CXIAAAAA <cxia0209@gmail.com>
Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>
@ekagra-ranjan
Copy link
Contributor

but looks like the throughput regressed a bit. I also shared my command do you think we should use low batch to benchmark or this looks good to you?

yup, high bs in this bench was the reason why it regressed. This PR has latest speedup: #18971

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants