From 3e1d0ad54fd2bff098575a2b5ab9ff8d76185a9f Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Mon, 18 Mar 2024 14:13:08 +0800 Subject: [PATCH] Allow initialize dataloader without specifying 'sampler' (#809) --- ppsci/data/__init__.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/ppsci/data/__init__.py b/ppsci/data/__init__.py index fe18e0a40..55288c26a 100644 --- a/ppsci/data/__init__.py +++ b/ppsci/data/__init__.py @@ -82,16 +82,21 @@ def build_dataloader(_dataset, cfg): sampler_cfg["batch_size"] = cfg["batch_size"] batch_sampler = getattr(io, batch_sampler_cls)(_dataset, **sampler_cfg) else: - if cfg["batch_size"] != 1: - raise ValueError( - f"`batch_size` should be 1 when sampler config is None, but got {cfg['batch_size']}." + batch_sampler_cls = "BatchSampler" + if world_size > 1: + batch_sampler_cls = "DistributedBatchSampler" + logger.warning( + f"Automatically use 'DistributedBatchSampler' instead of " + f"'BatchSampler' when world_size({world_size}) > 1." ) - logger.warning( - "`batch_size` is set to 1 as neither sampler config nor batch_size is set." - ) - batch_sampler = io.BatchSampler( + batch_sampler = getattr(io, batch_sampler_cls)( _dataset, batch_size=cfg["batch_size"], + shuffle=False, + drop_last=False, + ) + logger.message( + "'shuffle' and 'drop_last' are both set to False in default as sampler config is not specified." ) # build collate_fn if specified