1
1
import asyncio
2
+ import codecs
2
3
import json
3
4
from pathlib import Path
4
5
from typing import get_args
8
9
from guidellm .backend import BackendType
9
10
from guidellm .benchmark import ProfileType , benchmark_generative_text
10
11
from guidellm .config import print_config
12
+ from guidellm .preprocess .dataset import ShortPromptStrategy , process_dataset
11
13
from guidellm .scheduler import StrategyType
12
14
13
15
STRATEGY_PROFILE_CHOICES = set (
@@ -280,6 +282,20 @@ def benchmark(
280
282
)
281
283
282
284
285
+ def decode_escaped_str (_ctx , _param , value ):
286
+ """
287
+ Click auto adds characters. For example, when using --pad-char "\n ",
288
+ it parses it as "\\ n". This method decodes the string to handle escape
289
+ sequences correctly.
290
+ """
291
+ if value is None :
292
+ return None
293
+ try :
294
+ return codecs .decode (value , "unicode_escape" )
295
+ except Exception as e :
296
+ raise click .BadParameter (f"Could not decode escape sequences: { e } " ) from e
297
+
298
+
283
299
@cli .command (
284
300
help = (
285
301
"Print out the available configuration settings that can be set "
@@ -290,5 +306,139 @@ def config():
290
306
print_config ()
291
307
292
308
309
+ @cli .group (help = "General preprocessing tools and utilities." )
310
+ def preprocess ():
311
+ pass
312
+
313
+
314
+ @preprocess .command (
315
+ help = (
316
+ "Convert a dataset to have specific prompt and output token sizes.\n "
317
+ "DATA: Path to the input dataset or dataset ID.\n "
318
+ "OUTPUT_PATH: Path to save the converted dataset, including file suffix."
319
+ )
320
+ )
321
+ @click .argument (
322
+ "data" ,
323
+ type = str ,
324
+ required = True ,
325
+ )
326
+ @click .argument (
327
+ "output_path" ,
328
+ type = click .Path (file_okay = True , dir_okay = False , writable = True , resolve_path = True ),
329
+ required = True ,
330
+ )
331
+ @click .option (
332
+ "--processor" ,
333
+ type = str ,
334
+ required = True ,
335
+ help = (
336
+ "The processor or tokenizer to use to calculate token counts for statistics "
337
+ "and synthetic data generation."
338
+ ),
339
+ )
340
+ @click .option (
341
+ "--processor-args" ,
342
+ default = None ,
343
+ callback = parse_json ,
344
+ help = (
345
+ "A JSON string containing any arguments to pass to the processor constructor "
346
+ "as a dict with **kwargs."
347
+ ),
348
+ )
349
+ @click .option (
350
+ "--data-args" ,
351
+ callback = parse_json ,
352
+ help = (
353
+ "A JSON string containing any arguments to pass to the dataset creation "
354
+ "as a dict with **kwargs."
355
+ ),
356
+ )
357
+ @click .option (
358
+ "--short-prompt-strategy" ,
359
+ type = click .Choice ([s .value for s in ShortPromptStrategy ]),
360
+ default = ShortPromptStrategy .IGNORE .value ,
361
+ show_default = True ,
362
+ help = "Strategy to handle prompts shorter than the target length. " ,
363
+ )
364
+ @click .option (
365
+ "--pad-char" ,
366
+ type = str ,
367
+ default = "" ,
368
+ callback = decode_escaped_str ,
369
+ help = "The token to pad short prompts with when using the 'pad' strategy." ,
370
+ )
371
+ @click .option (
372
+ "--concat-delimiter" ,
373
+ type = str ,
374
+ default = "" ,
375
+ help = (
376
+ "The delimiter to use when concatenating prompts that are too short."
377
+ " Used when strategy is 'concatenate'."
378
+ ),
379
+ )
380
+ @click .option (
381
+ "--prompt-tokens" ,
382
+ type = str ,
383
+ default = None ,
384
+ help = "Prompt tokens config (JSON, YAML file or key=value string)" ,
385
+ )
386
+ @click .option (
387
+ "--output-tokens" ,
388
+ type = str ,
389
+ default = None ,
390
+ help = "Output tokens config (JSON, YAML file or key=value string)" ,
391
+ )
392
+ @click .option (
393
+ "--push-to-hub" ,
394
+ is_flag = True ,
395
+ help = "Set this flag to push the converted dataset to the Hugging Face Hub." ,
396
+ )
397
+ @click .option (
398
+ "--hub-dataset-id" ,
399
+ type = str ,
400
+ default = None ,
401
+ help = "The Hugging Face Hub dataset ID to push to. "
402
+ "Required if --push-to-hub is used." ,
403
+ )
404
+ @click .option (
405
+ "--random-seed" ,
406
+ type = int ,
407
+ default = 42 ,
408
+ show_default = True ,
409
+ help = "Random seed for prompt token sampling and output tokens sampling." ,
410
+ )
411
+ def dataset (
412
+ data ,
413
+ output_path ,
414
+ processor ,
415
+ processor_args ,
416
+ data_args ,
417
+ short_prompt_strategy ,
418
+ pad_char ,
419
+ concat_delimiter ,
420
+ prompt_tokens ,
421
+ output_tokens ,
422
+ push_to_hub ,
423
+ hub_dataset_id ,
424
+ random_seed ,
425
+ ):
426
+ process_dataset (
427
+ data = data ,
428
+ output_path = output_path ,
429
+ processor = processor ,
430
+ prompt_tokens = prompt_tokens ,
431
+ output_tokens = output_tokens ,
432
+ processor_args = processor_args ,
433
+ data_args = data_args ,
434
+ short_prompt_strategy = short_prompt_strategy ,
435
+ pad_char = pad_char ,
436
+ concat_delimiter = concat_delimiter ,
437
+ push_to_hub = push_to_hub ,
438
+ hub_dataset_id = hub_dataset_id ,
439
+ random_seed = random_seed ,
440
+ )
441
+
442
+
293
443
if __name__ == "__main__" :
294
444
cli ()
0 commit comments