diff --git a/superbench/benchmarks/model_benchmarks/model_base.py b/superbench/benchmarks/model_benchmarks/model_base.py index 133ee76f4..a51c05850 100644 --- a/superbench/benchmarks/model_benchmarks/model_base.py +++ b/superbench/benchmarks/model_benchmarks/model_base.py @@ -78,6 +78,13 @@ def add_parser_arguments(self): required=False, help='The number of batch size.', ) + self._parser.add_argument( + '--num_workers', + type=int, + default=8, + required=False, + help='Number of subprocesses to use for data loading.', + ) self._parser.add_argument( '--precision', type=Precision, diff --git a/superbench/benchmarks/model_benchmarks/pytorch_base.py b/superbench/benchmarks/model_benchmarks/pytorch_base.py index ce1cca93b..f0cb52319 100644 --- a/superbench/benchmarks/model_benchmarks/pytorch_base.py +++ b/superbench/benchmarks/model_benchmarks/pytorch_base.py @@ -181,7 +181,7 @@ def _init_dataloader(self): dataset=self._dataset, batch_size=self._args.batch_size, shuffle=False, - num_workers=8, + num_workers=self._args.num_workers, sampler=train_sampler, drop_last=True, pin_memory=self._args.pin_memory diff --git a/tests/benchmarks/model_benchmarks/test_model_base.py b/tests/benchmarks/model_benchmarks/test_model_base.py index 926088aea..deba3a438 100644 --- a/tests/benchmarks/model_benchmarks/test_model_base.py +++ b/tests/benchmarks/model_benchmarks/test_model_base.py @@ -167,6 +167,7 @@ def test_arguments_related_interfaces(): --no_gpu Disable GPU training. --num_steps int The number of test step. --num_warmup int The number of warmup step. + --num_workers int Number of subprocesses to use for data loading. --pin_memory Enable option to pin memory in data loader. --precision Precision [Precision ...] Model precision. E.g. fp8_hybrid fp8_e4m3 fp8_e5m2 @@ -206,6 +207,7 @@ def test_preprocess(): --no_gpu Disable GPU training. --num_steps int The number of test step. --num_warmup int The number of warmup step. + --num_workers int Number of subprocesses to use for data loading. --pin_memory Enable option to pin memory in data loader. --precision Precision [Precision ...] Model precision. E.g. fp8_hybrid fp8_e4m3 fp8_e5m2