diff --git a/wenet/dataset/datapipes.py b/wenet/dataset/datapipes.py index 46ca96d53..c50e538d3 100644 --- a/wenet/dataset/datapipes.py +++ b/wenet/dataset/datapipes.py @@ -452,6 +452,8 @@ def __init__(self, dp = IterableWrapperIterDataPipe(filenames) # 0 shard many jsonl files dp = dp.shuffle().repeat(cycle).shard(partition) + if shuffle: + self.dp = self.dp.shuffle(buffer_size=shuffle_size) # 1 read one json file self.dp = TextLineDataPipe(dp) self.dp = self.dp.prefetch(prefetch)