Skip to content

Commit

Permalink
Update unify_task.py
Browse files Browse the repository at this point in the history
  • Loading branch information
logicwong committed Jul 6, 2022
1 parent 1d6c444 commit f293f22
Showing 1 changed file with 46 additions and 0 deletions.
46 changes: 46 additions & 0 deletions tasks/pretrain_tasks/unify_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,49 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
poisson_lambda=self.cfg.poisson_lambda,
replace_length=self.cfg.replace_length
)

def get_batch_iterator(
self,
dataset,
max_tokens=None,
max_sentences=None,
max_positions=None,
ignore_invalid_inputs=False,
required_batch_size_multiple=1,
seed=1,
num_shards=1,
shard_id=0,
num_workers=0,
epoch=1,
data_buffer_size=0,
disable_iterator_cache=False,
):
assert isinstance(dataset, FairseqDataset)

# initialize the dataset with the correct starting epoch
dataset.set_epoch(epoch)

# create mini-batches with given size constraints
batch_sampler = [
[j for j in range(i, min(i + max_sentences, len(dataset)))]
for i in range(0, len(dataset), max_sentences)
]
total_row_count = dataset.dataset.get_total_row_count()
num_batches = math.ceil(math.ceil(total_row_count / num_shards) / max_sentences)
if len(batch_sampler) < num_batches:
batch_sampler.append([1])

# return a reusable, sharded iterator
epoch_iter = iterators.EpochBatchIterator(
dataset=dataset,
collate_fn=dataset.collater,
batch_sampler=batch_sampler,
seed=seed,
num_shards=1,
shard_id=0,
num_workers=num_workers,
epoch=epoch,
buffer_size=data_buffer_size
)

return epoch_iter

0 comments on commit f293f22

Please sign in to comment.