Skip to content

Commit

Permalink
Allow initialize dataloader without specifying 'sampler' (#809)
Browse files Browse the repository at this point in the history
  • Loading branch information
HydrogenSulfate authored Mar 18, 2024
1 parent 14430f1 commit 3e1d0ad
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions ppsci/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,16 +82,21 @@ def build_dataloader(_dataset, cfg):
sampler_cfg["batch_size"] = cfg["batch_size"]
batch_sampler = getattr(io, batch_sampler_cls)(_dataset, **sampler_cfg)
else:
if cfg["batch_size"] != 1:
raise ValueError(
f"`batch_size` should be 1 when sampler config is None, but got {cfg['batch_size']}."
batch_sampler_cls = "BatchSampler"
if world_size > 1:
batch_sampler_cls = "DistributedBatchSampler"
logger.warning(
f"Automatically use 'DistributedBatchSampler' instead of "
f"'BatchSampler' when world_size({world_size}) > 1."
)
logger.warning(
"`batch_size` is set to 1 as neither sampler config nor batch_size is set."
)
batch_sampler = io.BatchSampler(
batch_sampler = getattr(io, batch_sampler_cls)(
_dataset,
batch_size=cfg["batch_size"],
shuffle=False,
drop_last=False,
)
logger.message(
"'shuffle' and 'drop_last' are both set to False in default as sampler config is not specified."
)

# build collate_fn if specified
Expand Down

0 comments on commit 3e1d0ad

Please sign in to comment.