Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Update benchmark_throughput.py to support image input #9851

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 106 additions & 39 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,66 @@
import json
import random
import time
from typing import List, Optional, Tuple
from typing import List, Optional

import torch
import uvloop
from PIL import Image
from tqdm import tqdm
from transformers import (AutoModelForCausalLM, AutoTokenizer,
PreTrainedTokenizerBase)

from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args)
from vllm.inputs import TextPrompt
from vllm.multimodal import MultiModalDataDict
from vllm.sampling_params import BeamSearchParams
from vllm.utils import FlexibleArgumentParser, merge_async_iterators


def sample_requests(
dataset_path: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int],
) -> List[Tuple[str, int, int]]:
@dataclasses.dataclass
class SampleRequest:
"""A class representing a single inference request for benchmarking.
Attributes:
prompt: The input text prompt for the model.
multi_modal_data: Optional dictionary containing multi-modal data (e.g.
images).
prompt_len: The length of the prompt in tokens.
expected_output_len: The expected length of the output in tokens.
"""
prompt: str
prompt_len: int
expected_output_len: int
multi_modal_data: Optional[MultiModalDataDict] = None


def _get_prompt_for_image_model(question: str, *, model: str) -> str:
"""Prepend and append special tokens around the question to form a prompt.
Args:
question: The input question text to wrap with special tokens
model: The name of the model being used, to determine which special tokens to add

Check failure on line 47 in benchmarks/benchmark_throughput.py

View workflow job for this annotation

GitHub Actions / ruff (3.8)

Ruff (E501)

benchmarks/benchmark_throughput.py:47:81: E501 Line too long (89 > 80)

Check failure on line 47 in benchmarks/benchmark_throughput.py

View workflow job for this annotation

GitHub Actions / ruff (3.10)

Ruff (E501)

benchmarks/benchmark_throughput.py:47:81: E501 Line too long (89 > 80)

Check failure on line 47 in benchmarks/benchmark_throughput.py

View workflow job for this annotation

GitHub Actions / ruff (3.11)

Ruff (E501)

benchmarks/benchmark_throughput.py:47:81: E501 Line too long (89 > 80)

Check failure on line 47 in benchmarks/benchmark_throughput.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

benchmarks/benchmark_throughput.py:47:81: E501 Line too long (89 > 80)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please fix the format by running bash format.sh locally (and probably need to manually solve line too long errors in comments).

Returns:
The formatted prompt string with appropriate special tokens for the model

Check failure on line 50 in benchmarks/benchmark_throughput.py

View workflow job for this annotation

GitHub Actions / ruff (3.8)

Ruff (E501)

benchmarks/benchmark_throughput.py:50:81: E501 Line too long (81 > 80)

Check failure on line 50 in benchmarks/benchmark_throughput.py

View workflow job for this annotation

GitHub Actions / ruff (3.10)

Ruff (E501)

benchmarks/benchmark_throughput.py:50:81: E501 Line too long (81 > 80)

Check failure on line 50 in benchmarks/benchmark_throughput.py

View workflow job for this annotation

GitHub Actions / ruff (3.11)

Ruff (E501)

benchmarks/benchmark_throughput.py:50:81: E501 Line too long (81 > 80)

Check failure on line 50 in benchmarks/benchmark_throughput.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

benchmarks/benchmark_throughput.py:50:81: E501 Line too long (81 > 80)
Raises:
ValueError: If an unsupported model name is provided
"""
model = model.lower()
if "pixtral" in model:
return f"<s>[INST]{question}\n[IMG][/INST]"
raise ValueError(f"Unsupported model {model}")
Comment on lines +55 to +58
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine for now, but as a follow up we could consider leveraging HF tokenizer's chat template.



def sample_requests(tokenizer: PreTrainedTokenizerBase,
args: argparse.Namespace) -> List[SampleRequest]:
dataset_path: str = args.dataset
num_requests: int = args.num_prompts
fixed_output_len: Optional[int] = args.output_len
model: str = args.model
if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small")

Expand All @@ -33,58 +72,76 @@
dataset = json.load(f)
# Filter out the conversations with less than 2 turns.
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
# Only keep the first two turns of each conversation.
dataset = [(data["conversations"][0]["value"],
data["conversations"][1]["value"]) for data in dataset]

# Shuffle the dataset.
random.shuffle(dataset)

# Filter out sequences that are too long or too short
filtered_dataset: List[Tuple[str, int, int]] = []
for i in range(len(dataset)):
filtered_dataset: List[SampleRequest] = []
for data in dataset:
if len(filtered_dataset) == num_requests:
break

# Only keep the first two turns of each conversation.
prompt = data["conversations"][0]["value"]
completion = data["conversations"][1]["value"]

multi_modal_data: Optional[MultiModalDataDict] = None
if "image" in data:
multi_modal_data = multi_modal_data or {}
image_path = data["image"]
assert isinstance(image_path,
str), "Only support single image input"
Comment on lines +92 to +93
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could make a TODO to support multi-image inputs.

try:
multi_modal_data["image"] = Image.open(image_path).convert(
"RGB")
except FileNotFoundError:
# Ignore datapoint where asset is missing
continue
prompt = _get_prompt_for_image_model(question=prompt, model=model)

# Tokenize the prompts and completions.
prompt = dataset[i][0]
prompt_token_ids = tokenizer(prompt).input_ids
completion = dataset[i][1]
completion_token_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_token_ids)
output_len = len(completion_token_ids
) if fixed_output_len is None else fixed_output_len
) if fixed_output_len is None else fixed_output_len
if prompt_len < 4 or output_len < 4:
# Prune too short sequences.
continue
if prompt_len > 1024 or prompt_len + output_len > 2048:
# Prune too long sequences.
continue
filtered_dataset.append((prompt, prompt_len, output_len))
filtered_dataset.append(
SampleRequest(prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
multi_modal_data=multi_modal_data))

return filtered_dataset


def run_vllm(
requests: List[Tuple[str, int, int]],
requests: List[SampleRequest],
n: int,
engine_args: EngineArgs,
) -> float:
from vllm import LLM, SamplingParams
llm = LLM(**dataclasses.asdict(engine_args))

# Add the requests to the engine.
prompts: List[str] = []
prompts: List[TextPrompt] = []
sampling_params: List[SamplingParams] = []
for prompt, _, output_len in requests:
prompts.append(prompt)
for request in requests:
prompts.append(
TextPrompt(prompt=request.prompt,
multi_modal_data=request.multi_modal_data))
sampling_params.append(
SamplingParams(
n=n,
temperature=1.0,
top_p=1.0,
ignore_eos=True,
max_tokens=output_len,
max_tokens=request.expected_output_len,
))

use_beam_search = False
Expand All @@ -94,11 +151,11 @@
llm.generate(prompts, sampling_params, use_tqdm=True)
end = time.perf_counter()
else:
prompts = [prompt for prompt, _, _ in requests]
prompts = [request.prompt for request in requests]
# output_len should be the same for all requests.
output_len = requests[0][2]
for prompt, input_len, _output_len in requests:
assert _output_len == output_len
for request in requests:
assert request.expected_output_len == output_len
start = time.perf_counter()
llm.beam_search(
prompts,
Expand All @@ -112,7 +169,7 @@


async def run_vllm_async(
requests: List[Tuple[str, int, int]],
requests: List[SampleRequest],
n: int,
engine_args: AsyncEngineArgs,
disable_frontend_multiprocessing: bool = False,
Expand All @@ -123,17 +180,19 @@
engine_args, disable_frontend_multiprocessing) as llm:

# Add the requests to the engine.
prompts: List[str] = []
prompts: List[TextPrompt] = []
sampling_params: List[SamplingParams] = []
for prompt, _, output_len in requests:
prompts.append(prompt)
for request in requests:
prompts.append(
TextPrompt(prompt=request.prompt,
multi_modal_data=request.multi_modal_data))
sampling_params.append(
SamplingParams(
n=n,
temperature=1.0,
top_p=1.0,
ignore_eos=True,
max_tokens=output_len,
max_tokens=request.expected_output_len,
))

generators = []
Expand All @@ -149,7 +208,7 @@


def run_hf(
requests: List[Tuple[str, int, int]],
requests: List[SampleRequest],
model: str,
tokenizer: PreTrainedTokenizerBase,
n: int,
Expand Down Expand Up @@ -207,14 +266,14 @@


def run_mii(
requests: List[Tuple[str, int, int]],
requests: List[SampleRequest],
model: str,
tensor_parallel_size: int,
output_len: int,
) -> float:
from mii import client, serve
llm = serve(model, tensor_parallel=tensor_parallel_size)
prompts = [prompt for prompt, _, _ in requests]
prompts = [request.prompt for request in requests]

start = time.perf_counter()
llm.generate(prompts, max_new_tokens=output_len)
Expand Down Expand Up @@ -246,9 +305,10 @@
requests = [(prompt, args.input_len, args.output_len)
for _ in range(args.num_prompts)]
else:
requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
args.output_len)
requests = sample_requests(tokenizer, args)

is_multi_model = any(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
is_multi_model = any(
is_multi_modal = any(

request.multi_modal_data is not None for request in requests)
if args.backend == "vllm":
if args.async_engine:
elapsed_time = uvloop.run(
Expand All @@ -270,9 +330,14 @@
args.output_len)
else:
raise ValueError(f"Unknown backend: {args.backend}")
total_num_tokens = sum(prompt_len + output_len
for _, prompt_len, output_len in requests)
total_output_tokens = sum(output_len for _, _, output_len in requests)
total_num_tokens = sum(request.prompt_len + request.expected_output_len
for request in requests)
total_output_tokens = sum(
request.expected_output_len for request in requests)
if is_multi_model:
print("\033[91mWARNING\033[0m: Multi-modal request detected. "
"The following metrics is not accurate.")
Comment on lines +338 to +339
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
print("\033[91mWARNING\033[0m: Multi-modal request detected. "
"The following metrics is not accurate.")
print("\033[91mWARNING\033[0m: Multi-modal request detected. "
"The following metrics is not accurate because image tokens are not counted. See vllm-project/vllm/issues/9778 for details")

# TODO(vllm-project/vllm/issues/9778): Count molti-modal token length.
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
f"{total_output_tokens / elapsed_time:.2f} output tokens/s")
Expand All @@ -299,7 +364,9 @@
parser.add_argument("--dataset",
type=str,
default=None,
help="Path to the dataset.")
help="Path to the dataset. The dataset is expected to "
"be a json in form of List[Dict[..., conversations: "
"List[Dict[..., value: <prompt_or_response>]]]]")
parser.add_argument("--input-len",
type=int,
default=None,
Expand Down
Loading