-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Update benchmark_throughput.py to support image input #9851
base: main
Are you sure you want to change the base?
Changes from all commits
68cd83f
9329d8d
2623fea
2378563
917ccb3
6cb2fa7
a0199b5
eb6e01b
a5ab5d1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -4,27 +4,66 @@ | |||||||||
import json | ||||||||||
import random | ||||||||||
import time | ||||||||||
from typing import List, Optional, Tuple | ||||||||||
from typing import List, Optional | ||||||||||
|
||||||||||
import torch | ||||||||||
import uvloop | ||||||||||
from PIL import Image | ||||||||||
from tqdm import tqdm | ||||||||||
from transformers import (AutoModelForCausalLM, AutoTokenizer, | ||||||||||
PreTrainedTokenizerBase) | ||||||||||
|
||||||||||
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs | ||||||||||
from vllm.entrypoints.openai.api_server import ( | ||||||||||
build_async_engine_client_from_engine_args) | ||||||||||
from vllm.inputs import TextPrompt | ||||||||||
from vllm.multimodal import MultiModalDataDict | ||||||||||
from vllm.sampling_params import BeamSearchParams | ||||||||||
from vllm.utils import FlexibleArgumentParser, merge_async_iterators | ||||||||||
|
||||||||||
|
||||||||||
def sample_requests( | ||||||||||
dataset_path: str, | ||||||||||
num_requests: int, | ||||||||||
tokenizer: PreTrainedTokenizerBase, | ||||||||||
fixed_output_len: Optional[int], | ||||||||||
) -> List[Tuple[str, int, int]]: | ||||||||||
@dataclasses.dataclass | ||||||||||
class SampleRequest: | ||||||||||
"""A class representing a single inference request for benchmarking. | ||||||||||
Attributes: | ||||||||||
prompt: The input text prompt for the model. | ||||||||||
multi_modal_data: Optional dictionary containing multi-modal data (e.g. | ||||||||||
images). | ||||||||||
prompt_len: The length of the prompt in tokens. | ||||||||||
expected_output_len: The expected length of the output in tokens. | ||||||||||
""" | ||||||||||
prompt: str | ||||||||||
prompt_len: int | ||||||||||
expected_output_len: int | ||||||||||
multi_modal_data: Optional[MultiModalDataDict] = None | ||||||||||
|
||||||||||
|
||||||||||
def _get_prompt_for_image_model(question: str, *, model: str) -> str: | ||||||||||
"""Prepend and append special tokens around the question to form a prompt. | ||||||||||
Args: | ||||||||||
question: The input question text to wrap with special tokens | ||||||||||
model: The name of the model being used, to determine which special tokens to add | ||||||||||
Check failure on line 47 in benchmarks/benchmark_throughput.py GitHub Actions / ruff (3.8)Ruff (E501)
Check failure on line 47 in benchmarks/benchmark_throughput.py GitHub Actions / ruff (3.10)Ruff (E501)
Check failure on line 47 in benchmarks/benchmark_throughput.py GitHub Actions / ruff (3.11)Ruff (E501)
|
||||||||||
Returns: | ||||||||||
The formatted prompt string with appropriate special tokens for the model | ||||||||||
Check failure on line 50 in benchmarks/benchmark_throughput.py GitHub Actions / ruff (3.8)Ruff (E501)
Check failure on line 50 in benchmarks/benchmark_throughput.py GitHub Actions / ruff (3.10)Ruff (E501)
Check failure on line 50 in benchmarks/benchmark_throughput.py GitHub Actions / ruff (3.11)Ruff (E501)
|
||||||||||
Raises: | ||||||||||
ValueError: If an unsupported model name is provided | ||||||||||
""" | ||||||||||
model = model.lower() | ||||||||||
if "pixtral" in model: | ||||||||||
return f"<s>[INST]{question}\n[IMG][/INST]" | ||||||||||
raise ValueError(f"Unsupported model {model}") | ||||||||||
Comment on lines
+55
to
+58
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is fine for now, but as a follow up we could consider leveraging HF tokenizer's chat template. |
||||||||||
|
||||||||||
|
||||||||||
def sample_requests(tokenizer: PreTrainedTokenizerBase, | ||||||||||
args: argparse.Namespace) -> List[SampleRequest]: | ||||||||||
dataset_path: str = args.dataset | ||||||||||
num_requests: int = args.num_prompts | ||||||||||
fixed_output_len: Optional[int] = args.output_len | ||||||||||
model: str = args.model | ||||||||||
if fixed_output_len is not None and fixed_output_len < 4: | ||||||||||
raise ValueError("output_len too small") | ||||||||||
|
||||||||||
|
@@ -33,58 +72,76 @@ | |||||||||
dataset = json.load(f) | ||||||||||
# Filter out the conversations with less than 2 turns. | ||||||||||
dataset = [data for data in dataset if len(data["conversations"]) >= 2] | ||||||||||
# Only keep the first two turns of each conversation. | ||||||||||
dataset = [(data["conversations"][0]["value"], | ||||||||||
data["conversations"][1]["value"]) for data in dataset] | ||||||||||
|
||||||||||
# Shuffle the dataset. | ||||||||||
random.shuffle(dataset) | ||||||||||
|
||||||||||
# Filter out sequences that are too long or too short | ||||||||||
filtered_dataset: List[Tuple[str, int, int]] = [] | ||||||||||
for i in range(len(dataset)): | ||||||||||
filtered_dataset: List[SampleRequest] = [] | ||||||||||
for data in dataset: | ||||||||||
if len(filtered_dataset) == num_requests: | ||||||||||
break | ||||||||||
|
||||||||||
# Only keep the first two turns of each conversation. | ||||||||||
prompt = data["conversations"][0]["value"] | ||||||||||
completion = data["conversations"][1]["value"] | ||||||||||
|
||||||||||
multi_modal_data: Optional[MultiModalDataDict] = None | ||||||||||
if "image" in data: | ||||||||||
multi_modal_data = multi_modal_data or {} | ||||||||||
image_path = data["image"] | ||||||||||
assert isinstance(image_path, | ||||||||||
str), "Only support single image input" | ||||||||||
Comment on lines
+92
to
+93
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could make a TODO to support multi-image inputs. |
||||||||||
try: | ||||||||||
multi_modal_data["image"] = Image.open(image_path).convert( | ||||||||||
"RGB") | ||||||||||
except FileNotFoundError: | ||||||||||
# Ignore datapoint where asset is missing | ||||||||||
continue | ||||||||||
prompt = _get_prompt_for_image_model(question=prompt, model=model) | ||||||||||
|
||||||||||
# Tokenize the prompts and completions. | ||||||||||
prompt = dataset[i][0] | ||||||||||
prompt_token_ids = tokenizer(prompt).input_ids | ||||||||||
completion = dataset[i][1] | ||||||||||
completion_token_ids = tokenizer(completion).input_ids | ||||||||||
prompt_len = len(prompt_token_ids) | ||||||||||
output_len = len(completion_token_ids | ||||||||||
) if fixed_output_len is None else fixed_output_len | ||||||||||
) if fixed_output_len is None else fixed_output_len | ||||||||||
if prompt_len < 4 or output_len < 4: | ||||||||||
# Prune too short sequences. | ||||||||||
continue | ||||||||||
if prompt_len > 1024 or prompt_len + output_len > 2048: | ||||||||||
# Prune too long sequences. | ||||||||||
continue | ||||||||||
filtered_dataset.append((prompt, prompt_len, output_len)) | ||||||||||
filtered_dataset.append( | ||||||||||
SampleRequest(prompt=prompt, | ||||||||||
prompt_len=prompt_len, | ||||||||||
expected_output_len=output_len, | ||||||||||
multi_modal_data=multi_modal_data)) | ||||||||||
|
||||||||||
return filtered_dataset | ||||||||||
|
||||||||||
|
||||||||||
def run_vllm( | ||||||||||
requests: List[Tuple[str, int, int]], | ||||||||||
requests: List[SampleRequest], | ||||||||||
n: int, | ||||||||||
engine_args: EngineArgs, | ||||||||||
) -> float: | ||||||||||
from vllm import LLM, SamplingParams | ||||||||||
llm = LLM(**dataclasses.asdict(engine_args)) | ||||||||||
|
||||||||||
# Add the requests to the engine. | ||||||||||
prompts: List[str] = [] | ||||||||||
prompts: List[TextPrompt] = [] | ||||||||||
sampling_params: List[SamplingParams] = [] | ||||||||||
for prompt, _, output_len in requests: | ||||||||||
prompts.append(prompt) | ||||||||||
for request in requests: | ||||||||||
prompts.append( | ||||||||||
TextPrompt(prompt=request.prompt, | ||||||||||
multi_modal_data=request.multi_modal_data)) | ||||||||||
sampling_params.append( | ||||||||||
SamplingParams( | ||||||||||
n=n, | ||||||||||
temperature=1.0, | ||||||||||
top_p=1.0, | ||||||||||
ignore_eos=True, | ||||||||||
max_tokens=output_len, | ||||||||||
max_tokens=request.expected_output_len, | ||||||||||
)) | ||||||||||
|
||||||||||
use_beam_search = False | ||||||||||
|
@@ -94,11 +151,11 @@ | |||||||||
llm.generate(prompts, sampling_params, use_tqdm=True) | ||||||||||
end = time.perf_counter() | ||||||||||
else: | ||||||||||
prompts = [prompt for prompt, _, _ in requests] | ||||||||||
prompts = [request.prompt for request in requests] | ||||||||||
# output_len should be the same for all requests. | ||||||||||
output_len = requests[0][2] | ||||||||||
for prompt, input_len, _output_len in requests: | ||||||||||
assert _output_len == output_len | ||||||||||
for request in requests: | ||||||||||
assert request.expected_output_len == output_len | ||||||||||
start = time.perf_counter() | ||||||||||
llm.beam_search( | ||||||||||
prompts, | ||||||||||
|
@@ -112,7 +169,7 @@ | |||||||||
|
||||||||||
|
||||||||||
async def run_vllm_async( | ||||||||||
requests: List[Tuple[str, int, int]], | ||||||||||
requests: List[SampleRequest], | ||||||||||
n: int, | ||||||||||
engine_args: AsyncEngineArgs, | ||||||||||
disable_frontend_multiprocessing: bool = False, | ||||||||||
|
@@ -123,17 +180,19 @@ | |||||||||
engine_args, disable_frontend_multiprocessing) as llm: | ||||||||||
|
||||||||||
# Add the requests to the engine. | ||||||||||
prompts: List[str] = [] | ||||||||||
prompts: List[TextPrompt] = [] | ||||||||||
sampling_params: List[SamplingParams] = [] | ||||||||||
for prompt, _, output_len in requests: | ||||||||||
prompts.append(prompt) | ||||||||||
for request in requests: | ||||||||||
prompts.append( | ||||||||||
TextPrompt(prompt=request.prompt, | ||||||||||
multi_modal_data=request.multi_modal_data)) | ||||||||||
sampling_params.append( | ||||||||||
SamplingParams( | ||||||||||
n=n, | ||||||||||
temperature=1.0, | ||||||||||
top_p=1.0, | ||||||||||
ignore_eos=True, | ||||||||||
max_tokens=output_len, | ||||||||||
max_tokens=request.expected_output_len, | ||||||||||
)) | ||||||||||
|
||||||||||
generators = [] | ||||||||||
|
@@ -149,7 +208,7 @@ | |||||||||
|
||||||||||
|
||||||||||
def run_hf( | ||||||||||
requests: List[Tuple[str, int, int]], | ||||||||||
requests: List[SampleRequest], | ||||||||||
model: str, | ||||||||||
tokenizer: PreTrainedTokenizerBase, | ||||||||||
n: int, | ||||||||||
|
@@ -207,14 +266,14 @@ | |||||||||
|
||||||||||
|
||||||||||
def run_mii( | ||||||||||
requests: List[Tuple[str, int, int]], | ||||||||||
requests: List[SampleRequest], | ||||||||||
model: str, | ||||||||||
tensor_parallel_size: int, | ||||||||||
output_len: int, | ||||||||||
) -> float: | ||||||||||
from mii import client, serve | ||||||||||
llm = serve(model, tensor_parallel=tensor_parallel_size) | ||||||||||
prompts = [prompt for prompt, _, _ in requests] | ||||||||||
prompts = [request.prompt for request in requests] | ||||||||||
|
||||||||||
start = time.perf_counter() | ||||||||||
llm.generate(prompts, max_new_tokens=output_len) | ||||||||||
|
@@ -246,9 +305,10 @@ | |||||||||
requests = [(prompt, args.input_len, args.output_len) | ||||||||||
for _ in range(args.num_prompts)] | ||||||||||
else: | ||||||||||
requests = sample_requests(args.dataset, args.num_prompts, tokenizer, | ||||||||||
args.output_len) | ||||||||||
requests = sample_requests(tokenizer, args) | ||||||||||
|
||||||||||
is_multi_model = any( | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
request.multi_modal_data is not None for request in requests) | ||||||||||
if args.backend == "vllm": | ||||||||||
if args.async_engine: | ||||||||||
elapsed_time = uvloop.run( | ||||||||||
|
@@ -270,9 +330,14 @@ | |||||||||
args.output_len) | ||||||||||
else: | ||||||||||
raise ValueError(f"Unknown backend: {args.backend}") | ||||||||||
total_num_tokens = sum(prompt_len + output_len | ||||||||||
for _, prompt_len, output_len in requests) | ||||||||||
total_output_tokens = sum(output_len for _, _, output_len in requests) | ||||||||||
total_num_tokens = sum(request.prompt_len + request.expected_output_len | ||||||||||
for request in requests) | ||||||||||
total_output_tokens = sum( | ||||||||||
request.expected_output_len for request in requests) | ||||||||||
if is_multi_model: | ||||||||||
print("\033[91mWARNING\033[0m: Multi-modal request detected. " | ||||||||||
"The following metrics is not accurate.") | ||||||||||
Comment on lines
+338
to
+339
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
# TODO(vllm-project/vllm/issues/9778): Count molti-modal token length. | ||||||||||
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " | ||||||||||
f"{total_num_tokens / elapsed_time:.2f} total tokens/s, " | ||||||||||
f"{total_output_tokens / elapsed_time:.2f} output tokens/s") | ||||||||||
|
@@ -299,7 +364,9 @@ | |||||||||
parser.add_argument("--dataset", | ||||||||||
type=str, | ||||||||||
default=None, | ||||||||||
help="Path to the dataset.") | ||||||||||
help="Path to the dataset. The dataset is expected to " | ||||||||||
"be a json in form of List[Dict[..., conversations: " | ||||||||||
"List[Dict[..., value: <prompt_or_response>]]]]") | ||||||||||
parser.add_argument("--input-len", | ||||||||||
type=int, | ||||||||||
default=None, | ||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please fix the format by running
bash format.sh
locally (and probably need to manually solve line too long errors in comments).