diff --git a/pytext/data/__init__.py b/pytext/data/__init__.py index 6ba8bdb52..95827e757 100644 --- a/pytext/data/__init__.py +++ b/pytext/data/__init__.py @@ -4,7 +4,7 @@ from .bptt_lm_data_handler import BPTTLanguageModelDataHandler from .compositional_data_handler import CompositionalDataHandler from .contextual_intent_slot_data_handler import ContextualIntentSlotModelDataHandler -from .data import Batcher, Data +from .data import Batcher, Data, PoolingBatcher from .data_handler import BatchIterator, CommonMetadata, DataHandler from .disjoint_multitask_data_handler import DisjointMultitaskDataHandler from .doc_classification_data_handler import DocClassificationDataHandler, RawData @@ -19,19 +19,20 @@ __all__ = [ "Batcher", + "BatchIterator", "BPTTLanguageModelDataHandler", + "CommonMetadata", "CompositionalDataHandler", "ContextualIntentSlotModelDataHandler", - "BatchIterator", - "CommonMetadata", "Data", "DataHandler", + "DisjointMultitaskDataHandler", + "DocClassificationDataHandler", "JointModelDataHandler", "LanguageModelDataHandler", "PairClassificationDataHandler", - "SeqModelDataHandler", - "DocClassificationDataHandler", - "RawData", - "DisjointMultitaskDataHandler", + "PoolingBatcher", "QueryDocumentPairwiseRankingDataHandler", + "RawData", + "SeqModelDataHandler", ] diff --git a/pytext/data/data.py b/pytext/data/data.py index a479dc6d8..2f3df683e 100644 --- a/pytext/data/data.py +++ b/pytext/data/data.py @@ -3,6 +3,8 @@ import functools import itertools +import math +import random from typing import Dict, Iterable, Optional, Type from pytext.common.constants import Stage @@ -17,6 +19,7 @@ class Batcher(Component): """Batcher designed to batch rows of data, before padding.""" __COMPONENT_TYPE__ = ComponentType.BATCHER + __EXPANSIBLE__ = True class Config(Component.Config): #: Make batches of this size when possible. If there's not enough data, @@ -40,23 +43,82 @@ def __init__( self.train_batch_size = train_batch_size self.eval_batch_size = eval_batch_size self.test_batch_size = test_batch_size + self._batch_sizes = { + Stage.TRAIN: self.train_batch_size, + Stage.TEST: self.eval_batch_size, + Stage.EVAL: self.test_batch_size, + } def batchify( self, iterable: Iterable[RawExample], sort_key=None, stage=Stage.TRAIN ): """Group rows by batch_size. Assume iterable of dicts, yield dict of lists. The last batch will be of length len(iterable) % batch_size.""" - batch_size = { - Stage.TRAIN: self.train_batch_size, - Stage.TEST: self.eval_batch_size, - Stage.EVAL: self.test_batch_size, - }[stage] - iterators = [iter(iterable)] * batch_size - for batch in itertools.zip_longest(*iterators): - res = [ex for ex in batch if ex is not None] + batch_size = self._batch_sizes[stage] + for batch in self._group_iter(iterable, batch_size, sort_key): + yield zip_dicts(batch) + + def _group_iter(self, iterable: Iterable[RawExample], group_size, sort_key=None): + iterators = [iter(iterable)] * group_size + for group in itertools.zip_longest(*iterators): + group = [ex for ex in group if ex is not None] + if sort_key: + group.sort(key=sort_key, reverse=True) + yield group + + +class PoolingBatcher(Batcher): + """ + Batcher that looks at pools of data, and sorts, batches, and shuffles them, before + padding. + """ + + class Config(Batcher.Config): + #: Number of batches in a pool, to load at one time. + pool_num_batches: int = 10000 + + @classmethod + def from_config(cls, config: Config): + return cls( + config.train_batch_size, + config.eval_batch_size, + config.test_batch_size, + config.pool_num_batches, + ) + + def __init__( + self, + train_batch_size=Config.train_batch_size, + eval_batch_size=Config.eval_batch_size, + test_batch_size=Config.test_batch_size, + pool_num_batches=Config.pool_num_batches, + ): + super().__init__(train_batch_size, eval_batch_size, test_batch_size) + self.pool_num_batches = pool_num_batches or 1 + + def batchify( + self, iterable: Iterable[RawExample], sort_key=None, stage=Stage.TRAIN + ): + """ + From an iterable of dicts, yield dicts of lists, by + + 1. Load pool of batch_size * pool_num_batches examples. + 2. Sort rows, if necessary. + 3. Form batches with batch_size examples each. + 4. Shuffle batches and yield all batches. + """ + batch_size = self._batch_sizes[stage] + pool_size = batch_size * self.pool_num_batches + + for pool in self._group_iter(iterable, pool_size, sort_key): + batch_indices = list(range(math.ceil(len(pool) / batch_size))) if sort_key: - res.sort(reverse=True, key=sort_key) - yield zip_dicts(res) + random.shuffle(batch_indices) + else: + random.shuffle(pool) + for batch_index in batch_indices: + batch = pool[batch_size * batch_index : batch_size * (batch_index + 1)] + yield zip_dicts(batch) def numberize_rows(tensorizers, rows): @@ -134,7 +196,7 @@ class Config(Component.Config): #: will not provide any data. source: DataSource.Config = DataSource.Config() #: How training examples are split into batches for the optimizer. - batcher: Batcher.Config = Batcher.Config() + batcher: Batcher.Config = PoolingBatcher.Config() sort_key: Optional[str] = None @classmethod diff --git a/pytext/data/test/data_test.py b/pytext/data/test/data_test.py index c6afd8b42..89c606c13 100644 --- a/pytext/data/test/data_test.py +++ b/pytext/data/test/data_test.py @@ -4,7 +4,7 @@ import unittest from pytext.common.constants import Stage -from pytext.data import Batcher, Data +from pytext.data import Batcher, Data, PoolingBatcher from pytext.data.sources.data_source import SafeFileWrapper from pytext.data.sources.tsv import TSVDataSource from pytext.data.tensorizers import LabelTensorizer, WordTensorizer @@ -117,3 +117,19 @@ def test_batcher(self): self.assertEqual(len(batches), 4) self.assertEqual(batches[1]["a"], [3, 4, 5]) self.assertEqual(batches[3]["b"], [19]) + + def test_pooling_batcher(self): + data = [{"a": i, "b": 10 + i, "c": 20 + i} for i in range(10)] + batcher = PoolingBatcher(train_batch_size=3, pool_num_batches=2) + batches = list(batcher.batchify(data, sort_key=lambda x: x["a"])) + + self.assertEqual(len(batches), 4) + a_vals = {a for batch in batches for a in batch["a"]} + self.assertSetEqual(a_vals, set(range(10))) + for batch in batches[:2]: + self.assertGreater(batch["a"][0], batch["a"][-1]) + for a in batch["a"]: + self.assertLess(a, 6) + for batch in batches[2:]: + for a in batch["a"]: + self.assertGreaterEqual(a, 6)