24
24
"""
25
25
import argparse
26
26
import asyncio
27
+ import base64
28
+ import io
27
29
import json
28
30
import os
29
31
import random
30
32
import time
31
33
import warnings
32
34
from dataclasses import dataclass
33
35
from datetime import datetime
34
- from typing import Any , AsyncGenerator , Dict , List , Optional , Tuple
36
+ from typing import Any , AsyncGenerator , Collection , Dict , List , Optional , Tuple
35
37
36
38
import numpy as np
37
39
from backend_request_func import (ASYNC_REQUEST_FUNCS , RequestFuncInput ,
38
40
RequestFuncOutput )
41
+ from datasets import load_dataset
42
+ from PIL .Image import Image
39
43
from tqdm .asyncio import tqdm
40
44
from transformers import PreTrainedTokenizerBase
41
45
@@ -84,7 +88,7 @@ def sample_sharegpt_requests(
84
88
num_requests : int ,
85
89
tokenizer : PreTrainedTokenizerBase ,
86
90
fixed_output_len : Optional [int ] = None ,
87
- ) -> List [Tuple [str , int , int ]]:
91
+ ) -> List [Tuple [str , int , int , None ]]:
88
92
if fixed_output_len is not None and fixed_output_len < 4 :
89
93
raise ValueError ("output_len too small" )
90
94
# Load the dataset.
@@ -119,7 +123,7 @@ def sample_sharegpt_requests(
119
123
if prompt_len > 1024 or prompt_len + output_len > 2048 :
120
124
# Prune too long sequences.
121
125
continue
122
- filtered_dataset .append ((prompt , prompt_len , output_len ))
126
+ filtered_dataset .append ((prompt , prompt_len , output_len , None ))
123
127
124
128
return filtered_dataset
125
129
@@ -131,7 +135,7 @@ def sample_sonnet_requests(
131
135
output_len : int ,
132
136
prefix_len : int ,
133
137
tokenizer : PreTrainedTokenizerBase ,
134
- ) -> List [Tuple [str , str , int , int ]]:
138
+ ) -> List [Tuple [str , str , int , int , None ]]:
135
139
assert (
136
140
input_len > prefix_len
137
141
), "'args.sonnet-input-len' must be greater than 'args.prefix-input-len'."
@@ -189,7 +193,65 @@ def sample_sonnet_requests(
189
193
message , add_generation_prompt = True , tokenize = False )
190
194
prompt_len = len (tokenizer (prompt_formatted ).input_ids )
191
195
sampled_requests .append (
192
- (prompt , prompt_formatted , prompt_len , output_len ))
196
+ (prompt , prompt_formatted , prompt_len , output_len , None ))
197
+
198
+ return sampled_requests
199
+
200
+
201
+ def sample_hf_requests (
202
+ dataset_path : str ,
203
+ dataset_subset : str ,
204
+ dataset_split : str ,
205
+ num_requests : int ,
206
+ tokenizer : PreTrainedTokenizerBase ,
207
+ fixed_output_len : Optional [int ] = None ,
208
+ ) -> List [Tuple [str , str , int , Optional [Dict [str , Collection [str ]]]]]:
209
+ dataset = load_dataset (dataset_path ,
210
+ name = dataset_subset ,
211
+ split = dataset_split ,
212
+ streaming = True )
213
+ assert "conversations" in dataset .features , (
214
+ "HF Dataset must have 'conversations' column." )
215
+ filtered_dataset = dataset .shuffle ().filter (
216
+ lambda x : len (x ["conversations" ]) >= 2 )
217
+ sampled_requests : List [Tuple [str , int , int , Dict [str ,
218
+ Collection [str ]]]] = []
219
+ for data in filtered_dataset :
220
+ if len (sampled_requests ) == num_requests :
221
+ break
222
+
223
+ # Tokenize the prompts and completions.
224
+ prompt = data ["conversations" ][0 ]["value" ]
225
+ prompt_token_ids = tokenizer (prompt ).input_ids
226
+ completion = data ["conversations" ][1 ]["value" ]
227
+ completion_token_ids = tokenizer (completion ).input_ids
228
+ prompt_len = len (prompt_token_ids )
229
+ output_len = len (completion_token_ids
230
+ ) if fixed_output_len is None else fixed_output_len
231
+ if prompt_len < 4 or output_len < 4 :
232
+ # Prune too short sequences.
233
+ continue
234
+ if prompt_len > 1024 or prompt_len + output_len > 2048 :
235
+ # Prune too long sequences.
236
+ continue
237
+
238
+ if "image" in data and isinstance (data ["image" ], Image ):
239
+ image : Image = data ["image" ]
240
+ image = image .convert ("RGB" )
241
+ image_data = io .BytesIO ()
242
+ image .save (image_data , format = 'JPEG' )
243
+ image_base64 = base64 .b64encode (
244
+ image_data .getvalue ()).decode ("utf-8" )
245
+ mm_content = {
246
+ "type" : "image_url" ,
247
+ "image_url" : {
248
+ "url" : f"data:image/jpeg;base64,{ image_base64 } "
249
+ },
250
+ }
251
+ else :
252
+ mm_content = None
253
+
254
+ sampled_requests .append ((prompt , prompt_len , output_len , mm_content ))
193
255
194
256
return sampled_requests
195
257
@@ -223,8 +285,8 @@ def sample_random_requests(
223
285
[(offsets [i ] + i + j ) % tokenizer .vocab_size
224
286
for j in range (input_lens [i ])])
225
287
226
- input_requests .append (
227
- ( prompt , int ( prefix_len + input_lens [ i ]), int (output_lens [i ])))
288
+ input_requests .append (( prompt , int ( prefix_len + input_lens [ i ]),
289
+ int (output_lens [i ]), None ))
228
290
229
291
return input_requests
230
292
@@ -343,7 +405,12 @@ async def benchmark(
343
405
raise ValueError (f"Unknown backend: { backend } " )
344
406
345
407
print ("Starting initial single prompt test run..." )
346
- test_prompt , test_prompt_len , test_output_len = input_requests [0 ]
408
+ test_prompt , test_prompt_len , test_output_len , test_mm_content = (
409
+ input_requests [0 ])
410
+ if backend != "openai-chat" and test_mm_content is not None :
411
+ # multi-modal benchmark is only available on OpenAI Chat backend.
412
+ raise ValueError (
413
+ "Multi-modal content is only supported on 'openai-chat' backend." )
347
414
test_input = RequestFuncInput (
348
415
model = model_id ,
349
416
prompt = test_prompt ,
@@ -353,6 +420,7 @@ async def benchmark(
353
420
logprobs = logprobs ,
354
421
best_of = best_of ,
355
422
use_beam_search = use_beam_search ,
423
+ multi_modal_content = test_mm_content ,
356
424
)
357
425
test_output = await request_func (request_func_input = test_input )
358
426
if not test_output .success :
@@ -373,6 +441,7 @@ async def benchmark(
373
441
logprobs = logprobs ,
374
442
best_of = best_of ,
375
443
use_beam_search = use_beam_search ,
444
+ multi_modal_content = test_mm_content ,
376
445
)
377
446
profile_output = await request_func (request_func_input = profile_input )
378
447
if profile_output .success :
@@ -385,7 +454,7 @@ async def benchmark(
385
454
benchmark_start_time = time .perf_counter ()
386
455
tasks : List [asyncio .Task ] = []
387
456
async for request in get_request (input_requests , request_rate ):
388
- prompt , prompt_len , output_len = request
457
+ prompt , prompt_len , output_len , mm_content = request
389
458
request_func_input = RequestFuncInput (
390
459
model = model_id ,
391
460
prompt = prompt ,
@@ -395,6 +464,7 @@ async def benchmark(
395
464
logprobs = logprobs ,
396
465
best_of = best_of ,
397
466
use_beam_search = use_beam_search ,
467
+ multi_modal_content = mm_content ,
398
468
)
399
469
tasks .append (
400
470
asyncio .create_task (
@@ -575,6 +645,16 @@ def main(args: argparse.Namespace):
575
645
for prompt , prompt_formatted , prompt_len ,
576
646
output_len in input_requests ]
577
647
648
+ elif args .dataset_name == "hf" :
649
+ input_requests = sample_hf_requests (
650
+ dataset_path = args .dataset_path ,
651
+ dataset_subset = args .hf_subset ,
652
+ dataset_split = args .hf_split ,
653
+ num_requests = args .num_prompts ,
654
+ tokenizer = tokenizer ,
655
+ fixed_output_len = args .hf_output_len ,
656
+ )
657
+
578
658
elif args .dataset_name == "random" :
579
659
input_requests = sample_random_requests (
580
660
prefix_len = args .random_prefix_len ,
@@ -685,13 +765,14 @@ def main(args: argparse.Namespace):
685
765
"--dataset-name" ,
686
766
type = str ,
687
767
default = "sharegpt" ,
688
- choices = ["sharegpt" , "sonnet" , "random" ],
768
+ choices = ["sharegpt" , "sonnet" , "random" , "hf" ],
689
769
help = "Name of the dataset to benchmark on." ,
690
770
)
691
771
parser .add_argument ("--dataset-path" ,
692
772
type = str ,
693
773
default = None ,
694
- help = "Path to the dataset." )
774
+ help = "Path to the sharegpt/sonnet dataset. "
775
+ "Or the huggingface dataset ID if using HF dataset." )
695
776
parser .add_argument (
696
777
"--model" ,
697
778
type = str ,
@@ -718,26 +799,6 @@ def main(args: argparse.Namespace):
718
799
default = 1000 ,
719
800
help = "Number of prompts to process." ,
720
801
)
721
- parser .add_argument (
722
- "--sharegpt-output-len" ,
723
- type = int ,
724
- default = None ,
725
- help = "Output length for each request. Overrides the output length "
726
- "from the ShareGPT dataset." )
727
- parser .add_argument (
728
- "--sonnet-input-len" ,
729
- type = int ,
730
- default = 550 ,
731
- help =
732
- "Number of input tokens per request, used only for sonnet dataset." ,
733
- )
734
- parser .add_argument (
735
- "--sonnet-output-len" ,
736
- type = int ,
737
- default = 150 ,
738
- help =
739
- "Number of output tokens per request, used only for sonnet dataset." ,
740
- )
741
802
parser .add_argument (
742
803
"--logprobs" ,
743
804
type = int ,
@@ -748,42 +809,6 @@ def main(args: argparse.Namespace):
748
809
"logprob is returned for each token; or (2) if beam search "
749
810
"is enabled 1 logprob per token is computed" ),
750
811
)
751
- parser .add_argument (
752
- "--sonnet-prefix-len" ,
753
- type = int ,
754
- default = 200 ,
755
- help =
756
- "Number of prefix tokens per request, used only for sonnet dataset." ,
757
- )
758
- parser .add_argument (
759
- "--random-input-len" ,
760
- type = int ,
761
- default = 1024 ,
762
- help =
763
- "Number of input tokens per request, used only for random sampling." ,
764
- )
765
- parser .add_argument (
766
- "--random-output-len" ,
767
- type = int ,
768
- default = 128 ,
769
- help =
770
- "Number of output tokens per request, used only for random sampling." ,
771
- )
772
- parser .add_argument (
773
- "--random-range-ratio" ,
774
- type = float ,
775
- default = 1.0 ,
776
- help = "Range of sampled ratio of input/output length, "
777
- "used only for random sampling." ,
778
- )
779
- parser .add_argument (
780
- "--random-prefix-len" ,
781
- type = int ,
782
- default = 0 ,
783
- help = "Number of fixed prefix tokens before random "
784
- " context. The length range of context in a random "
785
- " request is [random-prefix-len, "
786
- " random-prefix-len + random-prefix-len * random-range-ratio)." )
787
812
parser .add_argument (
788
813
"--request-rate" ,
789
814
type = float ,
@@ -857,5 +882,85 @@ def main(args: argparse.Namespace):
857
882
"Use \" --percentile-metrics\" to select metrics." ,
858
883
)
859
884
885
+ # group for dataset specific arguments
886
+ sonnet_group = parser .add_argument_group ("sonnet dataset options" )
887
+ sonnet_group .add_argument (
888
+ "--sonnet-input-len" ,
889
+ type = int ,
890
+ default = 550 ,
891
+ help =
892
+ "Number of input tokens per request, used only for sonnet dataset." ,
893
+ )
894
+ sonnet_group .add_argument (
895
+ "--sonnet-output-len" ,
896
+ type = int ,
897
+ default = 150 ,
898
+ help =
899
+ "Number of output tokens per request, used only for sonnet dataset." ,
900
+ )
901
+ sonnet_group .add_argument (
902
+ "--sonnet-prefix-len" ,
903
+ type = int ,
904
+ default = 200 ,
905
+ help =
906
+ "Number of prefix tokens per request, used only for sonnet dataset." ,
907
+ )
908
+
909
+ sharegpt_group = parser .add_argument_group ("sharegpt dataset options" )
910
+ sharegpt_group .add_argument (
911
+ "--sharegpt-output-len" ,
912
+ type = int ,
913
+ default = None ,
914
+ help = "Output length for each request. Overrides the output length "
915
+ "from the ShareGPT dataset." )
916
+
917
+ random_group = parser .add_argument_group ("random dataset options" )
918
+ random_group .add_argument (
919
+ "--random-input-len" ,
920
+ type = int ,
921
+ default = 1024 ,
922
+ help =
923
+ "Number of input tokens per request, used only for random sampling." ,
924
+ )
925
+ random_group .add_argument (
926
+ "--random-output-len" ,
927
+ type = int ,
928
+ default = 128 ,
929
+ help =
930
+ "Number of output tokens per request, used only for random sampling." ,
931
+ )
932
+ random_group .add_argument (
933
+ "--random-range-ratio" ,
934
+ type = float ,
935
+ default = 1.0 ,
936
+ help = "Range of sampled ratio of input/output length, "
937
+ "used only for random sampling." ,
938
+ )
939
+ random_group .add_argument (
940
+ "--random-prefix-len" ,
941
+ type = int ,
942
+ default = 0 ,
943
+ help = "Number of fixed prefix tokens before random "
944
+ " context. The length range of context in a random "
945
+ " request is [random-prefix-len, "
946
+ " random-prefix-len + random-prefix-len * random-range-ratio)." )
947
+
948
+ hf_group = parser .add_argument_group ("hf dataset options" )
949
+ hf_group .add_argument ("--hf-subset" ,
950
+ type = str ,
951
+ default = None ,
952
+ help = "Subset of the HF dataset." )
953
+ hf_group .add_argument ("--hf-split" ,
954
+ type = str ,
955
+ default = None ,
956
+ help = "Split of the HF dataset." )
957
+ hf_group .add_argument (
958
+ "--hf-output-len" ,
959
+ type = int ,
960
+ default = None ,
961
+ help = "Output length for each request. Overrides the output lengths "
962
+ "from the sampled HF dataset." ,
963
+ )
964
+
860
965
args = parser .parse_args ()
861
966
main (args )
0 commit comments