Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run v1 benchmark and integrate with PyTorch OSS benchmark database #13068

Merged
merged 19 commits into from
Feb 17, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Run v1 benchmark
Signed-off-by: Huy Do <huydhn@gmail.com>
  • Loading branch information
huydhn committed Feb 11, 2025
commit ce79bc5e514d4e3c143dc35a675fdcaa3f7c760e
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,15 @@ main() {
check_gpus
check_hf_token

# Set to v1 to run v1 benchmark
VLLM_VERSION=$1
if [[ "${VLLM_VERSION:-v0}" == "v1" ]]; then
export VLLM_USE_V1=1
fi

# Set to 0 to run the benchmark script locally without uploading to Buildkite
UPLOAD_TO_BUILDKITE=$2

# dependencies
(which wget && which curl) || (apt-get update && apt-get install -y wget curl)
(which jq) || (apt-get update && apt-get -y install jq)
Expand All @@ -371,7 +380,9 @@ main() {
pip install tabulate pandas
python3 $QUICK_BENCHMARK_ROOT/scripts/convert-results-json-to-markdown.py

upload_to_buildkite
if [[ "${UPLOAD_TO_BUILDKITE:-1}" == "1" ]]; then
upload_to_buildkite
fi
}

main "$@"
2 changes: 1 addition & 1 deletion .buildkite/nightly-benchmarks/tests/latency-tests.json
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@
"num-iters": 15
}
}
]
]
147 changes: 98 additions & 49 deletions benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
"""Benchmark the latency of processing a single batch of requests."""

import os
import argparse
import dataclasses
import json
import time
from pathlib import Path
from typing import List, Optional
from typing import List, Optional, Dict, Any

import numpy as np
import torch
Expand All @@ -18,6 +20,38 @@
from vllm.utils import FlexibleArgumentParser


def save_to_pytorch_benchmark_format(
args: argparse.Namespace, results: Dict[str, Any]
) -> None:
# https://github.com/pytorch/pytorch/wiki/How-to-integrate-with-PyTorch-OSS-benchmark-database
record = {
"benchmark": {
"name": "vLLM benchmark",
"extra_info": {
"args": args,
},
},
"model": {
"name": args.model,
},
"metric": {
"name": "latency",
"benchmark_values": results.get("latencies", []),
"extra_info": {
"avg_latency": results.get("avg_latency", 0),
"percentiles": results.get("percentiles", {}),
},
},
}

if os.environ.get("SAVE_IN_PYTORCH_BENCHMARK_FORMAT", False):
output_file = (
f"{os.path.splitext(args.output_json)[0]}_pytorch_format.json"
)
with open(output_file, "w") as f:
json.dump(record, f)


def main(args: argparse.Namespace):
print(args)

Expand All @@ -35,36 +69,39 @@ def main(args: argparse.Namespace):
max_tokens=args.output_len,
)
print(sampling_params)
dummy_prompt_token_ids = np.random.randint(10000,
size=(args.batch_size,
args.input_len))
dummy_prompts: List[PromptType] = [{
"prompt_token_ids": batch
} for batch in dummy_prompt_token_ids.tolist()]
dummy_prompt_token_ids = np.random.randint(
10000, size=(args.batch_size, args.input_len)
)
dummy_prompts: List[PromptType] = [
{"prompt_token_ids": batch} for batch in dummy_prompt_token_ids.tolist()
]

def llm_generate():
if not args.use_beam_search:
llm.generate(dummy_prompts,
sampling_params=sampling_params,
use_tqdm=False)
llm.generate(
dummy_prompts, sampling_params=sampling_params, use_tqdm=False
)
else:
llm.beam_search(
dummy_prompts,
BeamSearchParams(
beam_width=args.n,
max_tokens=args.output_len,
ignore_eos=True,
))
),
)

def run_to_completion(profile_dir: Optional[str] = None):
if profile_dir:
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
on_trace_ready=torch.profiler.tensorboard_trace_handler(
str(profile_dir))) as p:
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
on_trace_ready=torch.profiler.tensorboard_trace_handler(
str(profile_dir)
),
) as p:
llm_generate()
print(p.key_averages().table(sort_by="self_cuda_time_total"))
else:
Expand All @@ -81,9 +118,11 @@ def run_to_completion(profile_dir: Optional[str] = None):
if args.profile:
profile_dir = args.profile_result_dir
if not profile_dir:
profile_dir = Path(
"."
) / "vllm_benchmark_result" / f"latency_result_{time.time()}"
profile_dir = (
Path(".")
/ "vllm_benchmark_result"
/ f"latency_result_{time.time()}"
)
print(f"Profiling (results will be saved to '{profile_dir}')...")
run_to_completion(profile_dir=profile_dir)
return
Expand All @@ -95,9 +134,9 @@ def run_to_completion(profile_dir: Optional[str] = None):
latencies = np.array(latencies)
percentages = [10, 25, 50, 75, 90, 99]
percentiles = np.percentile(latencies, percentages)
print(f'Avg latency: {np.mean(latencies)} seconds')
print(f"Avg latency: {np.mean(latencies)} seconds")
for percentage, percentile in zip(percentages, percentiles):
print(f'{percentage}% percentile latency: {percentile} seconds')
print(f"{percentage}% percentile latency: {percentile} seconds")

# Output JSON results if specified
if args.output_json:
Expand All @@ -108,43 +147,53 @@ def run_to_completion(profile_dir: Optional[str] = None):
}
with open(args.output_json, "w") as f:
json.dump(results, f, indent=4)
save_to_pytorch_benchmark_format(args, results)


if __name__ == '__main__':
if __name__ == "__main__":
parser = FlexibleArgumentParser(
description='Benchmark the latency of processing a single batch of '
'requests till completion.')
parser.add_argument('--input-len', type=int, default=32)
parser.add_argument('--output-len', type=int, default=128)
parser.add_argument('--batch-size', type=int, default=8)
parser.add_argument('--n',
type=int,
default=1,
help='Number of generated sequences per prompt.')
parser.add_argument('--use-beam-search', action='store_true')
parser.add_argument('--num-iters-warmup',
type=int,
default=10,
help='Number of iterations to run for warmup.')
parser.add_argument('--num-iters',
type=int,
default=30,
help='Number of iterations to run.')
description="Benchmark the latency of processing a single batch of "
"requests till completion."
)
parser.add_argument("--input-len", type=int, default=32)
parser.add_argument("--output-len", type=int, default=128)
parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument(
"--n",
type=int,
default=1,
help="Number of generated sequences per prompt.",
)
parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument(
'--profile',
action='store_true',
help='profile the generation process of a single batch')
"--num-iters-warmup",
type=int,
default=10,
help="Number of iterations to run for warmup.",
)
parser.add_argument(
"--num-iters", type=int, default=30, help="Number of iterations to run."
)
parser.add_argument(
'--profile-result-dir',
"--profile",
action="store_true",
help="profile the generation process of a single batch",
)
parser.add_argument(
"--profile-result-dir",
type=str,
default=None,
help=('path to save the pytorch profiler output. Can be visualized '
'with ui.perfetto.dev or Tensorboard.'))
help=(
"path to save the pytorch profiler output. Can be visualized "
"with ui.perfetto.dev or Tensorboard."
),
)
parser.add_argument(
'--output-json',
"--output-json",
type=str,
default=None,
help='Path to save the latency results in JSON format.')
help="Path to save the latency results in JSON format.",
)

parser = EngineArgs.add_cli_args(parser)
args = parser.parse_args()
Expand Down