From 5a50d960b6e20527a006b176419b188d79a85a85 Mon Sep 17 00:00:00 2001 From: Matt Post Date: Fri, 28 Sep 2018 03:26:51 -0400 Subject: [PATCH] Scoring (#538) This implements scoring of translations given source, by fully reusing the training computation graph, per @bricksdont's original suggestion. --- CHANGELOG.md | 2 + docs/modules.rst | 7 + setup.py | 1 + sockeye/arguments.py | 90 ++++++++--- sockeye/constants.py | 16 ++ sockeye/data_io.py | 182 ++++++++++++++-------- sockeye/inference.py | 3 +- sockeye/output_handler.py | 56 ++++++- sockeye/score.py | 169 ++++++++++++++++++++ sockeye/scoring.py | 266 ++++++++++++++++++++++++++++++++ sockeye/train.py | 5 +- sockeye/utils.py | 17 ++ test/common.py | 117 +++++++++++++- test/unit/test_data_io.py | 2 +- tutorials/README.md | 1 + tutorials/constraints/README.md | 2 + tutorials/scoring.md | 38 +++++ typechecked-files | 4 +- 18 files changed, 885 insertions(+), 93 deletions(-) create mode 100644 sockeye/score.py create mode 100644 sockeye/scoring.py create mode 100644 tutorials/scoring.md diff --git a/CHANGELOG.md b/CHANGELOG.md index 8c3c9eeaa..5a3c12fc2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,8 @@ Note that Sockeye has checks in place to not translate with an old model that wa Each version section may have have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_. ## [1.18.57] +### Added +- Added `sockeye.score` CLI for quickly scoring existing translations ([documentation](tutorials/scoring.md)). ### Fixed - Entry-point clean-up after the contrib/ rename diff --git a/docs/modules.rst b/docs/modules.rst index ec193bc34..adb330f01 100644 --- a/docs/modules.rst +++ b/docs/modules.rst @@ -176,6 +176,13 @@ sockeye.rnn_attention module :members: :show-inheritance: +sockeye.score module +-------------------- + +.. automodule:: sockeye.score + :members: + :show-inheritance: + sockeye.train module -------------------- diff --git a/setup.py b/setup.py index 3d5a412ab..ffa2a7b7c 100644 --- a/setup.py +++ b/setup.py @@ -82,6 +82,7 @@ def get_requirements(filename): 'sockeye-lexicon = sockeye.lexicon:main', 'sockeye-init-embed = sockeye.init_embedding:main', 'sockeye-prepare-data = sockeye.prepare_data:main', + 'sockeye-score = sockeye.score:main', 'sockeye-train = sockeye.train:main', 'sockeye-translate = sockeye.translate:main', 'sockeye-vocab = sockeye.vocab:main', diff --git a/sockeye/arguments.py b/sockeye/arguments.py index c457d7837..7cede0567 100644 --- a/sockeye/arguments.py +++ b/sockeye/arguments.py @@ -760,33 +760,37 @@ def add_model_parameters(params): "(and all convolutional weight matrices for CNN decoders). Default: %(default)s.") +def add_batch_args(params, default_batch_size=4096): + params.add_argument('--batch-size', '-b', + type=int_greater_or_equal(1), + default=default_batch_size, + help='Mini-batch size. Note that depending on the batch-type this either refers to ' + 'words or sentences.' + 'Sentence: each batch contains X sentences, number of words varies. ' + 'Word: each batch contains (approximately) X words, number of sentences varies. ' + 'Default: %(default)s.') + params.add_argument("--batch-type", + type=str, + default=C.BATCH_TYPE_WORD, + choices=[C.BATCH_TYPE_SENTENCE, C.BATCH_TYPE_WORD], + help="Sentence: each batch contains X sentences, number of words varies." + "Word: each batch contains (approximately) X target words, " + "number of sentences varies. Default: %(default)s.") + + def add_training_args(params): train_params = params.add_argument_group("Training parameters") + add_batch_args(train_params) + train_params.add_argument('--decoder-only', action='store_true', help='Pre-train a decoder. This is currently for RNN decoders only. ' 'Default: %(default)s.') - - train_params.add_argument('--batch-size', '-b', - type=int_greater_or_equal(1), - default=4096, - help='Mini-batch size. Note that depending on the batch-type this either refers to ' - 'words or sentences.' - 'Sentence: each batch contains X sentences, number of words varies. ' - 'Word: each batch contains (approximately) X words, number of sentences varies. ' - 'Default: %(default)s.') - train_params.add_argument("--batch-type", - type=str, - default=C.BATCH_TYPE_WORD, - choices=[C.BATCH_TYPE_SENTENCE, C.BATCH_TYPE_WORD], - help="Sentence: each batch contains X sentences, number of words varies." - "Word: each batch contains (approximately) X target words, " - "number of sentences varies. Default: %(default)s.") - train_params.add_argument('--fill-up', type=str, - default='replicate', + default=C.FILL_UP_DEFAULT, + choices=C.FILL_UP_CHOICES, help=argparse.SUPPRESS) train_params.add_argument('--loss', @@ -1075,6 +1079,56 @@ def add_translate_cli_args(params): add_logging_args(params) +def add_score_cli_args(params): + add_training_data_args(params, required=False) + add_vocab_args(params) + add_device_args(params) + add_logging_args(params) + add_batch_args(params, default_batch_size=500) + + params = params.add_argument_group("Scoring parameters") + + params.add_argument("--model", "-m", required=True, + help="Model directory containing trained model.") + + params.add_argument('--max-seq-len', + type=multiple_values(num_values=2, greater_or_equal=1), + default=None, + help='Maximum sequence length in tokens.' + 'Use "x:x" to specify separate values for src&tgt. Default: Read from model.') + + params.add_argument('--length-penalty-alpha', + default=1.0, + type=float, + help='Alpha factor for the length penalty used in scoring: ' + '(beta + len(Y))**alpha/(beta + 1)**alpha. A value of 0.0 will therefore turn off ' + 'length normalization. Default: %(default)s') + + params.add_argument('--length-penalty-beta', + default=0.0, + type=float, + help='Beta factor for the length penalty used in scoring: ' + '(beta + len(Y))**alpha/(beta + 1)**alpha. Default: %(default)s') + + params.add_argument('--softmax-temperature', + type=float, + default=None, + help='Controls peakiness of model predictions. Values < 1.0 produce ' + 'peaked predictions, values > 1.0 produce smoothed distributions.') + + params.add_argument("--output", "-o", default=None, + help="File to write output to. Default: STDOUT.") + + params.add_argument('--output-type', + default=C.OUTPUT_HANDLER_SCORE, + choices=C.OUTPUT_HANDLERS_SCORING, + help='Output type. Default: %(default)s.') + + params.add_argument('--score-type', + choices=C.SCORING_TYPE_CHOICES, + default=C.SCORING_TYPE_DEFAULT, + help='Score type to output. Default: %(default)s') + def add_max_output_cli_args(params): params.add_argument('--max-output-length', type=int, diff --git a/sockeye/constants.py b/sockeye/constants.py index 0473bde73..93caa3dd9 100644 --- a/sockeye/constants.py +++ b/sockeye/constants.py @@ -318,11 +318,14 @@ OUTPUT_HANDLER_TRANSLATION_WITH_SCORE = "translation_with_score" OUTPUT_HANDLER_TRANSLATION_WITH_ALIGNMENTS = "translation_with_alignments" OUTPUT_HANDLER_TRANSLATION_WITH_ALIGNMENT_MATRIX = "translation_with_alignment_matrix" +OUTPUT_HANDLER_SCORE = "score" +OUTPUT_HANDLER_PAIR_WITH_SCORE = "pair_with_score" OUTPUT_HANDLER_BENCHMARK = "benchmark" OUTPUT_HANDLER_ALIGN_PLOT = "align_plot" OUTPUT_HANDLER_ALIGN_TEXT = "align_text" OUTPUT_HANDLER_BEAM_STORE = "beam_store" OUTPUT_HANDLERS = [OUTPUT_HANDLER_TRANSLATION, + OUTPUT_HANDLER_SCORE, OUTPUT_HANDLER_TRANSLATION_WITH_SCORE, OUTPUT_HANDLER_TRANSLATION_WITH_ALIGNMENTS, OUTPUT_HANDLER_TRANSLATION_WITH_ALIGNMENT_MATRIX, @@ -330,6 +333,8 @@ OUTPUT_HANDLER_ALIGN_PLOT, OUTPUT_HANDLER_ALIGN_TEXT, OUTPUT_HANDLER_BEAM_STORE] +OUTPUT_HANDLERS_SCORING = [OUTPUT_HANDLER_SCORE, + OUTPUT_HANDLER_PAIR_WITH_SCORE] # metrics ACCURACY = 'accuracy' @@ -394,7 +399,18 @@ PREPARED_DATA_VERSION_FILE = "data.version" PREPARED_DATA_VERSION = 2 +FILL_UP_REPLICATE = 'replicate' +FILL_UP_ZEROS = 'zeros' +FILL_UP_DEFAULT = FILL_UP_REPLICATE +FILL_UP_CHOICES = [FILL_UP_REPLICATE, FILL_UP_ZEROS] + # reranking RERANK_BLEU = "bleu" RERANK_CHRF = "chrf" RERANK_METRICS = [RERANK_BLEU, RERANK_CHRF] + +# scoring +SCORING_TYPE_NEGLOGPROB = 'neglogprob' +SCORING_TYPE_LOGPROB = 'logprob' +SCORING_TYPE_DEFAULT = SCORING_TYPE_NEGLOGPROB +SCORING_TYPE_CHOICES = [SCORING_TYPE_NEGLOGPROB, SCORING_TYPE_LOGPROB] diff --git a/sockeye/data_io.py b/sockeye/data_io.py index bb270ae7f..9a7e40249 100644 --- a/sockeye/data_io.py +++ b/sockeye/data_io.py @@ -23,7 +23,7 @@ from abc import ABC, abstractmethod from collections import OrderedDict from contextlib import ExitStack -from typing import Any, cast, Dict, Iterator, Iterable, List, Optional, Sequence, Sized, Tuple +from typing import Any, cast, Dict, Iterator, Iterable, List, Optional, Sequence, Sized, Tuple, Set import mxnet as mx import numpy as np @@ -675,9 +675,10 @@ def get_prepared_data_iters(prepared_data_dir: str, batch_size: int, batch_by_words: bool, batch_num_devices: int, - fill_up: str) -> Tuple['BaseParallelSampleIter', - 'BaseParallelSampleIter', - 'DataConfig', List[vocab.Vocab], vocab.Vocab]: + fill_up: str, + permute: bool = True) -> Tuple['BaseParallelSampleIter', + 'BaseParallelSampleIter', + 'DataConfig', List[vocab.Vocab], vocab.Vocab]: logger.info("===============================") logger.info("Creating training data iterator") logger.info("===============================") @@ -731,7 +732,8 @@ def get_prepared_data_iters(prepared_data_dir: str, batch_size, bucket_batch_sizes, fill_up, - num_factors=len(data_info.sources)) + num_factors=len(data_info.sources), + permute=permute) data_loader = RawParallelDatasetLoader(buckets=buckets, eos_id=target_vocab[C.EOS_SYMBOL], @@ -754,8 +756,8 @@ def get_prepared_data_iters(prepared_data_dir: str, def get_training_data_iters(sources: List[str], target: str, - validation_sources: List[str], - validation_target: str, + validation_sources: Optional[List[str]], + validation_target: Optional[str], source_vocabs: List[vocab.Vocab], target_vocab: vocab.Vocab, source_vocab_paths: List[Optional[str]], @@ -768,9 +770,10 @@ def get_training_data_iters(sources: List[str], max_seq_len_source: int, max_seq_len_target: int, bucketing: bool, - bucket_width: int) -> Tuple['BaseParallelSampleIter', - 'BaseParallelSampleIter', - 'DataConfig', 'DataInfo']: + bucket_width: int, + permute: bool = True) -> Tuple['BaseParallelSampleIter', + Optional['BaseParallelSampleIter'], + 'DataConfig', 'DataInfo']: """ Returns data iterators for training and validation data. @@ -786,7 +789,7 @@ def get_training_data_iters(sources: List[str], :param batch_size: Batch size. :param batch_by_words: Size batches by words rather than sentences. :param batch_num_devices: Number of devices batches will be parallelized across. - :param fill_up: Fill-up strategy for buckets. + :param fill_up: Fill-up policy for buckets. :param max_seq_len_source: Maximum source sequence length. :param max_seq_len_target: Maximum target sequence length. :param bucketing: Whether to use bucketing. @@ -806,7 +809,7 @@ def get_training_data_iters(sources: List[str], sources_sentences, target_sentences = create_sequence_readers(sources, target, source_vocabs, target_vocab) - # 2. pass: Get data statistics + # Pass 2: Get data statistics and determine the number of data points for each bucket. data_statistics = get_data_statistics(sources_sentences, target_sentences, buckets, length_statistics.length_ratio_mean, length_statistics.length_ratio_std, source_vocabs, target_vocab) @@ -819,6 +822,7 @@ def get_training_data_iters(sources: List[str], data_statistics.log(bucket_batch_sizes) + # Pass 3: Load the data into memory and return the iterator. data_loader = RawParallelDatasetLoader(buckets=buckets, eos_id=target_vocab[C.EOS_SYMBOL], pad_id=C.PAD_ID) @@ -843,19 +847,22 @@ def get_training_data_iters(sources: List[str], buckets=buckets, batch_size=batch_size, bucket_batch_sizes=bucket_batch_sizes, - num_factors=len(sources)) - - validation_iter = get_validation_data_iter(data_loader=data_loader, - validation_sources=validation_sources, - validation_target=validation_target, - buckets=buckets, - bucket_batch_sizes=bucket_batch_sizes, - source_vocabs=source_vocabs, - target_vocab=target_vocab, - max_seq_len_source=max_seq_len_source, - max_seq_len_target=max_seq_len_target, - batch_size=batch_size, - fill_up=fill_up) + num_factors=len(sources), + permute=permute) + + validation_iter = None + if validation_sources is not None and validation_target is not None: + validation_iter = get_validation_data_iter(data_loader=data_loader, + validation_sources=validation_sources, + validation_target=validation_target, + buckets=buckets, + bucket_batch_sizes=bucket_batch_sizes, + source_vocabs=source_vocabs, + target_vocab=target_vocab, + max_seq_len_source=max_seq_len_source, + max_seq_len_target=max_seq_len_target, + batch_size=batch_size, + fill_up=fill_up) return train_iter, validation_iter, config_data, data_info @@ -1022,6 +1029,22 @@ def ids2strids(ids: Iterable[int]) -> str: return " ".join(map(str, ids)) +def ids2tokens(token_ids: Iterable[int], + vocab_inv: Dict[int, str], + exclude_set: Set[int] = set()) -> Iterator[str]: + """ + Transforms a list of token IDs into a list of words, exluding any IDs in `exclude_set`. + + :param token_ids: The list of token IDs. + :param vocab_inv: The inverse vocabulary. + :param exclude_set: The list of token IDs to exclude. + :return: The list of words. +""" + + tokens = [vocab_inv[token] for token in token_ids] + return (tok for token_id, tok in zip(token_ids, tokens) if token_id not in exclude_set) + + class SequenceReader(Iterable): """ Reads sequence samples from path and (optionally) creates integer id sequences. @@ -1235,13 +1258,13 @@ def load(fname: str) -> 'ParallelDataSet': def fill_up(self, bucket_batch_sizes: List[BucketBatchSize], - fill_up: str, + policy: str, seed: int = 42) -> 'ParallelDataSet': """ - Returns a new dataset with buckets filled up using the specified fill-up strategy. + Returns a new dataset with buckets filled up using the specified fill-up policy. :param bucket_batch_sizes: Bucket batch sizes. - :param fill_up: Fill-up strategy. + :param policy: Fill-up policy. :param seed: The random seed used for sampling sentences to fill up. :return: New dataset with buckets filled up to the next multiple of batch size """ @@ -1259,26 +1282,44 @@ def fill_up(self, bucket_label = self.label[bucket_idx] num_samples = bucket_source.shape[0] + # Fill up the last batch by randomly sampling from the extant items. + # If we're using the 'zeros' policy, these are overwritten later below. if num_samples % bucket_batch_size != 0: - if fill_up == 'replicate': - rest = bucket_batch_size - num_samples % bucket_batch_size - logger.info("Replicating %d random samples from %d samples in bucket %s " - "to size it to multiple of %d", - rest, num_samples, bucket, bucket_batch_size) - random_indices_np = rs.randint(num_samples, size=rest) - random_indices = mx.nd.array(random_indices_np) - if isinstance(source[bucket_idx], np.ndarray): - source[bucket_idx] = np.concatenate((bucket_source, bucket_source.take(random_indices_np)), axis=0) - else: - source[bucket_idx] = mx.nd.concat(bucket_source, bucket_source.take(random_indices), dim=0) - target[bucket_idx] = mx.nd.concat(bucket_target, bucket_target.take(random_indices), dim=0) - label[bucket_idx] = mx.nd.concat(bucket_label, bucket_label.take(random_indices), dim=0) + if policy == C.FILL_UP_ZEROS: + logger.info("Filling bucket %s from size %d to %d with zeros", + bucket, num_samples, bucket_batch_size) + elif policy == C.FILL_UP_REPLICATE: + logger.info("Filling bucket %s from size %d to %d by sampling with replacement", + bucket, num_samples, bucket_batch_size) + else: + raise NotImplementedError('Unknown fill-up policy') + + rest = bucket_batch_size - num_samples % bucket_batch_size + desired_indices_np = rs.randint(num_samples, size=rest) + desired_indices = mx.nd.array(desired_indices_np) + + if isinstance(source[bucket_idx], np.ndarray): + source[bucket_idx] = np.concatenate((bucket_source, bucket_source.take(desired_indices_np)), axis=0) else: - raise NotImplementedError('Unknown fill-up strategy') + source[bucket_idx] = mx.nd.concat(bucket_source, bucket_source.take(desired_indices), dim=0) + target[bucket_idx] = mx.nd.concat(bucket_target, bucket_target.take(desired_indices), dim=0) + label[bucket_idx] = mx.nd.concat(bucket_label, bucket_label.take(desired_indices), dim=0) + + if policy == C.FILL_UP_ZEROS: + source[bucket_idx][num_samples:, :, :] = C.PAD_ID + target[bucket_idx][num_samples:, :] = C.PAD_ID + label[bucket_idx][num_samples:, :] = C.PAD_ID return ParallelDataSet(source, target, label) def permute(self, permutations: List[mx.nd.NDArray]) -> 'ParallelDataSet': + """ + Permutes the data within each bucket. The permutation is received as an argument, + allowing the data to be unpermuted (i.e., restored) later on. + + :param permutations: For each bucket, a permutation of the data within that bucket. + :return: A new, permuted ParallelDataSet. + """ assert len(self) == len(permutations) source = [] target = [] @@ -1332,6 +1373,8 @@ def get_batch_indices(data: ParallelDataSet, Returns a list of index tuples that index into the bucket and the start index inside a bucket given the batch size for a bucket. These indices are valid for the given dataset. + Put another way, this returns the starting points for all batches within the dataset, across all buckets. + :param data: Data to create indices for. :param bucket_batch_sizes: Bucket batch sizes. :return: List of 2d indices. @@ -1357,17 +1400,27 @@ class MetaBaseParallelSampleIter(ABC): class BaseParallelSampleIter(mx.io.DataIter): """ Base parallel sample iterator. + + :param buckets: The list of buckets. + :param bucket_batch_sizes: A list, parallel to `buckets`, containing the number of samples in each bucket. + :param source_data_name: The source data name. + :param target_data_name: The target data name. + :param label_name: The label name. + :param num_factors: The number of source factors. + :param permute: Randomly shuffle the parallel data. + :param dtype: The MXNet data type. """ __metaclass__ = MetaBaseParallelSampleIter def __init__(self, - buckets, - batch_size, - bucket_batch_sizes, - source_data_name, - target_data_name, - label_name, + buckets: List[Tuple[int, int]], + batch_size: int, + bucket_batch_sizes: List[BucketBatchSize], + source_data_name: str, + target_data_name: str, + label_name: str, num_factors: int = 1, + permute: bool = True, dtype='float32') -> None: super().__init__(batch_size=batch_size) @@ -1378,6 +1431,7 @@ def __init__(self, self.target_data_name = target_data_name self.label_name = label_name self.num_factors = num_factors + self.permute = permute self.dtype = dtype # "Staging area" that needs to fit any size batch we're using by total number of elements. @@ -1440,10 +1494,11 @@ def __init__(self, target_data_name=C.TARGET_NAME, label_name=C.TARGET_LABEL_NAME, num_factors: int = 1, + permute: bool = True, dtype='float32') -> None: super().__init__(buckets=buckets, batch_size=batch_size, bucket_batch_sizes=bucket_batch_sizes, source_data_name=source_data_name, target_data_name=target_data_name, - label_name=label_name, num_factors=num_factors, dtype=dtype) + label_name=label_name, num_factors=num_factors, permute=permute, dtype=dtype) assert len(shards_fnames) > 0 self.shards_fnames = list(shards_fnames) self.shard_index = -1 @@ -1455,7 +1510,7 @@ def _load_shard(self): shard_fname = self.shards_fnames[self.shard_index] logger.info("Loading shard %s.", shard_fname) dataset = ParallelDataSet.load(self.shards_fnames[self.shard_index]).fill_up(self.bucket_batch_sizes, - self.fill_up, + policy=self.fill_up, seed=self.shard_index) self.shard_iter = ParallelSampleIter(data=dataset, buckets=self.buckets, @@ -1463,7 +1518,8 @@ def _load_shard(self): bucket_batch_sizes=self.bucket_batch_sizes, source_data_name=self.source_data_name, target_data_name=self.target_data_name, - num_factors=self.num_factors) + num_factors=self.num_factors, + permute=self.permute) def reset(self): if len(self.shards_fnames) > 1: @@ -1531,18 +1587,21 @@ def __init__(self, target_data_name=C.TARGET_NAME, label_name=C.TARGET_LABEL_NAME, num_factors: int = 1, + permute: bool = True, dtype='float32') -> None: super().__init__(buckets=buckets, batch_size=batch_size, bucket_batch_sizes=bucket_batch_sizes, source_data_name=source_data_name, target_data_name=target_data_name, - label_name=label_name, num_factors=num_factors, dtype=dtype) + label_name=label_name, num_factors=num_factors, permute=permute, dtype=dtype) # create independent lists to be shuffled self.data = ParallelDataSet(list(data.source), list(data.target), list(data.label)) - # create index tuples (buck_idx, batch_start_pos) into buckets. These will be shuffled. + # create index tuples (buck_idx, batch_start_pos) into buckets. + # This is the list of all batches across all buckets in the dataset. These will be shuffled. self.batch_indices = get_batch_indices(self.data, bucket_batch_sizes) self.curr_batch_index = 0 + # Produces a permutation of the batches within each bucket, along with the permutation that inverts it. self.inverse_data_permutations = [mx.nd.arange(0, max(1, self.data.source[i].shape[0])) for i in range(len(self.data))] self.data_permutations = [mx.nd.arange(0, max(1, self.data.source[i].shape[0])) @@ -1555,15 +1614,16 @@ def reset(self): Resets and reshuffles the data. """ self.curr_batch_index = 0 - # shuffle batch start indices - random.shuffle(self.batch_indices) - - # restore - self.data = self.data.permute(self.inverse_data_permutations) + if self.permute: + # shuffle batch start indices + random.shuffle(self.batch_indices) - self.data_permutations, self.inverse_data_permutations = get_permutations(self.data.get_bucket_counts()) + # restore the data permutation + self.data = self.data.permute(self.inverse_data_permutations) - self.data = self.data.permute(self.data_permutations) + # permute the data within each batch + self.data_permutations, self.inverse_data_permutations = get_permutations(self.data.get_bucket_counts()) + self.data = self.data.permute(self.data_permutations) def iter_next(self) -> bool: """ @@ -1592,7 +1652,7 @@ def next(self) -> mx.io.DataBatch: provide_label = [mx.io.DataDesc(name=n, shape=x.shape, layout=C.BATCH_MAJOR) for n, x in zip(self.label_names, label)] - # TODO: num pad examples is not set here if fillup strategy would be padding + # TODO: num pad examples is not set here if fillup policy would be padding return mx.io.DataBatch(data, label, pad=0, index=None, bucket_key=self.buckets[i], provide_data=provide_data, provide_label=provide_label) diff --git a/sockeye/inference.py b/sockeye/inference.py index d3918655c..5fd81f8c7 100644 --- a/sockeye/inference.py +++ b/sockeye/inference.py @@ -1320,9 +1320,8 @@ def _make_result(self, attention_matrix = translation.attention_matrix target_tokens = [self.vocab_target_inv[target_id] for target_id in target_ids] + target_string = C.TOKEN_SEPARATOR.join(data_io.ids2tokens(target_ids, self.vocab_target_inv, self.strip_ids)) - target_string = C.TOKEN_SEPARATOR.join( - tok for target_id, tok in zip(target_ids, target_tokens) if target_id not in self.strip_ids) attention_matrix = attention_matrix[:, :len(trans_input.tokens)] return TranslatorOutput(sentence_id=trans_input.sentence_id, diff --git a/sockeye/output_handler.py b/sockeye/output_handler.py index 4b877b28d..330e02d21 100644 --- a/sockeye/output_handler.py +++ b/sockeye/output_handler.py @@ -23,8 +23,8 @@ def get_output_handler(output_type: str, - output_fname: Optional[str], - sure_align_threshold: float) -> 'OutputHandler': + output_fname: Optional[str] = None, + sure_align_threshold: float = 1.0) -> 'OutputHandler': """ :param output_type: Type of output handler. @@ -36,6 +36,10 @@ def get_output_handler(output_type: str, output_stream = sys.stdout if output_fname is None else data_io.smart_open(output_fname, mode='w') if output_type == C.OUTPUT_HANDLER_TRANSLATION: return StringOutputHandler(output_stream) + elif output_type == C.OUTPUT_HANDLER_SCORE: + return ScoreOutputHandler(output_stream) + elif output_type == C.OUTPUT_HANDLER_PAIR_WITH_SCORE: + return PairWithScoreOutputHandler(output_stream) elif output_type == C.OUTPUT_HANDLER_TRANSLATION_WITH_SCORE: return StringWithScoreOutputHandler(output_stream) elif output_type == C.OUTPUT_HANDLER_TRANSLATION_WITH_ALIGNMENTS: @@ -119,6 +123,54 @@ def handle(self, self.stream.flush() +class ScoreOutputHandler(OutputHandler): + """ + Output handler to write translation score to a stream. + + :param stream: Stream to write translations to (e.g., sys.stdout). + """ + + def __init__(self, stream): + self.stream = stream + + def handle(self, + t_input: inference.TranslatorInput, + t_output: inference.TranslatorOutput, + t_walltime: float = 0.): + """ + :param t_input: Translator input. + :param t_output: Translator output. + :param t_walltime: Total walltime for translation. + """ + self.stream.write("{:.3f}\n".format(t_output.score)) + self.stream.flush() + + +class PairWithScoreOutputHandler(OutputHandler): + """ + Output handler to write translation score along with sentence input and output (tab-delimited). + + :param stream: Stream to write translations to (e.g., sys.stdout). + """ + + def __init__(self, stream): + self.stream = stream + + def handle(self, + t_input: inference.TranslatorInput, + t_output: inference.TranslatorOutput, + t_walltime: float = 0.): + """ + :param t_input: Translator input. + :param t_output: Translator output. + :param t_walltime: Total walltime for translation. + """ + self.stream.write("{:.3f}\t{}\t{}\n".format(t_output.score, + C.TOKEN_SEPARATOR.join(t_input.tokens), + t_output.translation)) + self.stream.flush() + + class StringWithAlignmentsOutputHandler(StringOutputHandler): """ Output handler to write translations and alignments to a stream. Translation and alignment string diff --git a/sockeye/score.py b/sockeye/score.py new file mode 100644 index 000000000..451c4838e --- /dev/null +++ b/sockeye/score.py @@ -0,0 +1,169 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not +# use this file except in compliance with the License. A copy of the License +# is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +""" +Simple Training CLI. +""" +import argparse +import os +import sys +from contextlib import ExitStack +from typing import Any, cast, Optional, Dict, List, Tuple + +import mxnet as mx + +from . import arguments +from . import constants as C +from . import data_io +from . import inference +from . import model +from . import scoring +from . import train +from . import utils +from . import vocab +from .log import setup_main_logger +from .output_handler import get_output_handler, OutputHandler +from .utils import check_condition, log_basic_info + +# Temporary logger, the real one (logging to a file probably, will be created in the main function) +logger = setup_main_logger(__name__, file_logging=False, console=True) + + +def main(): + params = arguments.ConfigArgumentParser(description='Score data with an existing model.') + arguments.add_score_cli_args(params) + args = params.parse_args() + score(args) + + +def get_data_iters_and_vocabs(args: argparse.Namespace, + model_folder: Optional[str]) -> Tuple['data_io.BaseParallelSampleIter', + 'data_io.DataConfig', + List[vocab.Vocab], vocab.Vocab, model.ModelConfig]: + """ + Loads the data iterators and vocabularies. + + :param args: Arguments as returned by argparse. + :param max_seq_len_source: Source maximum sequence length. + :param max_seq_len_target: Target maximum sequence length. + :param shared_vocab: Whether to create a shared vocabulary. + :param resume_training: Whether to resume training. + :param model_folder: Output folder. + :return: The data iterators (train, validation, config_data) as well as the source and target vocabularies, and data_info if not using prepared data. + """ + + model_config = model.SockeyeModel.load_config(os.path.join(args.model, C.CONFIG_NAME)) + + if args.max_seq_len is None: + max_seq_len_source = model_config.config_data.max_seq_len_source + max_seq_len_target = model_config.config_data.max_seq_len_target + else: + max_seq_len_source, max_seq_len_target = args.max_seq_len + + + batch_num_devices = 1 if args.use_cpu else sum(-di if di < 0 else 1 for di in args.device_ids) + batch_by_words = args.batch_type == C.BATCH_TYPE_WORD + + # Load the existing vocabs created when starting the training run. + source_vocabs = vocab.load_source_vocabs(model_folder) + target_vocab = vocab.load_target_vocab(model_folder) + + sources = [args.source] + args.source_factors + sources = [str(os.path.abspath(source)) for source in sources] + + train_iter, _, config_data, data_info = data_io.get_training_data_iters( + sources=sources, + target=os.path.abspath(args.target), + validation_sources=None, + validation_target=None, + source_vocabs=source_vocabs, + target_vocab=target_vocab, + source_vocab_paths=None, + target_vocab_path=None, + shared_vocab=False, + batch_size=args.batch_size, + batch_by_words=batch_by_words, + batch_num_devices=batch_num_devices, + fill_up=C.FILL_UP_ZEROS, + permute=False, + max_seq_len_source=max_seq_len_source, + max_seq_len_target=max_seq_len_target, + bucketing=False, + bucket_width=args.bucket_width) + + return train_iter, config_data, source_vocabs, target_vocab, model_config + + +def score(args: argparse.Namespace): + global logger + logger = setup_main_logger(__name__, file_logging=False) + + utils.log_basic_info(args) + + with ExitStack() as exit_stack: + context = utils.determine_context(device_ids=args.device_ids, + use_cpu=args.use_cpu, + disable_device_locking=args.disable_device_locking, + lock_dir=args.lock_dir, + exit_stack=exit_stack) + if args.batch_type == C.BATCH_TYPE_SENTENCE: + check_condition(args.batch_size % len(context) == 0, "When using multiple devices the batch size must be " + "divisible by the number of devices. Choose a batch " + "size that is a multiple of %d." % len(context)) + logger.info("Scoring Device(s): %s", ", ".join(str(c) for c in context)) + + # This call has a number of different parameters compared to training which reflect our need to get scores + # one-for-one and in the same order as the input data. + # To enable code reuse, we stuff the `args` parameter with some values. + # Bucketing and permuting need to be turned off in order to preserve the ordering of sentences. + # The 'zeros' fill_up strategy fills underfilled buckets with zeros which can then be used to find the last item. + # Finally, 'resume_training' needs to be set to True because it causes the model to be loaded instead of initialized. + args.no_bucketing = True + args.fill_up = 'zeros' + args.bucket_width = 10 + score_iter, config_data, source_vocabs, target_vocab, model_config = get_data_iters_and_vocabs( + args=args, + model_folder=args.model) + + scoring_model = scoring.ScoringModel(config=model_config, + model_dir=args.model, + context=context, + provide_data=score_iter.provide_data, + provide_label=score_iter.provide_label, + default_bucket_key=score_iter.default_bucket_key, + score_type=args.score_type, + bucketing=False, + length_penalty=inference.LengthPenalty(alpha=args.length_penalty_alpha, + beta=args.length_penalty_beta), + softmax_temperature=args.softmax_temperature) + + scorer = scoring.Scorer(scoring_model, source_vocabs, target_vocab) + + scorer.score(score_iter=score_iter, + score_type=args.score_type, + output_handler=get_output_handler(output_type=args.output_type, + output_fname=args.output)) + + if config_data.data_statistics.num_discarded != 0: + num_discarded = config_data.data_statistics.num_discarded + logger.warning('Warning: %d %s longer than %s %s skipped. ' + 'As a result, the output won\'t be parallel with the input. ' + 'Increase the maximum length (--max-seq-len M:N) or trim your training data.', + num_discarded, + utils.inflect('sentence', num_discarded), + args.max_seq_len, + utils.inflect('was', num_discarded)) + + +if __name__ == "__main__": + main() diff --git a/sockeye/scoring.py b/sockeye/scoring.py new file mode 100644 index 000000000..1ca15e3b6 --- /dev/null +++ b/sockeye/scoring.py @@ -0,0 +1,266 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not +# use this file except in compliance with the License. A copy of the License +# is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +""" +Code for scoring. +""" +import logging +import multiprocessing as mp +import os +import pickle +import random +import shutil +import time +from functools import reduce +from typing import Any, Dict, List, Optional, Tuple, Union + +import mxnet as mx +import numpy as np +from math import sqrt + +from . import constants as C +from . import data_io +from . import inference +from . import model +from . import utils +from . import vocab + +from .output_handler import OutputHandler +from .inference import TranslatorInput, TranslatorOutput + +logger = logging.getLogger(__name__) + + +class ScoringModel(model.SockeyeModel): + """ + ScoringModel is a TrainingModel (which is in turn a SockeyeModel) that scores a pair of sentences. + That is, it full unrolls over source and target sequences, running the encoder and decoder, but stopping short of computing a loss and backpropagating. + It is analogous to TrainingModel, but more limited. + + :param config: Configuration object holding details about the model. + :param model_dir: Directory containing the trained model. + :param context: The context(s) that MXNet will be run in (GPU(s)/CPU). + :param provide_data: List of input data descriptions. + :param provide_label: List of label descriptions. + :param default_bucket_key: Default bucket key. + :param score_type: The type of score to output (negative logprob or logprob). + :param length_penalty: The length penalty class to use. + """ + + def __init__(self, + config: model.ModelConfig, + model_dir: str, + context: List[mx.context.Context], + provide_data: List[mx.io.DataDesc], + provide_label: List[mx.io.DataDesc], + bucketing: bool, + default_bucket_key: Tuple[int, int], + score_type: str, + length_penalty: inference.LengthPenalty, + softmax_temperature: Optional[float] = None) -> None: + super().__init__(config) + self.context = context + self.bucketing = bucketing + self.score_type = score_type + self.length_penalty = length_penalty + self.softmax_temperature = softmax_temperature + + # Create the computation graph + self._initialize(provide_data, provide_label, default_bucket_key) + + # Load model parameters into graph + params_fname = os.path.join(model_dir, C.PARAMS_BEST_NAME) + super().load_params_from_file(params_fname) + self.module.set_params(arg_params=self.params, + aux_params=self.aux_params, + allow_missing=False) + + def _initialize(self, + provide_data: List[mx.io.DataDesc], + provide_label: List[mx.io.DataDesc], + default_bucket_key: Tuple[int, int]) -> None: + """ + Initializes model components, creates scoring symbol and module, and binds it. + + :param provide_data: List of data descriptors. + :param provide_label: List of label descriptors. + :param default_bucket_key: The default maximum (source, target) lengths. + """ + source = mx.sym.Variable(C.SOURCE_NAME) + source_words = source.split(num_outputs=self.config.config_embed_source.num_factors, + axis=2, squeeze_axis=True)[0] + source_length = utils.compute_lengths(source_words) + target = mx.sym.Variable(C.TARGET_NAME) + target_length = utils.compute_lengths(target) + + # labels shape: (batch_size, target_length) (usually the maximum target sequence length) + labels = mx.sym.Variable(C.TARGET_LABEL_NAME) + + data_names = [C.SOURCE_NAME, C.TARGET_NAME] + label_names = [C.TARGET_LABEL_NAME] + + # check provide_{data,label} names + provide_data_names = [d[0] for d in provide_data] + utils.check_condition(provide_data_names == data_names, + "incompatible provide_data: %s, names should be %s" % (provide_data_names, data_names)) + provide_label_names = [d[0] for d in provide_label] + utils.check_condition(provide_label_names == label_names, + "incompatible provide_label: %s, names should be %s" % (provide_label_names, label_names)) + + def sym_gen(seq_lens): + """ + Returns a (grouped) symbol containing the summed score for each sentence, as well as the entire target distributions for each word. + Also returns data and label names for the BucketingModule. + """ + source_seq_len, target_seq_len = seq_lens + + # source embedding + (source_embed, + source_embed_length, + source_embed_seq_len) = self.embedding_source.encode(source, source_length, source_seq_len) + + # target embedding + (target_embed, + target_embed_length, + target_embed_seq_len) = self.embedding_target.encode(target, target_length, target_seq_len) + + # encoder + # source_encoded: (batch_size, source_encoded_length, encoder_depth) + (source_encoded, + source_encoded_length, + source_encoded_seq_len) = self.encoder.encode(source_embed, + source_embed_length, + source_embed_seq_len) + + # decoder + # target_decoded: (batch-size, target_len, decoder_depth) + target_decoded = self.decoder.decode_sequence(source_encoded, source_encoded_length, source_encoded_seq_len, + target_embed, target_embed_length, target_embed_seq_len) + + # output layer + # logits: (batch_size * target_seq_len, target_vocab_size) + logits = self.output_layer(mx.sym.reshape(data=target_decoded, shape=(-3, 0))) + # logits after reshape: (batch_size, target_seq_len, target_vocab_size) + logits = mx.sym.reshape(data=logits, shape=(-4, -1, target_embed_seq_len, 0)) + + if self.softmax_temperature is not None: + logits = logits / self.softmax_temperature + + # Compute the softmax along the final dimension. + # target_dists: (batch_size, target_seq_len, target_vocab_size) + target_dists = mx.sym.softmax(data=logits, axis=2, name=C.SOFTMAX_NAME) + + # Select the label probability, then take their logs. + # probs and scores: (batch_size, target_seq_len) + probs = mx.sym.pick(target_dists, labels) + scores = mx.sym.log(probs) + if self.score_type == C.SCORING_TYPE_NEGLOGPROB: + scores = -1 * scores + + # Sum, then apply length penalty. The call to `mx.sym.where` masks out invalid values from scores. + # zeros and sums: (batch_size,) + zeros = mx.sym.zeros_like(scores) + sums = mx.sym.sum(mx.sym.where(labels != 0, scores, zeros), axis=1) / (self.length_penalty(target_length - 1)) + + # Return the sums and the target distributions + # sums: (batch_size,) target_dists: (batch_size, target_seq_len, target_vocab_size) + return mx.sym.Group([sums, target_dists]), data_names, label_names + + if self.bucketing: + logger.info("Using bucketing. Default max_seq_len=%s", default_bucket_key) + self.module = mx.mod.BucketingModule(sym_gen=sym_gen, + logger=logger, + default_bucket_key=default_bucket_key, + context=self.context) + else: + symbol, _, __ = sym_gen(default_bucket_key) + self.module = mx.mod.Module(symbol=symbol, + data_names=data_names, + label_names=label_names, + logger=logger, + context=self.context) + + self.module.bind(data_shapes=provide_data, + label_shapes=provide_label, + for_training=False, + force_rebind=False, + grad_req='null') + + def run(self, batch: mx.io.DataBatch) -> List[mx.nd.NDArray]: + """ + Runs the forward pass and returns the outputs. + + :param batch: The batch to run. + :return: The grouped symbol (probs and target dists) and lists containing the data names and label names. + """ + self.module.forward(batch, is_train=False) + return self.module.get_outputs() + + +class Scorer: + """ + Scorer class takes a ScoringModel and uses it to score a stream of parallel sentences. + It also takes the vocabularies so that the original sentences can be printed out, if desired. + + :param model: The model to score with. + :param source_vocabs: The source vocabularies. + :param target_vocab: The target vocabulary. + """ + def __init__(self, + model: ScoringModel, + source_vocabs: List[vocab.Vocab], + target_vocab: vocab.Vocab) -> None: + self.source_vocab_inv = vocab.reverse_vocab(source_vocabs[0]) + self.target_vocab_inv = vocab.reverse_vocab(target_vocab) + self.model = model + + self.exclude_list = set([source_vocabs[0][C.BOS_SYMBOL], target_vocab[C.EOS_SYMBOL], C.PAD_ID]) + + def score(self, + score_iter, + score_type: str, + output_handler: OutputHandler): + + total_time = 0. + tic = time.time() + sentence_no = 0 + for i, batch in enumerate(score_iter): + + batch_tic = time.time() + + # Run the model and get the outputs + scores = self.model.run(batch)[0] + + batch_time = time.time() - batch_tic + total_time += batch_time + + for source, target, score in zip(batch.data[0], batch.data[1], scores): + + # The "zeros" padding method will have filled remainder batches with zeros, so we can skip them here + if source[0][0] == C.PAD_ID: + break + + sentence_no += 1 + + # Transform arguments in preparation for printing + source_ids = [int(x) for x in source[:, 0].asnumpy().tolist()] + source_tokens = list(data_io.ids2tokens(source_ids, self.source_vocab_inv, self.exclude_list)) + target_ids = [int(x) for x in target.asnumpy().tolist()] + target_string = C.TOKEN_SEPARATOR.join( + data_io.ids2tokens(target_ids, self.target_vocab_inv, self.exclude_list)) + score = score.asscalar() + + # Output handling routines require us to make use of inference classes. + output_handler.handle(TranslatorInput(sentence_no, source_tokens), + TranslatorOutput(sentence_no, target_string, None, None, score), + batch_time) diff --git a/sockeye/train.py b/sockeye/train.py index 49bd06d4d..c22f21361 100644 --- a/sockeye/train.py +++ b/sockeye/train.py @@ -244,6 +244,7 @@ def create_data_iters_and_vocabs(args: argparse.Namespace, validation_sources = [args.validation_source] + args.validation_source_factors validation_sources = [str(os.path.abspath(source)) for source in validation_sources] + validation_target = str(os.path.abspath(args.validation_target)) either_raw_or_prepared_error_msg = "Either specify a raw training corpus with %s and %s or a preprocessed corpus " \ "with %s." % (C.TRAINING_ARG_SOURCE, @@ -258,7 +259,7 @@ def create_data_iters_and_vocabs(args: argparse.Namespace, train_iter, validation_iter, data_config, source_vocabs, target_vocab = data_io.get_prepared_data_iters( prepared_data_dir=args.prepared_data, validation_sources=validation_sources, - validation_target=str(os.path.abspath(args.validation_target)), + validation_target=validation_target, shared_vocab=shared_vocab, batch_size=args.batch_size, batch_by_words=batch_by_words, @@ -332,7 +333,7 @@ def create_data_iters_and_vocabs(args: argparse.Namespace, sources=sources, target=os.path.abspath(args.target), validation_sources=validation_sources, - validation_target=os.path.abspath(args.validation_target), + validation_target=validation_target, source_vocabs=source_vocabs, target_vocab=target_vocab, source_vocab_paths=source_vocab_paths, diff --git a/sockeye/utils.py b/sockeye/utils.py index e6d2a2ae3..875b53f53 100644 --- a/sockeye/utils.py +++ b/sockeye/utils.py @@ -941,3 +941,20 @@ def split(data: mx.nd.NDArray, if num_outputs == 1: return [ndarray_or_list] return ndarray_or_list + + +def inflect(word: str, + count: int): + """ + Minimal inflection module. + + :param word: The word to inflect. + :param count: The count. + :return: The word, perhaps inflected for number. + """ + if word in ['time', 'sentence']: + return word if count == 1 else word + 's' + elif word == 'was': + return 'was' if count == 1 else 'were' + else: + return word + '(s)' diff --git a/test/common.py b/test/common.py index d137b8ba7..aa7ff553a 100644 --- a/test/common.py +++ b/test/common.py @@ -30,7 +30,9 @@ import sockeye.evaluate import sockeye.extract_parameters import sockeye.lexicon +import sockeye.model import sockeye.prepare_data +import sockeye.score import sockeye.train import sockeye.translate import sockeye.utils @@ -210,6 +212,10 @@ def tmp_digits_dataset(prefix: str, _TRANSLATE_PARAMS_RESTRICT = "--restrict-lexicon {lexicon} --restrict-lexicon-topk {topk}" +_SCORE_PARAMS_COMMON = "--use-cpu --model {model} --source {source} --target {target} --output {output}" + +_SCORE_WITH_FACTORS_COMMON = " --source-factors {source_factors}" + _EVAL_PARAMS_COMMON = "--hypotheses {hypotheses} --references {references} --metrics {metrics} {quiet}" _EXTRACT_PARAMS = "--input {input} --names target_output_bias --list-all --output {output}" @@ -334,12 +340,13 @@ def run_train_translate(train_params: str, logger.info("Translating with parameters %s.", translate_params) # Translate corpus with the 1st params out_path = os.path.join(work_dir, "out.txt") - params = "{} {} {}".format(sockeye.translate.__file__, - _TRANSLATE_PARAMS_COMMON.format(model=model_path, - input=test_source_path, - output=out_path, - quiet=quiet_arg), - translate_params) + translate_score_path = os.path.join(work_dir, "out.scores.txt") + params = "{} {} {} --output-type translation_with_score".format(sockeye.translate.__file__, + _TRANSLATE_PARAMS_COMMON.format(model=model_path, + input=test_source_path, + output=out_path, + quiet=quiet_arg), + translate_params) if test_source_factor_paths is not None: params += _TRANSLATE_WITH_FACTORS_COMMON.format(input_factors=" ".join(test_source_factor_paths)) @@ -347,6 +354,21 @@ def run_train_translate(train_params: str, with patch.object(sys, "argv", params.split()): sockeye.translate.main() + # Break out translation and score + with open(out_path) as out_fh: + outputs = out_fh.readlines() + with open(out_path, 'w') as out_translate, open(translate_score_path, 'w') as out_scores: + for output in outputs: + output = output.strip() + # blank lines on test input will have only one field output (-inf for the score) + try: + score, translation = output.split('\t') + except ValueError: + score = output + translation = "" + print(translation, file=out_translate) + print(score, file=out_scores) + # Test target constraints if use_target_constraints: """ @@ -403,6 +425,89 @@ def run_train_translate(train_params: str, # for negative constraints, ensure the constraints is *not* in the constrained output assert restriction not in constrained_out + # Test scoring by ensuring that the sockeye.scoring module produces the same scores when scoring the output + # of sockeye.translate. However, since this training is on very small datasets, the output of sockeye.translate + # is often pure garbage or empty and cannot be scored. So we only try to score if we have some valid output + # to work with. + + # Skip if there are invalid tokens in the output, or if no valid outputs were found + translate_output_is_valid = True + with open(out_path) as out_fh: + sentences = list(map(lambda x: x.rstrip(), out_fh.readlines())) + # At least one output must be non-empty + found_valid_output = any(sentences) + + # There must be no bad tokens + found_bad_tokens = any([bad_token in ' '.join(sentences) for bad_token in C.VOCAB_SYMBOLS]) + + translate_output_is_valid = found_valid_output and not found_bad_tokens + + # Only run scoring under these conditions. Why? + # - scoring isn't compatible with prepared data because that loses the source ordering + # - scoring doesn't support skipping softmax (which can be enabled explicitly or implicitly by using a beam size of 1) + # - translate splits up too-long sentences and translates them in sequence, invalidating the score, so skip that + # - scoring requires valid translation output to compare against + if not use_prepared_data \ + and '--beam-size 1' not in translate_params \ + and '--max-input-len' not in translate_params \ + and translate_output_is_valid: + + ## Score + # We use the translation parameters, but have to remove irrelevant arguments from it. + # Currently, the only relevant flag passed is the --softmax-temperature flag. + score_params = '' + if 'softmax-temperature' in translate_params: + params = translate_params.split(C.TOKEN_SEPARATOR) + for i, param in enumerate(params): + if param == '--softmax-temperature': + score_params = '--softmax-temperature {}'.format(params[i + 1]) + break + + scores_output_file = out_path + '.score' + params = "{} {} {}".format(sockeye.score.__file__, + _SCORE_PARAMS_COMMON.format(model=model_path, + source=test_source_path, + target=out_path, + output=scores_output_file), + score_params) + + if test_source_factor_paths is not None: + params += _SCORE_WITH_FACTORS_COMMON.format(source_factors=" ".join(test_source_factor_paths)) + + with patch.object(sys, "argv", params.split()): + sockeye.score.main() + + # Compare scored output to original translation output. There are a few tricks: for blank source sentences, + # inference will report a score of -inf, so skip these. Second, we don't know if the scores include the + # generation of and have had length normalization applied. So, skip all sentences that are as long + # as the maximum length, in order to safely exclude them. + with open(translate_score_path) as in_translate, open(out_path) as in_words, open(scores_output_file) as in_score: + model_config = sockeye.model.SockeyeModel.load_config(os.path.join(model_path, C.CONFIG_NAME)) + max_len = model_config.config_data.max_seq_len_target + + # Filter out sockeye.translate sentences that had -inf or were too long (which sockeye.score will have skipped) + translate_scores = [] + translate_lens = [] + score_scores = in_score.readlines() + for score, sent in zip(in_translate.readlines(), in_words.readlines()): + if score != '-inf\n' and len(sent.split()) < max_len: + translate_scores.append(score) + translate_lens.append(len(sent.split())) + + assert len(translate_scores) == len(score_scores) + + # Compare scores (using 0.002 which covers common noise comparing e.g., 1.234 and 1.235) + for translate_score, translate_len, score_score in zip(translate_scores, translate_lens, score_scores): + # Skip sentences that are close to the maximum length to avoid confusion about whether + # the length penalty was applied + if translate_len >= max_len - 2: + continue + + translate_score = float(translate_score) + score_score = float(score_score) + + assert abs(translate_score - score_score) < 0.002 + # Translate corpus with the 2nd params if translate_params_equiv is not None: out_path_equiv = os.path.join(work_dir, "out_equiv.txt") diff --git a/test/unit/test_data_io.py b/test/unit/test_data_io.py index 608a2f0f2..66034d9e4 100644 --- a/test/unit/test_data_io.py +++ b/test/unit/test_data_io.py @@ -389,7 +389,7 @@ def test_get_batch_indices(): assert 0 <= start_pos < len(dataset.source[buck_idx]) - batch_size + 1 # check that all indices are used for a filled-up dataset - dataset = dataset.fill_up(bucket_batch_sizes, fill_up='replicate') + dataset = dataset.fill_up(bucket_batch_sizes, policy='replicate') indices = data_io.get_batch_indices(dataset, bucket_batch_sizes=bucket_batch_sizes) all_bucket_indices = set(list(range(len(dataset)))) computed_bucket_indices = set([i for i, j in indices]) diff --git a/tutorials/README.md b/tutorials/README.md index ac7e88c32..342c812f2 100644 --- a/tutorials/README.md +++ b/tutorials/README.md @@ -17,3 +17,4 @@ introduce different concepts and parameters used for training and translation. 1. [Domain adaptation of NMT models](adapt) 1. [Decoding with lexical constraints](constraints) 1. [Process per core translation](process_per_core_translation) +1. [Scoring](scoring.md) diff --git a/tutorials/constraints/README.md b/tutorials/constraints/README.md index 033d5e0ab..ef57fecc9 100644 --- a/tutorials/constraints/README.md +++ b/tutorials/constraints/README.md @@ -55,6 +55,8 @@ For example: This will output tab-delimited pairs of (score, translation). As always, don't forget to apply source- and target-side preprocessing to your input and your constraint. +However, it is probably better to use [Sockeye's scoring module](../scoring.md) directly, since it makes use of the training time computation graph and is therefore much faster. + ## Negative constraints Negative constraints---phrases that must *not* appear in the output---are also supported. diff --git a/tutorials/scoring.md b/tutorials/scoring.md new file mode 100644 index 000000000..a70062ac4 --- /dev/null +++ b/tutorials/scoring.md @@ -0,0 +1,38 @@ +# Scoring existing translations + +Sockeye provides a fast scoring module that permits the scoring of existing translations. +It works by making use of the training computation graph, but turning off caching of gradients and loss computation. +Just like when training models, the scorer works with raw plain-text data passed in via `--source` and `--target`. +It can easily therefore taken any pretrained model, just like in inference. + +## Example + +To score a source and target dataset, first make sure that all source and target preprocessing have been applied. +Then run this command: + + python3 -m sockeye.score -m MODEL --source SOURCE --target TARGET + +Sockeye will output a score (a negative log probability) for each sentence pair. + +## Command-line arguments + +The scorer takes a number of arguments: + +- `--score-type logprob`. Use this to get log probabilities instead of negative log probabilities. +- `--batch-size X`. Word-based batching is used. + You can use this flag to change the batch size from its default of 500. + If you run out of memory, try lowering this. +- `--output-type {score,pair_with_score}`. The output type: either the score alone, or the score with the translation pair. + Fields will be separated by a tab. +- `--max-seq-len M:N`. The maximum sequences length (`M` the source length, `N` the target). +- `--softmax-temperature X`. Scales the logits by dividing by this argument before computing softmax. +- `--length-penalty-alpha`, `--length-penalty-beta`. Parameters for length normalization. + Set `--length-penalty-alpha 0` to disable normalization. + +## Caveat emptor + +Some things to watch out for: + +- Scoring reads the maximum sentence lengths from the model. + Sentences longer than these will be skipped, meaning the scored output will not be parallel with the input. + A warning message will be printed to STDERR, but beware. diff --git a/typechecked-files b/typechecked-files index 69fcced40..3de9ef2df 100644 --- a/typechecked-files +++ b/typechecked-files @@ -26,6 +26,8 @@ sockeye/output_handler.py sockeye/prepare_data.py sockeye/rnn.py sockeye/rnn_attention.py +sockeye/score.py +sockeye/scoring.py sockeye/train.py sockeye/training.py sockeye/transformer.py @@ -39,4 +41,4 @@ sockeye/image_captioning/checkpoint_decoder.py sockeye/image_captioning/encoder.py sockeye/image_captioning/extract_features.py sockeye/image_captioning/utils.py -sockeye/image_captioning/visualize.py \ No newline at end of file +sockeye/image_captioning/visualize.py