Skip to content

Commit 1b6de83

Browse files
authored
[Benchmark] Support sample from HF datasets and image input for benchmark_serving (#8495)
1 parent cbdb252 commit 1b6de83

File tree

2 files changed

+177
-68
lines changed

2 files changed

+177
-68
lines changed

benchmarks/backend_request_func.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class RequestFuncInput:
2525
best_of: int = 1
2626
use_beam_search: bool = False
2727
logprobs: Optional[int] = None
28+
multi_modal_content: Optional[dict] = None
2829

2930

3031
@dataclass
@@ -312,12 +313,15 @@ async def async_request_openai_chat_completions(
312313

313314
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
314315
assert not request_func_input.use_beam_search
316+
content = [{"type": "text", "text": request_func_input.prompt}]
317+
if request_func_input.multi_modal_content:
318+
content.append(request_func_input.multi_modal_content)
315319
payload = {
316320
"model": request_func_input.model,
317321
"messages": [
318322
{
319323
"role": "user",
320-
"content": request_func_input.prompt,
324+
"content": content
321325
},
322326
],
323327
"temperature": 0.0,

benchmarks/benchmark_serving.py

+172-67
Original file line numberDiff line numberDiff line change
@@ -24,18 +24,22 @@
2424
"""
2525
import argparse
2626
import asyncio
27+
import base64
28+
import io
2729
import json
2830
import os
2931
import random
3032
import time
3133
import warnings
3234
from dataclasses import dataclass
3335
from datetime import datetime
34-
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
36+
from typing import Any, AsyncGenerator, Collection, Dict, List, Optional, Tuple
3537

3638
import numpy as np
3739
from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput,
3840
RequestFuncOutput)
41+
from datasets import load_dataset
42+
from PIL.Image import Image
3943
from tqdm.asyncio import tqdm
4044
from transformers import PreTrainedTokenizerBase
4145

@@ -84,7 +88,7 @@ def sample_sharegpt_requests(
8488
num_requests: int,
8589
tokenizer: PreTrainedTokenizerBase,
8690
fixed_output_len: Optional[int] = None,
87-
) -> List[Tuple[str, int, int]]:
91+
) -> List[Tuple[str, int, int, None]]:
8892
if fixed_output_len is not None and fixed_output_len < 4:
8993
raise ValueError("output_len too small")
9094
# Load the dataset.
@@ -119,7 +123,7 @@ def sample_sharegpt_requests(
119123
if prompt_len > 1024 or prompt_len + output_len > 2048:
120124
# Prune too long sequences.
121125
continue
122-
filtered_dataset.append((prompt, prompt_len, output_len))
126+
filtered_dataset.append((prompt, prompt_len, output_len, None))
123127

124128
return filtered_dataset
125129

@@ -131,7 +135,7 @@ def sample_sonnet_requests(
131135
output_len: int,
132136
prefix_len: int,
133137
tokenizer: PreTrainedTokenizerBase,
134-
) -> List[Tuple[str, str, int, int]]:
138+
) -> List[Tuple[str, str, int, int, None]]:
135139
assert (
136140
input_len > prefix_len
137141
), "'args.sonnet-input-len' must be greater than 'args.prefix-input-len'."
@@ -189,7 +193,65 @@ def sample_sonnet_requests(
189193
message, add_generation_prompt=True, tokenize=False)
190194
prompt_len = len(tokenizer(prompt_formatted).input_ids)
191195
sampled_requests.append(
192-
(prompt, prompt_formatted, prompt_len, output_len))
196+
(prompt, prompt_formatted, prompt_len, output_len, None))
197+
198+
return sampled_requests
199+
200+
201+
def sample_hf_requests(
202+
dataset_path: str,
203+
dataset_subset: str,
204+
dataset_split: str,
205+
num_requests: int,
206+
tokenizer: PreTrainedTokenizerBase,
207+
fixed_output_len: Optional[int] = None,
208+
) -> List[Tuple[str, str, int, Optional[Dict[str, Collection[str]]]]]:
209+
dataset = load_dataset(dataset_path,
210+
name=dataset_subset,
211+
split=dataset_split,
212+
streaming=True)
213+
assert "conversations" in dataset.features, (
214+
"HF Dataset must have 'conversations' column.")
215+
filtered_dataset = dataset.shuffle().filter(
216+
lambda x: len(x["conversations"]) >= 2)
217+
sampled_requests: List[Tuple[str, int, int, Dict[str,
218+
Collection[str]]]] = []
219+
for data in filtered_dataset:
220+
if len(sampled_requests) == num_requests:
221+
break
222+
223+
# Tokenize the prompts and completions.
224+
prompt = data["conversations"][0]["value"]
225+
prompt_token_ids = tokenizer(prompt).input_ids
226+
completion = data["conversations"][1]["value"]
227+
completion_token_ids = tokenizer(completion).input_ids
228+
prompt_len = len(prompt_token_ids)
229+
output_len = len(completion_token_ids
230+
) if fixed_output_len is None else fixed_output_len
231+
if prompt_len < 4 or output_len < 4:
232+
# Prune too short sequences.
233+
continue
234+
if prompt_len > 1024 or prompt_len + output_len > 2048:
235+
# Prune too long sequences.
236+
continue
237+
238+
if "image" in data and isinstance(data["image"], Image):
239+
image: Image = data["image"]
240+
image = image.convert("RGB")
241+
image_data = io.BytesIO()
242+
image.save(image_data, format='JPEG')
243+
image_base64 = base64.b64encode(
244+
image_data.getvalue()).decode("utf-8")
245+
mm_content = {
246+
"type": "image_url",
247+
"image_url": {
248+
"url": f"data:image/jpeg;base64,{image_base64}"
249+
},
250+
}
251+
else:
252+
mm_content = None
253+
254+
sampled_requests.append((prompt, prompt_len, output_len, mm_content))
193255

194256
return sampled_requests
195257

@@ -223,8 +285,8 @@ def sample_random_requests(
223285
[(offsets[i] + i + j) % tokenizer.vocab_size
224286
for j in range(input_lens[i])])
225287

226-
input_requests.append(
227-
(prompt, int(prefix_len + input_lens[i]), int(output_lens[i])))
288+
input_requests.append((prompt, int(prefix_len + input_lens[i]),
289+
int(output_lens[i]), None))
228290

229291
return input_requests
230292

@@ -343,7 +405,12 @@ async def benchmark(
343405
raise ValueError(f"Unknown backend: {backend}")
344406

345407
print("Starting initial single prompt test run...")
346-
test_prompt, test_prompt_len, test_output_len = input_requests[0]
408+
test_prompt, test_prompt_len, test_output_len, test_mm_content = (
409+
input_requests[0])
410+
if backend != "openai-chat" and test_mm_content is not None:
411+
# multi-modal benchmark is only available on OpenAI Chat backend.
412+
raise ValueError(
413+
"Multi-modal content is only supported on 'openai-chat' backend.")
347414
test_input = RequestFuncInput(
348415
model=model_id,
349416
prompt=test_prompt,
@@ -353,6 +420,7 @@ async def benchmark(
353420
logprobs=logprobs,
354421
best_of=best_of,
355422
use_beam_search=use_beam_search,
423+
multi_modal_content=test_mm_content,
356424
)
357425
test_output = await request_func(request_func_input=test_input)
358426
if not test_output.success:
@@ -373,6 +441,7 @@ async def benchmark(
373441
logprobs=logprobs,
374442
best_of=best_of,
375443
use_beam_search=use_beam_search,
444+
multi_modal_content=test_mm_content,
376445
)
377446
profile_output = await request_func(request_func_input=profile_input)
378447
if profile_output.success:
@@ -385,7 +454,7 @@ async def benchmark(
385454
benchmark_start_time = time.perf_counter()
386455
tasks: List[asyncio.Task] = []
387456
async for request in get_request(input_requests, request_rate):
388-
prompt, prompt_len, output_len = request
457+
prompt, prompt_len, output_len, mm_content = request
389458
request_func_input = RequestFuncInput(
390459
model=model_id,
391460
prompt=prompt,
@@ -395,6 +464,7 @@ async def benchmark(
395464
logprobs=logprobs,
396465
best_of=best_of,
397466
use_beam_search=use_beam_search,
467+
multi_modal_content=mm_content,
398468
)
399469
tasks.append(
400470
asyncio.create_task(
@@ -575,6 +645,16 @@ def main(args: argparse.Namespace):
575645
for prompt, prompt_formatted, prompt_len,
576646
output_len in input_requests]
577647

648+
elif args.dataset_name == "hf":
649+
input_requests = sample_hf_requests(
650+
dataset_path=args.dataset_path,
651+
dataset_subset=args.hf_subset,
652+
dataset_split=args.hf_split,
653+
num_requests=args.num_prompts,
654+
tokenizer=tokenizer,
655+
fixed_output_len=args.hf_output_len,
656+
)
657+
578658
elif args.dataset_name == "random":
579659
input_requests = sample_random_requests(
580660
prefix_len=args.random_prefix_len,
@@ -685,13 +765,14 @@ def main(args: argparse.Namespace):
685765
"--dataset-name",
686766
type=str,
687767
default="sharegpt",
688-
choices=["sharegpt", "sonnet", "random"],
768+
choices=["sharegpt", "sonnet", "random", "hf"],
689769
help="Name of the dataset to benchmark on.",
690770
)
691771
parser.add_argument("--dataset-path",
692772
type=str,
693773
default=None,
694-
help="Path to the dataset.")
774+
help="Path to the sharegpt/sonnet dataset. "
775+
"Or the huggingface dataset ID if using HF dataset.")
695776
parser.add_argument(
696777
"--model",
697778
type=str,
@@ -718,26 +799,6 @@ def main(args: argparse.Namespace):
718799
default=1000,
719800
help="Number of prompts to process.",
720801
)
721-
parser.add_argument(
722-
"--sharegpt-output-len",
723-
type=int,
724-
default=None,
725-
help="Output length for each request. Overrides the output length "
726-
"from the ShareGPT dataset.")
727-
parser.add_argument(
728-
"--sonnet-input-len",
729-
type=int,
730-
default=550,
731-
help=
732-
"Number of input tokens per request, used only for sonnet dataset.",
733-
)
734-
parser.add_argument(
735-
"--sonnet-output-len",
736-
type=int,
737-
default=150,
738-
help=
739-
"Number of output tokens per request, used only for sonnet dataset.",
740-
)
741802
parser.add_argument(
742803
"--logprobs",
743804
type=int,
@@ -748,42 +809,6 @@ def main(args: argparse.Namespace):
748809
"logprob is returned for each token; or (2) if beam search "
749810
"is enabled 1 logprob per token is computed"),
750811
)
751-
parser.add_argument(
752-
"--sonnet-prefix-len",
753-
type=int,
754-
default=200,
755-
help=
756-
"Number of prefix tokens per request, used only for sonnet dataset.",
757-
)
758-
parser.add_argument(
759-
"--random-input-len",
760-
type=int,
761-
default=1024,
762-
help=
763-
"Number of input tokens per request, used only for random sampling.",
764-
)
765-
parser.add_argument(
766-
"--random-output-len",
767-
type=int,
768-
default=128,
769-
help=
770-
"Number of output tokens per request, used only for random sampling.",
771-
)
772-
parser.add_argument(
773-
"--random-range-ratio",
774-
type=float,
775-
default=1.0,
776-
help="Range of sampled ratio of input/output length, "
777-
"used only for random sampling.",
778-
)
779-
parser.add_argument(
780-
"--random-prefix-len",
781-
type=int,
782-
default=0,
783-
help="Number of fixed prefix tokens before random "
784-
" context. The length range of context in a random "
785-
" request is [random-prefix-len, "
786-
" random-prefix-len + random-prefix-len * random-range-ratio).")
787812
parser.add_argument(
788813
"--request-rate",
789814
type=float,
@@ -857,5 +882,85 @@ def main(args: argparse.Namespace):
857882
"Use \"--percentile-metrics\" to select metrics.",
858883
)
859884

885+
# group for dataset specific arguments
886+
sonnet_group = parser.add_argument_group("sonnet dataset options")
887+
sonnet_group.add_argument(
888+
"--sonnet-input-len",
889+
type=int,
890+
default=550,
891+
help=
892+
"Number of input tokens per request, used only for sonnet dataset.",
893+
)
894+
sonnet_group.add_argument(
895+
"--sonnet-output-len",
896+
type=int,
897+
default=150,
898+
help=
899+
"Number of output tokens per request, used only for sonnet dataset.",
900+
)
901+
sonnet_group.add_argument(
902+
"--sonnet-prefix-len",
903+
type=int,
904+
default=200,
905+
help=
906+
"Number of prefix tokens per request, used only for sonnet dataset.",
907+
)
908+
909+
sharegpt_group = parser.add_argument_group("sharegpt dataset options")
910+
sharegpt_group.add_argument(
911+
"--sharegpt-output-len",
912+
type=int,
913+
default=None,
914+
help="Output length for each request. Overrides the output length "
915+
"from the ShareGPT dataset.")
916+
917+
random_group = parser.add_argument_group("random dataset options")
918+
random_group.add_argument(
919+
"--random-input-len",
920+
type=int,
921+
default=1024,
922+
help=
923+
"Number of input tokens per request, used only for random sampling.",
924+
)
925+
random_group.add_argument(
926+
"--random-output-len",
927+
type=int,
928+
default=128,
929+
help=
930+
"Number of output tokens per request, used only for random sampling.",
931+
)
932+
random_group.add_argument(
933+
"--random-range-ratio",
934+
type=float,
935+
default=1.0,
936+
help="Range of sampled ratio of input/output length, "
937+
"used only for random sampling.",
938+
)
939+
random_group.add_argument(
940+
"--random-prefix-len",
941+
type=int,
942+
default=0,
943+
help="Number of fixed prefix tokens before random "
944+
" context. The length range of context in a random "
945+
" request is [random-prefix-len, "
946+
" random-prefix-len + random-prefix-len * random-range-ratio).")
947+
948+
hf_group = parser.add_argument_group("hf dataset options")
949+
hf_group.add_argument("--hf-subset",
950+
type=str,
951+
default=None,
952+
help="Subset of the HF dataset.")
953+
hf_group.add_argument("--hf-split",
954+
type=str,
955+
default=None,
956+
help="Split of the HF dataset.")
957+
hf_group.add_argument(
958+
"--hf-output-len",
959+
type=int,
960+
default=None,
961+
help="Output length for each request. Overrides the output lengths "
962+
"from the sampled HF dataset.",
963+
)
964+
860965
args = parser.parse_args()
861966
main(args)

0 commit comments

Comments
 (0)