Skip to content

Commit

Permalink
Move compatibility checks for inference configuration variables into …
Browse files Browse the repository at this point in the history
…API functions (Translator.init() & load_models()). Some more cleanups. (awslabs#592)
  • Loading branch information
fhieber authored Dec 17, 2018
1 parent 094baca commit 769e517
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 44 deletions.
35 changes: 25 additions & 10 deletions sockeye/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,8 @@ def load_models(context: mx.context.Context,

if checkpoints is None:
checkpoints = [None] * len(model_folders)
else:
utils.check_condition(len(checkpoints) == len(model_folders), "Must provide checkpoints for each model")

skip_softmax = False
# performance tweak: skip softmax for a single model, decoding with beam size 1, when not sampling and no scores are required in output.
Expand All @@ -430,6 +432,10 @@ def load_models(context: mx.context.Context,
if override_dtype is not None:
model_config.config_encoder.dtype = override_dtype
model_config.config_decoder.dtype = override_dtype
if override_dtype == C.DTYPE_FP16:
logger.warning('Experimental feature \'override_dtype=float16\' has been used. '
'This feature may be removed or change its behaviour in future. '
'DO NOT USE IT IN PRODUCTION!')

if checkpoint is None:
params_fname = os.path.join(model_folder, C.PARAMS_BEST_NAME)
Expand Down Expand Up @@ -751,7 +757,8 @@ def make_input_from_dict(sentence_id: SentenceId, input_dict: Dict) -> Translato
if isinstance(constraints, list):
constraints = [list(data_io.get_tokens(constraint)) for constraint in constraints]

return TranslatorInput(sentence_id=sentence_id, tokens=tokens, factors=factors, constraints=constraints, avoid_list=avoid_list)
return TranslatorInput(sentence_id=sentence_id, tokens=tokens, factors=factors,
constraints=constraints, avoid_list=avoid_list)

except Exception as e:
logger.exception(e, exc_info=True) if not is_python34() else logger.error(e) # type: ignore
Expand Down Expand Up @@ -875,6 +882,7 @@ class NBestTranslations:
__slots__ = ('target_ids_list',
'attention_matrices',
'scores')

def __init__(self,
target_ids_list: List[TokenIds],
attention_matrices: List[np.ndarray],
Expand Down Expand Up @@ -1168,15 +1176,21 @@ def __init__(self,
self.interpolation_func = self._get_interpolation_func(ensemble_mode)
self.beam_size = self.models[0].beam_size
self.nbest_size = nbest_size
utils.check_condition(self.beam_size >= nbest_size,
'Nbest size must be smaller or equal to beam size.')
utils.check_condition(self.beam_size >= nbest_size, 'nbest_size must be smaller or equal to beam_size.')
if self.nbest_size > 1:
utils.check_condition(self.beam_search_stop == C.BEAM_SEARCH_STOP_ALL,
"nbest_size > 1 requires beam_search_stop to be set to 'all'")
self.batch_size = self.models[0].batch_size

if any(m.skip_softmax for m in self.models):
utils.check_condition(len(self.models) == 1 and self.beam_size == 1,
"Skipping softmax cannot be enabled for ensembles or beam sizes > 1.")

self.skip_topk = skip_topk
if self.skip_topk:
utils.check_condition(self.beam_size == 1, "skip_topk has no effect if beam size is larger than 1")
utils.check_condition(len(self.models) == 1, "skip_topk has no effect for decoding with more than 1 model")

self.sample = sample
utils.check_condition(not self.sample or self.restrict_lexicon is None,
"Sampling is not available when working with a restricted lexicon.")
Expand All @@ -1195,7 +1209,8 @@ def __init__(self,
ctx=self.context, dtype='float32')

# offset for hypothesis indices in batch decoding
self.offset = mx.nd.array(np.repeat(np.arange(0, self.batch_size * self.beam_size, self.beam_size), self.beam_size),
self.offset = mx.nd.array(np.repeat(np.arange(0, self.batch_size * self.beam_size, self.beam_size),
self.beam_size),
dtype='int32', ctx=self.context)
# locations of each batch item when first dimension is (batch * beam)
self.batch_indices = mx.nd.array(np.arange(0, self.batch_size * self.beam_size, self.beam_size), dtype='int32', ctx=self.context)
Expand Down Expand Up @@ -1250,7 +1265,8 @@ def __init__(self,
for phrase in data_io.read_content(avoid_list):
phrase_ids = data_io.tokens2ids(phrase, self.vocab_target)
if self.unk_id in phrase_ids:
logger.warning("Global avoid phrase '%s' contains an %s; this may indicate improper preprocessing.", ' '.join(phrase), C.UNK_SYMBOL)
logger.warning("Global avoid phrase '%s' contains an %s; this may indicate improper preprocessing.",
' '.join(phrase), C.UNK_SYMBOL)
self.global_avoid_trie.add_phrase(phrase_ids)

self._concat_translations = partial(_concat_nbest_translations if self.nbest_size > 1 else _concat_translations,
Expand Down Expand Up @@ -1416,9 +1432,9 @@ def _get_inference_input(self,
List[Optional[constrained.RawConstraintList]],
mx.nd.NDArray]:
"""
Assembles the numerical data for the batch.
This comprises an NDArray for the source sentences, the bucket key (padded source length), and a list of
raw constraint lists, one for each sentence in the batch, an NDArray of maximum output lengths for each sentence in the batch.
Assembles the numerical data for the batch. This comprises an NDArray for the source sentences,
the bucket key (padded source length), and a list of raw constraint lists, one for each sentence in the batch,
an NDArray of maximum output lengths for each sentence in the batch.
Each raw constraint list contains phrases in the form of lists of integers in the target language vocabulary.
:param trans_inputs: List of TranslatorInputs.
Expand Down Expand Up @@ -1780,8 +1796,7 @@ def _beam_search(self,
constraints,
best_hyp_indices,
best_word_indices,
scores_accumulated,
self.context)
scores_accumulated)

# Map from restricted to full vocab ids if needed
if self.restrict_lexicon:
Expand Down
24 changes: 11 additions & 13 deletions sockeye/lexical_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,8 +491,7 @@ def topk(timestep: int,
hypotheses: List[ConstrainedHypothesis],
best_ids: mx.nd.NDArray,
best_word_ids: mx.nd.NDArray,
seq_scores: mx.nd.NDArray,
context: mx.context.Context) -> Tuple[np.array, np.array, np.array, List[ConstrainedHypothesis], mx.nd.NDArray]:
seq_scores: mx.nd.NDArray) -> Tuple[np.array, np.array, np.array, List[ConstrainedHypothesis], mx.nd.NDArray]:
"""
Builds a new topk list such that the beam contains hypotheses having completed different numbers of constraints.
These items are built from three different types: (1) the best items across the whole
Expand All @@ -507,7 +506,6 @@ def topk(timestep: int,
:param best_ids: The current list of best hypotheses (shape: (beam_size,)).
:param best_word_ids: The parallel list of best word IDs (shape: (beam_size,)).
:param seq_scores: (shape: (beam_size, 1)).
:param context: The MXNet device context.
:return: A tuple containing the best hypothesis rows, the best hypothesis words, the scores,
the updated constrained hypotheses, and the updated set of inactive hypotheses.
"""
Expand All @@ -526,8 +524,7 @@ def topk(timestep: int,
hypotheses[rows],
best_ids[rows] - rows.start,
best_word_ids[rows],
seq_scores[rows],
context)
seq_scores[rows])

# offsetting since the returned smallest_k() indices were slice-relative
best_ids[rows] += rows.start
Expand All @@ -546,8 +543,8 @@ def _topk(timestep: int,
hypotheses: List[ConstrainedHypothesis],
best_ids: mx.nd.NDArray,
best_word_ids: mx.nd.NDArray,
sequence_scores: mx.nd.NDArray,
context: mx.context.Context) -> Tuple[np.array, np.array, np.array, List[ConstrainedHypothesis], mx.nd.NDArray]:
sequence_scores: mx.nd.NDArray) -> Tuple[np.array, np.array, np.array,
List[ConstrainedHypothesis], mx.nd.NDArray]:
"""
Builds a new topk list such that the beam contains hypotheses having completed different numbers of constraints.
These items are built from three different types: (1) the best items across the whole
Expand All @@ -561,7 +558,6 @@ def _topk(timestep: int,
:param best_ids: The current list of best hypotheses (shape: (beam_size,)).
:param best_word_ids: The parallel list of best word IDs (shape: (beam_size,)).
:param sequence_scores: (shape: (beam_size, 1)).
:param context: The MXNet device context.
:return: A tuple containing the best hypothesis rows, the best hypothesis words, the scores,
the updated constrained hypotheses, and the updated set of inactive hypotheses.
"""
Expand Down Expand Up @@ -608,7 +604,7 @@ def _topk(timestep: int,
sorted_candidates = sorted(candidates, key=attrgetter('score'))

# The number of hypotheses in each bank
counts = [0 for x in range(num_constraints + 1)]
counts = [0 for _ in range(num_constraints + 1)]
for cand in sorted_candidates:
counts[cand.hypothesis.num_met()] += 1

Expand All @@ -624,12 +620,14 @@ def _topk(timestep: int,
pruned_candidates.append(cand)
bank_sizes[bank] -= 1

inactive[:len(pruned_candidates)] = 0
num_pruned_candidates = len(pruned_candidates)

inactive[:num_pruned_candidates] = 0

# Pad the beam so array assignment still works
if len(pruned_candidates) < beam_size:
inactive[len(pruned_candidates):] = 1
pruned_candidates += [pruned_candidates[len(pruned_candidates) - 1]] * (beam_size - len(pruned_candidates))
if num_pruned_candidates < beam_size:
inactive[num_pruned_candidates:] = 1
pruned_candidates += [pruned_candidates[num_pruned_candidates - 1]] * (beam_size - num_pruned_candidates)

return (np.array([x.row for x in pruned_candidates]),
np.array([x.col for x in pruned_candidates]),
Expand Down
22 changes: 2 additions & 20 deletions sockeye/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,26 +53,13 @@ def run_translate(args: argparse.Namespace):
file_logging=True,
path="%s.%s" % (args.output, C.LOG_NAME))

if args.checkpoints is not None:
check_condition(len(args.checkpoints) == len(args.models), "must provide checkpoints for each model")

if args.skip_topk:
check_condition(args.beam_size == 1, "--skip-topk has no effect if beam size is larger than 1")
check_condition(len(args.models) == 1, "--skip-topk has no effect for decoding with more than 1 model")
log_basic_info(args)

if args.nbest_size > 1:
check_condition(args.beam_size >= args.nbest_size,
"Size of nbest list (--nbest-size) must be smaller or equal to beam size (--beam-size).")
check_condition(args.beam_search_stop == C.BEAM_SEARCH_STOP_ALL,
"--nbest-size > 1 requires beam search to only stop after all hypotheses are finished "
"(--beam-search-stop all)")
if args.output_type != C.OUTPUT_HANDLER_NBEST:
logger.warning("For nbest translation, output handler must be '%s', overriding option --output-type.",
C.OUTPUT_HANDLER_NBEST)
C.OUTPUT_HANDLER_NBEST)
args.output_type = C.OUTPUT_HANDLER_NBEST

log_basic_info(args)

output_handler = get_output_handler(args.output_type,
args.output,
args.sure_align_threshold)
Expand All @@ -86,11 +73,6 @@ def run_translate(args: argparse.Namespace):
exit_stack=exit_stack)[0]
logger.info("Translate Device: %s", context)

if args.override_dtype == C.DTYPE_FP16:
logger.warning('Experimental feature \'--override-dtype float16\' has been used. '
'This feature may be removed or change its behaviour in future. '
'DO NOT USE IT IN PRODUCTION!')

models, source_vocabs, target_vocab = inference.load_models(
context=context,
max_input_len=args.max_input_len,
Expand Down
2 changes: 1 addition & 1 deletion sockeye/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
"""
import binascii
import errno
import portalocker
import glob
import gzip
import itertools
Expand All @@ -34,6 +33,7 @@

import mxnet as mx
import numpy as np
import portalocker

from . import __version__, constants as C
from .log import log_sockeye_version, log_mxnet_version
Expand Down

0 comments on commit 769e517

Please sign in to comment.