Skip to content

Commit

Permalink
Scoring (awslabs#538)
Browse files Browse the repository at this point in the history
This implements scoring of translations given source, by fully reusing the training computation graph, per @bricksdont's original suggestion.
  • Loading branch information
mjpost authored and fhieber committed Sep 28, 2018
1 parent 3e554ad commit 5a50d96
Show file tree
Hide file tree
Showing 18 changed files with 885 additions and 93 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ Note that Sockeye has checks in place to not translate with an old model that wa
Each version section may have have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_.

## [1.18.57]
### Added
- Added `sockeye.score` CLI for quickly scoring existing translations ([documentation](tutorials/scoring.md)).
### Fixed
- Entry-point clean-up after the contrib/ rename

Expand Down
7 changes: 7 additions & 0 deletions docs/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,13 @@ sockeye.rnn_attention module
:members:
:show-inheritance:

sockeye.score module
--------------------

.. automodule:: sockeye.score
:members:
:show-inheritance:

sockeye.train module
--------------------

Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def get_requirements(filename):
'sockeye-lexicon = sockeye.lexicon:main',
'sockeye-init-embed = sockeye.init_embedding:main',
'sockeye-prepare-data = sockeye.prepare_data:main',
'sockeye-score = sockeye.score:main',
'sockeye-train = sockeye.train:main',
'sockeye-translate = sockeye.translate:main',
'sockeye-vocab = sockeye.vocab:main',
Expand Down
90 changes: 72 additions & 18 deletions sockeye/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,33 +760,37 @@ def add_model_parameters(params):
"(and all convolutional weight matrices for CNN decoders). Default: %(default)s.")


def add_batch_args(params, default_batch_size=4096):
params.add_argument('--batch-size', '-b',
type=int_greater_or_equal(1),
default=default_batch_size,
help='Mini-batch size. Note that depending on the batch-type this either refers to '
'words or sentences.'
'Sentence: each batch contains X sentences, number of words varies. '
'Word: each batch contains (approximately) X words, number of sentences varies. '
'Default: %(default)s.')
params.add_argument("--batch-type",
type=str,
default=C.BATCH_TYPE_WORD,
choices=[C.BATCH_TYPE_SENTENCE, C.BATCH_TYPE_WORD],
help="Sentence: each batch contains X sentences, number of words varies."
"Word: each batch contains (approximately) X target words, "
"number of sentences varies. Default: %(default)s.")


def add_training_args(params):
train_params = params.add_argument_group("Training parameters")

add_batch_args(train_params)

train_params.add_argument('--decoder-only',
action='store_true',
help='Pre-train a decoder. This is currently for RNN decoders only. '
'Default: %(default)s.')

train_params.add_argument('--batch-size', '-b',
type=int_greater_or_equal(1),
default=4096,
help='Mini-batch size. Note that depending on the batch-type this either refers to '
'words or sentences.'
'Sentence: each batch contains X sentences, number of words varies. '
'Word: each batch contains (approximately) X words, number of sentences varies. '
'Default: %(default)s.')
train_params.add_argument("--batch-type",
type=str,
default=C.BATCH_TYPE_WORD,
choices=[C.BATCH_TYPE_SENTENCE, C.BATCH_TYPE_WORD],
help="Sentence: each batch contains X sentences, number of words varies."
"Word: each batch contains (approximately) X target words, "
"number of sentences varies. Default: %(default)s.")

train_params.add_argument('--fill-up',
type=str,
default='replicate',
default=C.FILL_UP_DEFAULT,
choices=C.FILL_UP_CHOICES,
help=argparse.SUPPRESS)

train_params.add_argument('--loss',
Expand Down Expand Up @@ -1075,6 +1079,56 @@ def add_translate_cli_args(params):
add_logging_args(params)


def add_score_cli_args(params):
add_training_data_args(params, required=False)
add_vocab_args(params)
add_device_args(params)
add_logging_args(params)
add_batch_args(params, default_batch_size=500)

params = params.add_argument_group("Scoring parameters")

params.add_argument("--model", "-m", required=True,
help="Model directory containing trained model.")

params.add_argument('--max-seq-len',
type=multiple_values(num_values=2, greater_or_equal=1),
default=None,
help='Maximum sequence length in tokens.'
'Use "x:x" to specify separate values for src&tgt. Default: Read from model.')

params.add_argument('--length-penalty-alpha',
default=1.0,
type=float,
help='Alpha factor for the length penalty used in scoring: '
'(beta + len(Y))**alpha/(beta + 1)**alpha. A value of 0.0 will therefore turn off '
'length normalization. Default: %(default)s')

params.add_argument('--length-penalty-beta',
default=0.0,
type=float,
help='Beta factor for the length penalty used in scoring: '
'(beta + len(Y))**alpha/(beta + 1)**alpha. Default: %(default)s')

params.add_argument('--softmax-temperature',
type=float,
default=None,
help='Controls peakiness of model predictions. Values < 1.0 produce '
'peaked predictions, values > 1.0 produce smoothed distributions.')

params.add_argument("--output", "-o", default=None,
help="File to write output to. Default: STDOUT.")

params.add_argument('--output-type',
default=C.OUTPUT_HANDLER_SCORE,
choices=C.OUTPUT_HANDLERS_SCORING,
help='Output type. Default: %(default)s.')

params.add_argument('--score-type',
choices=C.SCORING_TYPE_CHOICES,
default=C.SCORING_TYPE_DEFAULT,
help='Score type to output. Default: %(default)s')

def add_max_output_cli_args(params):
params.add_argument('--max-output-length',
type=int,
Expand Down
16 changes: 16 additions & 0 deletions sockeye/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,18 +318,23 @@
OUTPUT_HANDLER_TRANSLATION_WITH_SCORE = "translation_with_score"
OUTPUT_HANDLER_TRANSLATION_WITH_ALIGNMENTS = "translation_with_alignments"
OUTPUT_HANDLER_TRANSLATION_WITH_ALIGNMENT_MATRIX = "translation_with_alignment_matrix"
OUTPUT_HANDLER_SCORE = "score"
OUTPUT_HANDLER_PAIR_WITH_SCORE = "pair_with_score"
OUTPUT_HANDLER_BENCHMARK = "benchmark"
OUTPUT_HANDLER_ALIGN_PLOT = "align_plot"
OUTPUT_HANDLER_ALIGN_TEXT = "align_text"
OUTPUT_HANDLER_BEAM_STORE = "beam_store"
OUTPUT_HANDLERS = [OUTPUT_HANDLER_TRANSLATION,
OUTPUT_HANDLER_SCORE,
OUTPUT_HANDLER_TRANSLATION_WITH_SCORE,
OUTPUT_HANDLER_TRANSLATION_WITH_ALIGNMENTS,
OUTPUT_HANDLER_TRANSLATION_WITH_ALIGNMENT_MATRIX,
OUTPUT_HANDLER_BENCHMARK,
OUTPUT_HANDLER_ALIGN_PLOT,
OUTPUT_HANDLER_ALIGN_TEXT,
OUTPUT_HANDLER_BEAM_STORE]
OUTPUT_HANDLERS_SCORING = [OUTPUT_HANDLER_SCORE,
OUTPUT_HANDLER_PAIR_WITH_SCORE]

# metrics
ACCURACY = 'accuracy'
Expand Down Expand Up @@ -394,7 +399,18 @@
PREPARED_DATA_VERSION_FILE = "data.version"
PREPARED_DATA_VERSION = 2

FILL_UP_REPLICATE = 'replicate'
FILL_UP_ZEROS = 'zeros'
FILL_UP_DEFAULT = FILL_UP_REPLICATE
FILL_UP_CHOICES = [FILL_UP_REPLICATE, FILL_UP_ZEROS]

# reranking
RERANK_BLEU = "bleu"
RERANK_CHRF = "chrf"
RERANK_METRICS = [RERANK_BLEU, RERANK_CHRF]

# scoring
SCORING_TYPE_NEGLOGPROB = 'neglogprob'
SCORING_TYPE_LOGPROB = 'logprob'
SCORING_TYPE_DEFAULT = SCORING_TYPE_NEGLOGPROB
SCORING_TYPE_CHOICES = [SCORING_TYPE_NEGLOGPROB, SCORING_TYPE_LOGPROB]
Loading

0 comments on commit 5a50d96

Please sign in to comment.