Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Add shuffling to new pytext data #410

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions pytext/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .bptt_lm_data_handler import BPTTLanguageModelDataHandler
from .compositional_data_handler import CompositionalDataHandler
from .contextual_intent_slot_data_handler import ContextualIntentSlotModelDataHandler
from .data import Batcher, Data
from .data import Batcher, Data, PoolingBatcher
from .data_handler import BatchIterator, CommonMetadata, DataHandler
from .disjoint_multitask_data_handler import DisjointMultitaskDataHandler
from .doc_classification_data_handler import DocClassificationDataHandler, RawData
Expand All @@ -19,19 +19,20 @@

__all__ = [
"Batcher",
"BatchIterator",
"BPTTLanguageModelDataHandler",
"CommonMetadata",
"CompositionalDataHandler",
"ContextualIntentSlotModelDataHandler",
"BatchIterator",
"CommonMetadata",
"Data",
"DataHandler",
"DisjointMultitaskDataHandler",
"DocClassificationDataHandler",
"JointModelDataHandler",
"LanguageModelDataHandler",
"PairClassificationDataHandler",
"SeqModelDataHandler",
"DocClassificationDataHandler",
"RawData",
"DisjointMultitaskDataHandler",
"PoolingBatcher",
"QueryDocumentPairwiseRankingDataHandler",
"RawData",
"SeqModelDataHandler",
]
84 changes: 73 additions & 11 deletions pytext/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

import functools
import itertools
import math
import random
from typing import Dict, Iterable, Optional, Type

from pytext.common.constants import Stage
Expand All @@ -17,6 +19,7 @@ class Batcher(Component):
"""Batcher designed to batch rows of data, before padding."""

__COMPONENT_TYPE__ = ComponentType.BATCHER
__EXPANSIBLE__ = True

class Config(Component.Config):
#: Make batches of this size when possible. If there's not enough data,
Expand All @@ -40,23 +43,82 @@ def __init__(
self.train_batch_size = train_batch_size
self.eval_batch_size = eval_batch_size
self.test_batch_size = test_batch_size
self._batch_sizes = {
Stage.TRAIN: self.train_batch_size,
Stage.TEST: self.eval_batch_size,
Stage.EVAL: self.test_batch_size,
}

def batchify(
self, iterable: Iterable[RawExample], sort_key=None, stage=Stage.TRAIN
):
"""Group rows by batch_size. Assume iterable of dicts, yield dict of lists.
The last batch will be of length len(iterable) % batch_size."""
batch_size = {
Stage.TRAIN: self.train_batch_size,
Stage.TEST: self.eval_batch_size,
Stage.EVAL: self.test_batch_size,
}[stage]
iterators = [iter(iterable)] * batch_size
for batch in itertools.zip_longest(*iterators):
res = [ex for ex in batch if ex is not None]
batch_size = self._batch_sizes[stage]
for batch in self._group_iter(iterable, batch_size, sort_key):
yield zip_dicts(batch)

def _group_iter(self, iterable: Iterable[RawExample], group_size, sort_key=None):
iterators = [iter(iterable)] * group_size
for group in itertools.zip_longest(*iterators):
group = [ex for ex in group if ex is not None]
if sort_key:
group.sort(key=sort_key, reverse=True)
yield group


class PoolingBatcher(Batcher):
"""
Batcher that looks at pools of data, and sorts, batches, and shuffles them, before
padding.
"""

class Config(Batcher.Config):
#: Number of batches in a pool, to load at one time.
pool_num_batches: int = 10000

@classmethod
def from_config(cls, config: Config):
return cls(
config.train_batch_size,
config.eval_batch_size,
config.test_batch_size,
config.pool_num_batches,
)

def __init__(
self,
train_batch_size=Config.train_batch_size,
eval_batch_size=Config.eval_batch_size,
test_batch_size=Config.test_batch_size,
pool_num_batches=Config.pool_num_batches,
):
super().__init__(train_batch_size, eval_batch_size, test_batch_size)
self.pool_num_batches = pool_num_batches or 1

def batchify(
self, iterable: Iterable[RawExample], sort_key=None, stage=Stage.TRAIN
):
"""
From an iterable of dicts, yield dicts of lists, by

1. Load pool of batch_size * pool_num_batches examples.
2. Sort rows, if necessary.
3. Form batches with batch_size examples each.
4. Shuffle batches and yield all batches.
"""
batch_size = self._batch_sizes[stage]
pool_size = batch_size * self.pool_num_batches

for pool in self._group_iter(iterable, pool_size, sort_key):
batch_indices = list(range(math.ceil(len(pool) / batch_size)))
if sort_key:
res.sort(reverse=True, key=sort_key)
yield zip_dicts(res)
random.shuffle(batch_indices)
else:
random.shuffle(pool)
for batch_index in batch_indices:
batch = pool[batch_size * batch_index : batch_size * (batch_index + 1)]
yield zip_dicts(batch)


def numberize_rows(tensorizers, rows):
Expand Down Expand Up @@ -134,7 +196,7 @@ class Config(Component.Config):
#: will not provide any data.
source: DataSource.Config = DataSource.Config()
#: How training examples are split into batches for the optimizer.
batcher: Batcher.Config = Batcher.Config()
batcher: Batcher.Config = PoolingBatcher.Config()
sort_key: Optional[str] = None

@classmethod
Expand Down
18 changes: 17 additions & 1 deletion pytext/data/test/data_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import unittest

from pytext.common.constants import Stage
from pytext.data import Batcher, Data
from pytext.data import Batcher, Data, PoolingBatcher
from pytext.data.sources.data_source import SafeFileWrapper
from pytext.data.sources.tsv import TSVDataSource
from pytext.data.tensorizers import LabelTensorizer, WordTensorizer
Expand Down Expand Up @@ -117,3 +117,19 @@ def test_batcher(self):
self.assertEqual(len(batches), 4)
self.assertEqual(batches[1]["a"], [3, 4, 5])
self.assertEqual(batches[3]["b"], [19])

def test_pooling_batcher(self):
data = [{"a": i, "b": 10 + i, "c": 20 + i} for i in range(10)]
batcher = PoolingBatcher(train_batch_size=3, pool_num_batches=2)
batches = list(batcher.batchify(data, sort_key=lambda x: x["a"]))

self.assertEqual(len(batches), 4)
a_vals = {a for batch in batches for a in batch["a"]}
self.assertSetEqual(a_vals, set(range(10)))
for batch in batches[:2]:
self.assertGreater(batch["a"][0], batch["a"][-1])
for a in batch["a"]:
self.assertLess(a, 6)
for batch in batches[2:]:
for a in batch["a"]:
self.assertGreaterEqual(a, 6)