From f293f22834aa61cb8fcf7c75a3c51cd2e7329e4f Mon Sep 17 00:00:00 2001 From: Wang Peng <36780733+logicwong@users.noreply.github.com> Date: Wed, 6 Jul 2022 13:35:04 +0800 Subject: [PATCH] Update unify_task.py --- tasks/pretrain_tasks/unify_task.py | 46 ++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/tasks/pretrain_tasks/unify_task.py b/tasks/pretrain_tasks/unify_task.py index cac2290c..764b1768 100644 --- a/tasks/pretrain_tasks/unify_task.py +++ b/tasks/pretrain_tasks/unify_task.py @@ -152,3 +152,49 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): poisson_lambda=self.cfg.poisson_lambda, replace_length=self.cfg.replace_length ) + + def get_batch_iterator( + self, + dataset, + max_tokens=None, + max_sentences=None, + max_positions=None, + ignore_invalid_inputs=False, + required_batch_size_multiple=1, + seed=1, + num_shards=1, + shard_id=0, + num_workers=0, + epoch=1, + data_buffer_size=0, + disable_iterator_cache=False, + ): + assert isinstance(dataset, FairseqDataset) + + # initialize the dataset with the correct starting epoch + dataset.set_epoch(epoch) + + # create mini-batches with given size constraints + batch_sampler = [ + [j for j in range(i, min(i + max_sentences, len(dataset)))] + for i in range(0, len(dataset), max_sentences) + ] + total_row_count = dataset.dataset.get_total_row_count() + num_batches = math.ceil(math.ceil(total_row_count / num_shards) / max_sentences) + if len(batch_sampler) < num_batches: + batch_sampler.append([1]) + + # return a reusable, sharded iterator + epoch_iter = iterators.EpochBatchIterator( + dataset=dataset, + collate_fn=dataset.collater, + batch_sampler=batch_sampler, + seed=seed, + num_shards=1, + shard_id=0, + num_workers=num_workers, + epoch=epoch, + buffer_size=data_buffer_size + ) + + return epoch_iter