Skip to content

Commit 8882440

Browse files
ekagra-ranjanamitm02
authored andcommitted
[Misc][Benchmark] Add support for CustomDataset (vllm-project#18511)
Signed-off-by: amit <amit.man@gmail.com>
1 parent 63f4c59 commit 8882440

File tree

5 files changed

+264
-8
lines changed

5 files changed

+264
-8
lines changed

benchmarks/README.md

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,12 @@ become available.
6464
<td style="text-align: center;">✅</td>
6565
<td><code>lmms-lab/LLaVA-OneVision-Data</code>, <code>Aeala/ShareGPT_Vicuna_unfiltered</code></td>
6666
</tr>
67+
<tr>
68+
<td><strong>Custom</strong></td>
69+
<td style="text-align: center;">✅</td>
70+
<td style="text-align: center;">✅</td>
71+
<td>Local file: <code>data.jsonl</code></td>
72+
</tr>
6773
</tbody>
6874
</table>
6975

@@ -124,6 +130,38 @@ P99 ITL (ms): 8.39
124130
==================================================
125131
```
126132

133+
### Custom Dataset
134+
If the dataset you want to benchmark is not supported yet in vLLM, even then you can benchmark on it using `CustomDataset`. Your data needs to be in `.jsonl` format and needs to have "prompt" field per entry, e.g., data.jsonl
135+
136+
```
137+
{"prompt": "What is the capital of India?"}
138+
{"prompt": "What is the capital of Iran?"}
139+
{"prompt": "What is the capital of China?"}
140+
```
141+
142+
```bash
143+
# start server
144+
VLLM_USE_V1=1 vllm serve meta-llama/Llama-3.1-8B-Instruct --disable-log-requests
145+
```
146+
147+
```bash
148+
# run benchmarking script
149+
python3 benchmarks/benchmark_serving.py --port 9001 --save-result --save-detailed \
150+
--backend vllm \
151+
--model meta-llama/Llama-3.1-8B-Instruct \
152+
--endpoint /v1/completions \
153+
--dataset-name custom \
154+
--dataset-path <path-to-your-data-jsonl> \
155+
--custom-skip-chat-template \
156+
--num-prompts 80 \
157+
--max-concurrency 1 \
158+
--temperature=0.3 \
159+
--top-p=0.75 \
160+
--result-dir "./log/"
161+
```
162+
163+
You can skip applying chat template if your data already has it by using `--custom-skip-chat-template`.
164+
127165
### VisionArena Benchmark for Vision Language Models
128166

129167
```bash
@@ -203,6 +241,16 @@ python3 vllm/benchmarks/benchmark_serving.py \
203241
--seed 42
204242
```
205243

244+
**`philschmid/mt-bench`**
245+
246+
``` bash
247+
python3 vllm/benchmarks/benchmark_serving.py \
248+
--model Qwen/QwQ-32B \
249+
--dataset-name hf \
250+
--dataset-path philschmid/mt-bench \
251+
--num-prompts 80
252+
```
253+
206254
### Running With Sampling Parameters
207255

208256
When using OpenAI-compatible backends such as `vllm`, optional sampling

benchmarks/benchmark_dataset.py

Lines changed: 91 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,6 @@
99
- BurstGPT
1010
- HuggingFace
1111
- VisionArena
12-
13-
TODO: Implement CustomDataset to parse a JSON file and convert its contents into
14-
SampleRequest instances, similar to the approach used in ShareGPT.
1512
"""
1613

1714
import base64
@@ -442,6 +439,97 @@ def sample(
442439
return samples
443440

444441

442+
# -----------------------------------------------------------------------------
443+
# Custom Dataset Implementation
444+
# -----------------------------------------------------------------------------
445+
446+
447+
class CustomDataset(BenchmarkDataset):
448+
"""
449+
Implements the Custom dataset. Loads data from a JSONL file and generates
450+
sample requests based on conversation turns. E.g.,
451+
```
452+
{"prompt": "What is the capital of India?"}
453+
{"prompt": "What is the capital of Iran?"}
454+
{"prompt": "What is the capital of China?"}
455+
```
456+
"""
457+
458+
def __init__(self, **kwargs) -> None:
459+
super().__init__(**kwargs)
460+
self.load_data()
461+
462+
def load_data(self) -> None:
463+
if self.dataset_path is None:
464+
raise ValueError("dataset_path must be provided for loading data.")
465+
466+
# self.data will be a list of dictionaries
467+
# e.g., [{"prompt": "What is the capital of India?"}, ...]
468+
# This will be the standardized format which load_data()
469+
# has to convert into depending on the filetype of dataset_path.
470+
# sample() will assume this standardized format of self.data
471+
self.data = []
472+
473+
# Load the JSONL file
474+
if self.dataset_path.endswith(".jsonl"):
475+
jsonl_data = pd.read_json(path_or_buf=self.dataset_path, lines=True)
476+
477+
# check if the JSONL file has a 'prompt' column
478+
if "prompt" not in jsonl_data.columns:
479+
raise ValueError("JSONL file must contain a 'prompt' column.")
480+
481+
# Convert each row to a dictionary and append to self.data
482+
# This will convert the DataFrame to a list of dictionaries
483+
# where each dictionary corresponds to a row in the DataFrame.
484+
# This is the standardized format we want for self.data
485+
for _, row in jsonl_data.iterrows():
486+
self.data.append(row.to_dict())
487+
else:
488+
raise NotImplementedError(
489+
"Only JSONL format is supported for CustomDataset."
490+
)
491+
492+
random.seed(self.random_seed)
493+
random.shuffle(self.data)
494+
495+
def sample(
496+
self,
497+
tokenizer: PreTrainedTokenizerBase,
498+
num_requests: int,
499+
lora_path: Optional[str] = None,
500+
max_loras: Optional[int] = None,
501+
output_len: Optional[int] = None,
502+
enable_multimodal_chat: bool = False,
503+
skip_chat_template: bool = False,
504+
**kwargs,
505+
) -> list:
506+
sampled_requests = []
507+
for item in self.data:
508+
if len(sampled_requests) >= num_requests:
509+
break
510+
prompt = item["prompt"]
511+
512+
# apply template
513+
if not skip_chat_template:
514+
prompt = tokenizer.apply_chat_template(
515+
[{"role": "user", "content": prompt}],
516+
add_generation_prompt=True,
517+
tokenize=False,
518+
)
519+
520+
prompt_len = len(tokenizer(prompt).input_ids)
521+
sampled_requests.append(
522+
SampleRequest(
523+
prompt=prompt,
524+
prompt_len=prompt_len,
525+
expected_output_len=output_len,
526+
)
527+
)
528+
self.maybe_oversample_requests(sampled_requests, num_requests)
529+
530+
return sampled_requests
531+
532+
445533
# -----------------------------------------------------------------------------
446534
# Sonnet Dataset Implementation
447535
# -----------------------------------------------------------------------------

benchmarks/benchmark_serving.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
ASRDataset,
6161
BurstGPTDataset,
6262
ConversationDataset,
63+
CustomDataset,
6364
HuggingFaceDataset,
6465
InstructCoderDataset,
6566
MTBenchDataset,
@@ -627,7 +628,16 @@ def main(args: argparse.Namespace):
627628
"'--dataset-path' if required."
628629
)
629630

630-
if args.dataset_name == "sonnet":
631+
if args.dataset_name == "custom":
632+
dataset = CustomDataset(dataset_path=args.dataset_path)
633+
input_requests = dataset.sample(
634+
num_requests=args.num_prompts,
635+
tokenizer=tokenizer,
636+
output_len=args.custom_output_len,
637+
skip_chat_template=args.custom_skip_chat_template,
638+
)
639+
640+
elif args.dataset_name == "sonnet":
631641
dataset = SonnetDataset(dataset_path=args.dataset_path)
632642
# For the "sonnet" dataset, formatting depends on the backend.
633643
if args.backend == "openai-chat":
@@ -838,6 +848,8 @@ def main(args: argparse.Namespace):
838848
]:
839849
if field in result_json:
840850
del result_json[field]
851+
if field in benchmark_result:
852+
del benchmark_result[field]
841853

842854
# Save to file
843855
base_model_id = model_id.split("/")[-1]
@@ -850,6 +862,7 @@ def main(args: argparse.Namespace):
850862
if args.result_filename:
851863
file_name = args.result_filename
852864
if args.result_dir:
865+
os.makedirs(args.result_dir, exist_ok=True)
853866
file_name = os.path.join(args.result_dir, file_name)
854867
with open(
855868
file_name, mode="a+" if args.append_result else "w", encoding="utf-8"
@@ -890,7 +903,7 @@ def main(args: argparse.Namespace):
890903
"--dataset-name",
891904
type=str,
892905
default="sharegpt",
893-
choices=["sharegpt", "burstgpt", "sonnet", "random", "hf"],
906+
choices=["sharegpt", "burstgpt", "sonnet", "random", "hf", "custom"],
894907
help="Name of the dataset to benchmark on.",
895908
)
896909
parser.add_argument(
@@ -1060,6 +1073,19 @@ def main(args: argparse.Namespace):
10601073
)
10611074

10621075
# group for dataset specific arguments
1076+
custom_group = parser.add_argument_group("custom dataset options")
1077+
custom_group.add_argument(
1078+
"--custom-output-len",
1079+
type=int,
1080+
default=256,
1081+
help="Number of output tokens per request, used only for custom dataset.",
1082+
)
1083+
custom_group.add_argument(
1084+
"--custom-skip-chat-template",
1085+
action="store_true",
1086+
help="Skip applying chat template to prompt, used only for custom dataset.",
1087+
)
1088+
10631089
sonnet_group = parser.add_argument_group("sonnet dataset options")
10641090
sonnet_group.add_argument(
10651091
"--sonnet-input-len",

vllm/benchmarks/datasets.py

Lines changed: 94 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,6 @@
99
- BurstGPT
1010
- HuggingFace
1111
- VisionArena
12-
13-
TODO: Implement CustomDataset to parse a JSON file and convert its contents into
14-
SampleRequest instances, similar to the approach used in ShareGPT.
1512
"""
1613
import base64
1714
import io
@@ -26,6 +23,7 @@
2623
from typing import Any, Callable, Optional, Union
2724

2825
import numpy as np
26+
import pandas as pd
2927
from PIL import Image
3028
from transformers import PreTrainedTokenizerBase
3129

@@ -443,6 +441,99 @@ def sample(
443441
return samples
444442

445443

444+
# -----------------------------------------------------------------------------
445+
# Custom Dataset Implementation
446+
# -----------------------------------------------------------------------------
447+
448+
449+
class CustomDataset(BenchmarkDataset):
450+
"""
451+
Implements the Custom dataset. Loads data from a JSONL file and generates
452+
sample requests based on conversation turns. E.g.,
453+
```
454+
{"prompt": "What is the capital of India?"}
455+
{"prompt": "What is the capital of Iran?"}
456+
{"prompt": "What is the capital of China?"}
457+
```
458+
"""
459+
460+
def __init__(self, **kwargs) -> None:
461+
super().__init__(**kwargs)
462+
self.load_data()
463+
464+
def load_data(self) -> None:
465+
if self.dataset_path is None:
466+
raise ValueError("dataset_path must be provided for loading data.")
467+
468+
# self.data will be a list of dictionaries
469+
# e.g., [{"prompt": "What is the capital of India?"}, ...]
470+
# This will be the standardized format which load_data()
471+
# has to convert into depending on the filetype of dataset_path.
472+
# sample() will assume this standardized format of self.data
473+
self.data = []
474+
475+
# Load the JSONL file
476+
if self.dataset_path.endswith(".jsonl"):
477+
jsonl_data = pd.read_json(path_or_buf=self.dataset_path,
478+
lines=True)
479+
480+
# check if the JSONL file has a 'prompt' column
481+
if "prompt" not in jsonl_data.columns:
482+
raise ValueError("JSONL file must contain a 'prompt' column.")
483+
484+
# Convert each row to a dictionary and append to self.data
485+
# This will convert the DataFrame to a list of dictionaries
486+
# where each dictionary corresponds to a row in the DataFrame.
487+
# This is the standardized format we want for self.data
488+
for _, row in jsonl_data.iterrows():
489+
self.data.append(row.to_dict())
490+
else:
491+
raise NotImplementedError(
492+
"Only JSONL format is supported for CustomDataset.")
493+
494+
random.seed(self.random_seed)
495+
random.shuffle(self.data)
496+
497+
def sample(
498+
self,
499+
tokenizer: PreTrainedTokenizerBase,
500+
num_requests: int,
501+
lora_path: Optional[str] = None,
502+
max_loras: Optional[int] = None,
503+
output_len: Optional[int] = None,
504+
enable_multimodal_chat: bool = False,
505+
skip_chat_template: bool = False,
506+
**kwargs,
507+
) -> list:
508+
sampled_requests = []
509+
for item in self.data:
510+
if len(sampled_requests) >= num_requests:
511+
break
512+
prompt = item["prompt"]
513+
514+
# apply template
515+
if not skip_chat_template:
516+
prompt = tokenizer.apply_chat_template(
517+
[{
518+
"role": "user",
519+
"content": prompt
520+
}],
521+
add_generation_prompt=True,
522+
tokenize=False,
523+
)
524+
525+
prompt_len = len(tokenizer(prompt).input_ids)
526+
sampled_requests.append(
527+
SampleRequest(
528+
prompt=prompt,
529+
prompt_len=prompt_len,
530+
expected_output_len=output_len,
531+
))
532+
self.maybe_oversample_requests(sampled_requests, num_requests)
533+
534+
return sampled_requests
535+
536+
446537
# -----------------------------------------------------------------------------
447538
# Sonnet Dataset Implementation
448539
# -----------------------------------------------------------------------------

vllm/benchmarks/serve.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1110,6 +1110,8 @@ def main(args: argparse.Namespace):
11101110
]:
11111111
if field in result_json:
11121112
del result_json[field]
1113+
if field in benchmark_result:
1114+
del benchmark_result[field]
11131115

11141116
# Save to file
11151117
base_model_id = model_id.split("/")[-1]
@@ -1120,6 +1122,7 @@ def main(args: argparse.Namespace):
11201122
if args.result_filename:
11211123
file_name = args.result_filename
11221124
if args.result_dir:
1125+
os.makedirs(args.result_dir, exist_ok=True)
11231126
file_name = os.path.join(args.result_dir, file_name)
11241127
with open(file_name,
11251128
mode="a+" if args.append_result else "w",

0 commit comments

Comments
 (0)