Skip to content

Commit

Permalink
Add shuffling / pooling to new pytext data batcher (facebookresearch#410
Browse files Browse the repository at this point in the history
)

Summary:
Pull Request resolved: facebookresearch#410

Create a new batcher that will sort, batch from, and shuffle from pools of examples (instead of streaming one batch at a time as the current batcher does).

Differential Revision: D14513256

fbshipit-source-id: 17644c669084ac92cecbdfcc25988970e06a1767
  • Loading branch information
Michael Wu authored and facebook-github-bot committed Mar 19, 2019
1 parent d72f977 commit 4c79e0a
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 19 deletions.
15 changes: 8 additions & 7 deletions pytext/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .bptt_lm_data_handler import BPTTLanguageModelDataHandler
from .compositional_data_handler import CompositionalDataHandler
from .contextual_intent_slot_data_handler import ContextualIntentSlotModelDataHandler
from .data import Batcher, Data
from .data import Batcher, Data, PoolingBatcher
from .data_handler import BatchIterator, CommonMetadata, DataHandler
from .disjoint_multitask_data_handler import DisjointMultitaskDataHandler
from .doc_classification_data_handler import DocClassificationDataHandler, RawData
Expand All @@ -19,19 +19,20 @@

__all__ = [
"Batcher",
"BatchIterator",
"BPTTLanguageModelDataHandler",
"CommonMetadata",
"CompositionalDataHandler",
"ContextualIntentSlotModelDataHandler",
"BatchIterator",
"CommonMetadata",
"Data",
"DataHandler",
"DisjointMultitaskDataHandler",
"DocClassificationDataHandler",
"JointModelDataHandler",
"LanguageModelDataHandler",
"PairClassificationDataHandler",
"SeqModelDataHandler",
"DocClassificationDataHandler",
"RawData",
"DisjointMultitaskDataHandler",
"PoolingBatcher",
"QueryDocumentPairwiseRankingDataHandler",
"RawData",
"SeqModelDataHandler",
]
84 changes: 73 additions & 11 deletions pytext/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

import functools
import itertools
import math
import random
from typing import Dict, Iterable, Optional, Type

from pytext.common.constants import Stage
Expand All @@ -17,6 +19,7 @@ class Batcher(Component):
"""Batcher designed to batch rows of data, before padding."""

__COMPONENT_TYPE__ = ComponentType.BATCHER
__EXPANSIBLE__ = True

class Config(Component.Config):
#: Make batches of this size when possible. If there's not enough data,
Expand All @@ -40,23 +43,82 @@ def __init__(
self.train_batch_size = train_batch_size
self.eval_batch_size = eval_batch_size
self.test_batch_size = test_batch_size
self._batch_sizes = {
Stage.TRAIN: self.train_batch_size,
Stage.TEST: self.eval_batch_size,
Stage.EVAL: self.test_batch_size,
}

def batchify(
self, iterable: Iterable[RawExample], sort_key=None, stage=Stage.TRAIN
):
"""Group rows by batch_size. Assume iterable of dicts, yield dict of lists.
The last batch will be of length len(iterable) % batch_size."""
batch_size = {
Stage.TRAIN: self.train_batch_size,
Stage.TEST: self.eval_batch_size,
Stage.EVAL: self.test_batch_size,
}[stage]
iterators = [iter(iterable)] * batch_size
for batch in itertools.zip_longest(*iterators):
res = [ex for ex in batch if ex is not None]
batch_size = self._batch_sizes[stage]
for batch in self._group_iter(iterable, batch_size, sort_key):
yield zip_dicts(batch)

def _group_iter(self, iterable: Iterable[RawExample], group_size, sort_key=None):
iterators = [iter(iterable)] * group_size
for group in itertools.zip_longest(*iterators):
group = [ex for ex in group if ex is not None]
if sort_key:
group.sort(key=sort_key, reverse=True)
yield group


class PoolingBatcher(Batcher):
"""
Batcher that looks at pools of data, and sorts, batches, and shuffles them, before
padding.
"""

class Config(Batcher.Config):
#: Number of batches in a pool, to load at one time.
pool_num_batches: int = 10000

@classmethod
def from_config(cls, config: Config):
return cls(
config.train_batch_size,
config.eval_batch_size,
config.test_batch_size,
config.pool_num_batches,
)

def __init__(
self,
train_batch_size=Config.train_batch_size,
eval_batch_size=Config.eval_batch_size,
test_batch_size=Config.test_batch_size,
pool_num_batches=Config.pool_num_batches,
):
super().__init__(train_batch_size, eval_batch_size, test_batch_size)
self.pool_num_batches = pool_num_batches or 1

def batchify(
self, iterable: Iterable[RawExample], sort_key=None, stage=Stage.TRAIN
):
"""
From an iterable of dicts, yield dicts of lists, by
1. Load pool of batch_size * pool_num_batches examples.
2. Sort rows, if necessary.
3. Form batches with batch_size examples each.
4. Shuffle batches and yield all batches.
"""
batch_size = self._batch_sizes[stage]
pool_size = batch_size * self.pool_num_batches

for pool in self._group_iter(iterable, pool_size, sort_key):
batch_indices = list(range(math.ceil(len(pool) / batch_size)))
if sort_key:
res.sort(reverse=True, key=sort_key)
yield zip_dicts(res)
random.shuffle(batch_indices)
else:
random.shuffle(pool)
for batch_index in batch_indices:
batch = pool[batch_size * batch_index : batch_size * (batch_index + 1)]
yield zip_dicts(batch)


def numberize_rows(tensorizers, rows):
Expand Down Expand Up @@ -134,7 +196,7 @@ class Config(Component.Config):
#: will not provide any data.
source: DataSource.Config = DataSource.Config()
#: How training examples are split into batches for the optimizer.
batcher: Batcher.Config = Batcher.Config()
batcher: Batcher.Config = PoolingBatcher.Config()
sort_key: Optional[str] = None

@classmethod
Expand Down
15 changes: 14 additions & 1 deletion pytext/data/test/data_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import unittest

from pytext.common.constants import Stage
from pytext.data import Batcher, Data
from pytext.data import Batcher, Data, PoolingBatcher
from pytext.data.sources.data_source import SafeFileWrapper
from pytext.data.sources.tsv import TSVDataSource
from pytext.data.tensorizers import LabelTensorizer, WordTensorizer
Expand Down Expand Up @@ -117,3 +117,16 @@ def test_batcher(self):
self.assertEqual(len(batches), 4)
self.assertEqual(batches[1]["a"], [3, 4, 5])
self.assertEqual(batches[3]["b"], [19])

def test_pooling_batcher(self):
data = [{"a": i, "b": 10 + i, "c": 20 + i} for i in range(10)]
batcher = PoolingBatcher(train_batch_size=3, pool_num_batches=2)
batches = list(batcher.batchify(data, sort_key=lambda x: x["a"]))
self.assertEqual(len(batches), 4)
for batch in batches[:2]:
self.assertGreater(batch["a"][0], batch["a"][-1])
for a in batch["a"]:
self.assertLess(a, 6)
for batch in batches[2:]:
for a in batch["a"]:
self.assertGreaterEqual(a, 6)

0 comments on commit 4c79e0a

Please sign in to comment.