Skip to content

Commit 397f77d

Browse files
JArnoldAMDshreyankg
authored andcommitted
[Benchmarks] Make detokenization optional in benchmark scripts (vllm-project#11697)
Signed-off-by: Jeremy Arnold <Jeremy.Arnold@amd.com>
1 parent cd18ea4 commit 397f77d

File tree

4 files changed

+45
-7
lines changed

4 files changed

+45
-7
lines changed

benchmarks/benchmark_latency.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def main(args: argparse.Namespace):
5252
top_p=1.0,
5353
ignore_eos=True,
5454
max_tokens=args.output_len,
55+
detokenize=not args.disable_detokenize,
5556
)
5657
print(sampling_params)
5758
dummy_prompt_token_ids = np.random.randint(10000,
@@ -173,6 +174,12 @@ def run_to_completion(profile_dir: Optional[str] = None):
173174
default=None,
174175
help="Path to save the latency results in JSON format.",
175176
)
177+
parser.add_argument(
178+
"--disable-detokenize",
179+
action="store_true",
180+
help=("Do not detokenize responses (i.e. do not include "
181+
"detokenization time in the latency measurement)"),
182+
)
176183

177184
parser = EngineArgs.add_cli_args(parser)
178185
args = parser.parse_args()

benchmarks/benchmark_prefix_caching.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,9 @@ def main(args):
194194

195195
llm = LLM(**dataclasses.asdict(engine_args))
196196

197-
sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len)
197+
sampling_params = SamplingParams(temperature=0,
198+
max_tokens=args.output_len,
199+
detokenize=not args.disable_detokenize)
198200

199201
print("Testing filtered requests")
200202
prompts = repeat_and_sort_requests(filtered_requests,
@@ -243,6 +245,12 @@ def main(args):
243245
"subtract this length when filtering prompts. Only used "
244246
"when dataset-path is not provided.",
245247
)
248+
parser.add_argument(
249+
'--disable-detokenize',
250+
action='store_true',
251+
help=("Do not detokenize responses (i.e. do not include "
252+
"detokenization time in the latency measurement)"),
253+
)
246254

247255
parser = EngineArgs.add_cli_args(parser)
248256
args = parser.parse_args()

benchmarks/benchmark_prioritization.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def sample_requests(
2323
num_requests: int,
2424
tokenizer: PreTrainedTokenizerBase,
2525
fixed_output_len: Optional[int],
26-
) -> list[tuple[str, int, int]]:
26+
) -> list[tuple[str, int, int, int]]:
2727
if fixed_output_len is not None and fixed_output_len < 4:
2828
raise ValueError("output_len too small")
2929

@@ -71,6 +71,7 @@ def run_vllm(
7171
requests: list[tuple[str, int, int]],
7272
n: int,
7373
engine_args: EngineArgs,
74+
disable_detokenize: bool = False,
7475
) -> float:
7576
from vllm import LLM, SamplingParams
7677
llm = LLM(**dataclasses.asdict(engine_args))
@@ -95,6 +96,7 @@ def run_vllm(
9596
top_p=1.0,
9697
ignore_eos=True,
9798
max_tokens=output_len,
99+
detokenize=not disable_detokenize,
98100
))
99101

100102
start = time.perf_counter()
@@ -121,7 +123,8 @@ def main(args: argparse.Namespace):
121123

122124
if args.backend == "vllm":
123125
elapsed_time = run_vllm(requests, args.n,
124-
EngineArgs.from_cli_args(args))
126+
EngineArgs.from_cli_args(args),
127+
args.disable_detokenize)
125128
else:
126129
raise ValueError(f"Unknown backend: {args.backend}")
127130
total_num_tokens = sum(prompt_len + output_len
@@ -174,6 +177,12 @@ def main(args: argparse.Namespace):
174177
type=str,
175178
default=None,
176179
help='Path to save the throughput results in JSON format.')
180+
parser.add_argument(
181+
'--disable-detokenize',
182+
action='store_true',
183+
help=("Do not detokenize responses (i.e. do not include "
184+
"detokenization time in the latency measurement)"),
185+
)
177186

178187
parser = EngineArgs.add_cli_args(parser)
179188
args = parser.parse_args()

benchmarks/benchmark_throughput.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ def run_vllm(
168168
requests: list[SampleRequest],
169169
n: int,
170170
engine_args: EngineArgs,
171+
disable_detokenize: bool = False,
171172
) -> float:
172173
from vllm import LLM, SamplingParams
173174
llm = LLM(**dataclasses.asdict(engine_args))
@@ -194,6 +195,7 @@ def run_vllm(
194195
top_p=1.0,
195196
ignore_eos=True,
196197
max_tokens=request.expected_output_len,
198+
detokenize=not disable_detokenize,
197199
))
198200
lora_requests: Optional[list[LoRARequest]] = None
199201
if engine_args.enable_lora:
@@ -232,6 +234,7 @@ async def run_vllm_async(
232234
n: int,
233235
engine_args: AsyncEngineArgs,
234236
disable_frontend_multiprocessing: bool = False,
237+
disable_detokenize: bool = False,
235238
) -> float:
236239
from vllm import SamplingParams
237240

@@ -262,6 +265,7 @@ async def run_vllm_async(
262265
top_p=1.0,
263266
ignore_eos=True,
264267
max_tokens=request.expected_output_len,
268+
detokenize=not disable_detokenize,
265269
))
266270
lora_requests.append(request.lora_request)
267271

@@ -288,6 +292,7 @@ def run_hf(
288292
n: int,
289293
max_batch_size: int,
290294
trust_remote_code: bool,
295+
disable_detokenize: bool = False,
291296
) -> float:
292297
llm = AutoModelForCausalLM.from_pretrained(
293298
model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
@@ -327,8 +332,9 @@ def run_hf(
327332
use_cache=True,
328333
max_new_tokens=max_output_len,
329334
)
330-
# Include the decoding time.
331-
tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
335+
if not disable_detokenize:
336+
# Include the decoding time.
337+
tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
332338
pbar.update(len(batch))
333339

334340
# Clear the batch.
@@ -440,14 +446,17 @@ def main(args: argparse.Namespace):
440446
args.n,
441447
AsyncEngineArgs.from_cli_args(args),
442448
args.disable_frontend_multiprocessing,
449+
args.disable_detokenize,
443450
))
444451
else:
445452
elapsed_time = run_vllm(requests, args.n,
446-
EngineArgs.from_cli_args(args))
453+
EngineArgs.from_cli_args(args),
454+
args.disable_detokenize)
447455
elif args.backend == "hf":
448456
assert args.tensor_parallel_size == 1
449457
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
450-
args.hf_max_batch_size, args.trust_remote_code)
458+
args.hf_max_batch_size, args.trust_remote_code,
459+
args.disable_detokenize)
451460
elif args.backend == "mii":
452461
elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
453462
args.output_len)
@@ -526,6 +535,11 @@ def main(args: argparse.Namespace):
526535
action='store_true',
527536
default=False,
528537
help="Disable decoupled async engine frontend.")
538+
parser.add_argument(
539+
"--disable-detokenize",
540+
action="store_true",
541+
help=("Do not detokenize the response (i.e. do not include "
542+
"detokenization time in the measurement)"))
529543
# LoRA
530544
parser.add_argument(
531545
"--lora-path",

0 commit comments

Comments
 (0)