Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

customizable transform statistics #2059

Merged
merged 11 commits into from
May 7, 2021
2 changes: 1 addition & 1 deletion onmt/bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def train(opt):
opt, fields, transforms_cls, stride=nb_gpu, offset=device_id)
producer = mp.Process(target=batch_producer,
args=(train_iter, queues[device_id],
semaphore, opt,),
semaphore, opt, device_id),
daemon=True)
producers.append(producer)
producers[device_id].start()
Expand Down
15 changes: 10 additions & 5 deletions onmt/inputters/corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def __init__(self, name, src, tgt, align=None):
self.tgt = tgt
self.align = align

def load(self, offset=0, stride=1):
def load(self, offset=0, stride=1, log_level="warning"):
"""
Load file and iterate by lines.
`offset` and `stride` allow to iterate only on every
Expand All @@ -123,7 +123,10 @@ def load(self, offset=0, stride=1):
with exfile_open(self.src, mode='rb') as fs,\
exfile_open(self.tgt, mode='rb') as ft,\
exfile_open(self.align, mode='rb') as fa:
logger.info(f"Loading {repr(self)}...")
if log_level == "error":
logger.info(f"Loading {str(self)}...")
elif log_level == "warning":
logger.info(f"Loading {self.id}...")
francoishernandez marked this conversation as resolved.
Show resolved Hide resolved
for i, (sline, tline, align) in enumerate(zip(fs, ft, fa)):
if (i % stride) == offset:
sline = sline.decode('utf-8')
Expand All @@ -136,7 +139,7 @@ def load(self, offset=0, stride=1):
example['align'] = align.decode('utf-8')
yield example

def __repr__(self):
def __str__(self):
cls_name = type(self).__name__
return '{}({}, {}, align={})'.format(
cls_name, self.src, self.tgt, self.align)
Expand Down Expand Up @@ -208,7 +211,7 @@ def _transform(self, stream):
yield item
report_msg = self.transform.stats()
if report_msg != '':
logger.info("Transform statistics for {}:\n{}".format(
logger.info("* Transform statistics for {}:\n{}\n".format(
self.cid, report_msg))

def _add_index(self, stream):
Expand All @@ -229,7 +232,9 @@ def _add_index(self, stream):

def _iter_corpus(self):
corpus_stream = self.corpus.load(
stride=self.stride, offset=self.offset)
stride=self.stride, offset=self.offset,
log_level=self.skip_empty_level
Zenglinxiao marked this conversation as resolved.
Show resolved Hide resolved
)
tokenized_corpus = self._tokenize(corpus_stream)
transformed_corpus = self._transform(tokenized_corpus)
indexed_corpus = self._add_index(transformed_corpus)
Expand Down
8 changes: 5 additions & 3 deletions onmt/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ def _add_logging_opts(parser, is_train=True):
action=StoreLoggingLevelAction,
choices=StoreLoggingLevelAction.CHOICES,
default="0")
group.add('--verbose', '-verbose', action="store_true",
help='Print data loading and statistics for all process'
'(default only log the first process shard)' if is_train
else 'Print scores and predictions for each sentence')

if is_train:
group.add('--report_every', '-report_every', type=int, default=50,
Expand All @@ -44,8 +48,6 @@ def _add_logging_opts(parser, is_train=True):
"This is also the name of the run.")
else:
# Options only during inference
group.add('--verbose', '-verbose', action="store_true",
help='Print scores and predictions for each sentence')
group.add('--attn_debug', '-attn_debug', action="store_true",
help='Print best attn for each word')
group.add('--align_debug', '-align_debug', action="store_true",
Expand Down Expand Up @@ -75,7 +77,7 @@ def _add_dynamic_corpus_opts(parser, build_vocab_only=False):
help="Security level when encounter empty examples."
"silent: silently ignore/skip empty example;"
"warning: warning when ignore/skip empty example;"
"error: raise error & stop excution when encouter empty.)")
"error: raise error & stop execution when encouter empty.")
group.add("-transforms", "--transforms", default=[], nargs="+",
choices=AVAILABLE_TRANSFORMS.keys(),
help="Default transform pipeline to apply to data. "
Expand Down
48 changes: 47 additions & 1 deletion onmt/tests/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
import yaml
import math
from argparse import Namespace
from onmt.transforms import get_transforms_cls, get_specials, make_transforms
from onmt.transforms import (
get_transforms_cls,
get_specials,
make_transforms,
TransformPipe,
)
from onmt.transforms.bart import BARTNoising


Expand Down Expand Up @@ -51,6 +56,47 @@ def test_transform_specials(self):
self.assertEqual(specials, specials_expected)


def test_transform_pipe(self):
# 1. Init first transform in the pipe
prefix_cls = get_transforms_cls(["prefix"])["prefix"]
corpora = yaml.safe_load("""
trainset:
path_src: data/src-train.txt
path_tgt: data/tgt-train.txt
transforms: [prefix, filtertoolong]
weight: 1
src_prefix: "⦅_pf_src⦆"
tgt_prefix: "⦅_pf_tgt⦆"
""")
opt = Namespace(data=corpora, seed=-1)
prefix_transform = prefix_cls(opt)
prefix_transform.warm_up()
# 2. Init second transform in the pipe
filter_cls = get_transforms_cls(["filtertoolong"])["filtertoolong"]
opt = Namespace(src_seq_length=4, tgt_seq_length=4)
filter_transform = filter_cls(opt)
# 3. Sequential combine them into a transform pipe
transform_pipe = TransformPipe.build_from(
[prefix_transform, filter_transform]
)
ex = {
"src": ["Hello", ",", "world", "."],
"tgt": ["Bonjour", "le", "monde", "."],
}
# 4. apply transform pipe for example
ex_after = transform_pipe.apply(
copy.deepcopy(ex), corpus_name="trainset"
)
# 5. example after the pipe exceed the length limit, thus filtered
self.assertIsNone(ex_after)
# 6. Transform statistics registed (here for filtertoolong)
self.assertTrue(len(transform_pipe.statistics.observables) > 0)
msg = transform_pipe.statistics.report()
self.assertIsNotNone(msg)
# 7. after report, statistics become empty as a fresh start
self.assertTrue(len(transform_pipe.statistics.observables) == 0)


class TestMiscTransform(unittest.TestCase):
def test_prefix(self):
prefix_cls = get_transforms_cls(["prefix"])["prefix"]
Expand Down
15 changes: 13 additions & 2 deletions onmt/transforms/misc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
from onmt.utils.logging import logger
from onmt.transforms import register_transform
from .transform import Transform
from .transform import Transform, ObservableStats


class FilterTooLongStats(ObservableStats):
"""Runing statistics for FilterTooLongTransform."""
__slots__ = ["filtered"]

def __init__(self):
self.filtered = 1

def update(self, other: "FilterTooLongStats"):
self.filtered += other.filtered


@register_transform(name='filtertoolong')
Expand Down Expand Up @@ -28,7 +39,7 @@ def apply(self, example, is_train=False, stats=None, **kwargs):
if (len(example['src']) > self.src_seq_length or
len(example['tgt']) > self.tgt_seq_length):
if stats is not None:
stats.filter_too_long()
stats.update(FilterTooLongStats())
return None
else:
return example
Expand Down
50 changes: 46 additions & 4 deletions onmt/transforms/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
from onmt.constants import DefaultTokens
from onmt.transforms import register_transform
from .transform import Transform
from .transform import Transform, ObservableStats


class HammingDistanceSampling(object):
Expand Down Expand Up @@ -44,6 +44,20 @@ def _set_seed(self, seed):
random.seed(seed)


class SwitchOutStats(ObservableStats):
"""Runing statistics for counting tokens being switched out."""

__slots__ = ["changed", "total"]

def __init__(self, changed: int, total: int):
self.changed = changed
self.total = total

def update(self, other: "SwitchOutStats"):
self.changed += other.changed
self.total += other.total


@register_transform(name='switchout')
class SwitchOutTransform(HammingDistanceSamplingTransform):
"""
Expand Down Expand Up @@ -81,7 +95,7 @@ def _switchout(self, tokens, vocab, stats=None):
for i in chosen_indices:
tokens[i] = self._sample_replace(vocab, reject=tokens[i])
if stats is not None:
stats.switchout(n_switchout=n_chosen, n_total=len(tokens))
stats.update(SwitchOutStats(n_chosen, len(tokens)))
return tokens

def apply(self, example, is_train=False, stats=None, **kwargs):
Expand All @@ -98,6 +112,20 @@ def _repr_args(self):
return '{}={}'.format('switchout_temperature', self.temperature)


class TokenDropStats(ObservableStats):
"""Runing statistics for counting tokens being switched out."""

__slots__ = ["dropped", "total"]

def __init__(self, dropped: int, total: int):
self.dropped = dropped
self.total = total

def update(self, other: "TokenDropStats"):
self.dropped += other.dropped
self.total += other.total


@register_transform(name='tokendrop')
class TokenDropTransform(HammingDistanceSamplingTransform):
"""Random drop tokens from sentence."""
Expand Down Expand Up @@ -126,7 +154,7 @@ def _token_drop(self, tokens, stats=None):
out = [tok for (i, tok) in enumerate(tokens)
if i not in chosen_indices]
if stats is not None:
stats.token_drop(n_dropped=n_chosen, n_total=n_items)
stats.update(TokenDropStats(n_chosen, n_items))
return out

def apply(self, example, is_train=False, stats=None, **kwargs):
Expand All @@ -141,6 +169,20 @@ def _repr_args(self):
return '{}={}'.format('tokendrop_temperature', self.temperature)


class TokenMaskStats(ObservableStats):
"""Runing statistics for counting tokens being switched out."""

__slots__ = ["masked", "total"]

def __init__(self, masked: int, total: int):
self.masked = masked
self.total = total

def update(self, other: "TokenMaskStats"):
self.masked += other.masked
self.total += other.total


@register_transform(name='tokenmask')
class TokenMaskTransform(HammingDistanceSamplingTransform):
"""Random mask tokens from src sentence."""
Expand Down Expand Up @@ -175,7 +217,7 @@ def _token_mask(self, tokens, stats=None):
for i in chosen_indices:
tokens[i] = self.MASK_TOK
if stats is not None:
stats.token_mask(n_masked=n_chosen, n_total=len(tokens))
stats.update(TokenDropStats(n_chosen, len(tokens)))
return tokens

def apply(self, example, is_train=False, stats=None, **kwargs):
Expand Down
27 changes: 23 additions & 4 deletions onmt/transforms/tokenize.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Transforms relate to tokenization/subword."""
from onmt.utils.logging import logger
from onmt.transforms import register_transform
from .transform import Transform
from .transform import Transform, ObservableStats


class TokenizerTransform(Transform):
Expand Down Expand Up @@ -107,6 +107,25 @@ def _repr_args(self):
return ', '.join([f'{kw}={arg}' for kw, arg in kwargs.items()])


class SubwordStats(ObservableStats):
"""Runing statistics for counting tokens before/after subword transform."""

__slots__ = ["subwords", "words"]

def __init__(self, subwords: int, words: int):
self.subwords = subwords
self.words = words

def update(self, other: "SubwordStats"):
self.subwords += other.subwords
self.words += other.words

def __str__(self) -> str:
return "{}: {} -> {} tokens".format(
self.name(), self.words, self.subwords
)


@register_transform(name='sentencepiece')
class SentencePieceTransform(TokenizerTransform):
"""SentencePiece subword transform class."""
Expand Down Expand Up @@ -173,7 +192,7 @@ def apply(self, example, is_train=False, stats=None, **kwargs):
if stats is not None:
n_words = len(example['src']) + len(example['tgt'])
n_subwords = len(src_out) + len(tgt_out)
stats.subword(n_subwords, n_words)
stats.update(SubwordStats(n_subwords, n_words))
example['src'], example['tgt'] = src_out, tgt_out
return example

Expand Down Expand Up @@ -246,7 +265,7 @@ def apply(self, example, is_train=False, stats=None, **kwargs):
if stats is not None:
n_words = len(example['src']) + len(example['tgt'])
n_subwords = len(src_out) + len(tgt_out)
stats.subword(n_subwords, n_words)
stats.update(SubwordStats(n_subwords, n_words))
example['src'], example['tgt'] = src_out, tgt_out
return example

Expand Down Expand Up @@ -398,7 +417,7 @@ def apply(self, example, is_train=False, stats=None, **kwargs):
if stats is not None:
n_words = len(example['src']) + len(example['tgt'])
n_subwords = len(src_out) + len(tgt_out)
stats.subword(n_subwords, n_words)
stats.update(SubwordStats(n_subwords, n_words))
example['src'], example['tgt'] = src_out, tgt_out
return example

Expand Down
Loading