Skip to content

Commit

Permalink
Skip softmax and topk (awslabs#519)
Browse files Browse the repository at this point in the history
Skip softmax by default for greedy decoding, and adding option to skip topk:
- this change only affects inference
- only affects greedy decoding
  • Loading branch information
bricksdont authored and fhieber committed Sep 7, 2018
1 parent ca54e53 commit 7e8fc97
Show file tree
Hide file tree
Showing 8 changed files with 104 additions and 18 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ Note that Sockeye has checks in place to not translate with an old model that wa

Each version section may have have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_.

## [1.18.53]
### Added
- Always skipping softmax for greedy decoding by default, only for single models.
- Added option `--skip-topk` for greedy decoding.

## [1.18.52]
### Fixed
- Fixed bug in constrained decoding to make sure best hypothesis satifies all constraints.
Expand Down
2 changes: 1 addition & 1 deletion sockeye/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

__version__ = '1.18.52'
__version__ = '1.18.53'
5 changes: 5 additions & 0 deletions sockeye/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1148,6 +1148,11 @@ def add_inference_args(params):
' Default: %d without batching '
'and %d * batch_size with batching.' % (C.CHUNK_SIZE_NO_BATCHING,
C.CHUNK_SIZE_PER_BATCH_SEGMENT))
decode_params.add_argument('--skip-topk',
default=False,
action='store_true',
help='Use argmax instead of topk for greedy decoding (when --beam-size 1).'
'Default: %(default)s.')
decode_params.add_argument('--ensemble-mode',
type=str,
default='linear',
Expand Down
70 changes: 55 additions & 15 deletions sockeye/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class InferenceModel(model.SockeyeModel):
:param decoder_return_logit_inputs: Decoder returns inputs to logit computation instead of softmax over target
vocabulary. Used when logits/softmax are handled separately.
:param cache_output_layer_w_b: Cache weights and biases for logit computation.
:param skip_softmax: If True, does not compute softmax for greedy decoding.
"""

def __init__(self,
Expand All @@ -67,18 +68,22 @@ def __init__(self,
max_output_length_num_stds: int = C.DEFAULT_NUM_STD_MAX_OUTPUT_LENGTH,
decoder_return_logit_inputs: bool = False,
cache_output_layer_w_b: bool = False,
forced_max_output_len: Optional[int] = None) -> None:
forced_max_output_len: Optional[int] = None,
skip_softmax: bool = False) -> None:
super().__init__(config)
self.params_fname = params_fname
self.context = context
self.beam_size = beam_size
utils.check_condition(beam_size < self.config.vocab_target_size,
'The beam size must be smaller than the target vocabulary size.')
if skip_softmax:
assert beam_size == 1, 'Skipping softmax does not have any effect for beam size > 1'
self.batch_size = batch_size
self.softmax_temperature = softmax_temperature
self.max_input_length, self.get_max_output_length = models_max_input_output_length([self],
max_output_length_num_stds,
forced_max_output_len=forced_max_output_len)
self.skip_softmax = skip_softmax

self.encoder_module = None # type: Optional[mx.mod.BucketingModule]
self.encoder_default_bucket_key = None # type: Optional[int]
Expand Down Expand Up @@ -236,7 +241,11 @@ def sym_gen(bucket_key: Tuple[int, int]):
logits = self.output_layer(target_decoded)
if self.softmax_temperature is not None:
logits = logits / self.softmax_temperature
outputs = mx.sym.softmax(data=logits, name=C.SOFTMAX_NAME)
if self.skip_softmax:
# skip softmax for greedy decoding
outputs = logits
else:
outputs = mx.sym.softmax(data=logits, name=C.SOFTMAX_NAME)

data_names = [C.TARGET_NAME] + state_names
label_names = [] # type: List[str]
Expand Down Expand Up @@ -394,6 +403,14 @@ def load_models(context: mx.context.Context,
if checkpoints is None:
checkpoints = [None] * len(model_folders)

# skip softmax for a single model,
if len(model_folders) == 1 and beam_size == 1:
skip_softmax = True
logger.info("Enabled skipping softmax for a single model and greedy decoding.")
else:
# but not for an ensemble or beam search
skip_softmax = False

for model_folder, checkpoint in zip(model_folders, checkpoints):
model_source_vocabs = vocab.load_source_vocabs(model_folder)
model_target_vocab = vocab.load_target_vocab(model_folder)
Expand All @@ -420,13 +437,16 @@ def load_models(context: mx.context.Context,
batch_size=batch_size,
softmax_temperature=softmax_temperature,
decoder_return_logit_inputs=decoder_return_logit_inputs,
cache_output_layer_w_b=cache_output_layer_w_b)
cache_output_layer_w_b=cache_output_layer_w_b,
skip_softmax=skip_softmax)
utils.check_condition(inference_model.num_source_factors == len(model_source_vocabs),
"Number of loaded source vocabularies (%d) does not match "
"number of source factors for model '%s' (%d)" % (len(model_source_vocabs), model_folder,
inference_model.num_source_factors))
models.append(inference_model)



utils.check_condition(vocab.are_identical(*target_vocabs), "Target vocabulary ids do not match")
first_model_vocabs = source_vocabs[0]
for fi in range(len(first_model_vocabs)):
Expand Down Expand Up @@ -966,6 +986,7 @@ class Translator:
:param avoid_list: Global list of phrases to exclude from the output.
:param store_beam: If True, store the beam search history and return it in the TranslatorOutput.
:param strip_unknown_words: If True, removes any <unk> symbols from outputs.
:param skip_topk: If True, uses argmax instead of topk for greedy decoding.
"""

def __init__(self,
Expand All @@ -981,7 +1002,8 @@ def __init__(self,
restrict_lexicon: Optional[lexicon.TopKLexicon] = None,
avoid_list: Optional[str] = None,
store_beam: bool = False,
strip_unknown_words: bool = False) -> None:
strip_unknown_words: bool = False,
skip_topk: bool = False) -> None:
self.context = context
self.length_penalty = length_penalty
self.beam_prune = beam_prune
Expand All @@ -1005,6 +1027,12 @@ def __init__(self,
self.interpolation_func = self._get_interpolation_func(ensemble_mode)
self.beam_size = self.models[0].beam_size
self.batch_size = self.models[0].batch_size
# skip softmax for a single model, but not for an ensemble
self.skip_softmax = self.models[0].skip_softmax
if self.skip_softmax:
utils.check_condition(len(self.models) == 1 and self.beam_size == 1, "Skipping softmax cannot be enabled for several models, or a beam size > 1.")

self.skip_topk = skip_topk
# after models are loaded we ensured that they agree on max_input_length, max_output_length and batch size
self._max_input_length = self.models[0].max_input_length
if bucket_source_width > 0:
Expand All @@ -1027,11 +1055,15 @@ def __init__(self,
self._update_scores.hybridize()

# topk function used in beam search
self._topk = partial(utils.topk,
k=self.beam_size,
batch_size=self.batch_size,
offset=self.offset,
use_mxnet_topk=self.context != mx.cpu()) # MXNet implementation is faster on GPUs
if self.skip_topk:
self._top = partial(utils.top1,
offset=self.offset)
else:
self._top = partial(utils.topk,
k=self.beam_size,
batch_size=self.batch_size,
offset=self.offset,
use_mxnet_topk=self.context != mx.cpu()) # MXNet implementation is faster on GPUs

self._sort_by_index = SortByIndex()
self._sort_by_index.initialize(ctx=self.context)
Expand Down Expand Up @@ -1352,9 +1384,14 @@ def _decode_step(self,
# Compute logits and softmax with restricted vocabulary
if self.restrict_lexicon:
logits = model.output_layer(decoder_outputs, out_w, out_b)
probs = mx.nd.softmax(logits)
if self.skip_softmax:
# skip softmax for greedy decoding and single model
probs = logits
else:
probs = mx.nd.softmax(logits)
else:
# Otherwise decoder outputs are already target vocab probs
# Otherwise decoder outputs are already target vocab probs,
# or logits if beam size is 1
probs = decoder_outputs
model_probs.append(probs)
model_attention_probs.append(attention_probs)
Expand All @@ -1377,10 +1414,13 @@ def _combine_predictions(self,

# combine model predictions and convert to neg log probs
if len(self.models) == 1:
neg_logprobs = -mx.nd.log(probs[0]) # pylint: disable=invalid-unary-operand-type
if self.skip_softmax:
neg_probs = -probs[0]
else:
neg_probs = -mx.nd.log(probs[0]) # pylint: disable=invalid-unary-operand-type
else:
neg_logprobs = self.interpolation_func(probs)
return neg_logprobs, attention_prob_score
neg_probs = self.interpolation_func(probs)
return neg_probs, attention_prob_score

def _beam_search(self,
source: mx.nd.NDArray,
Expand Down Expand Up @@ -1525,7 +1565,7 @@ def _beam_search(self,

# (3) Get beam_size winning hypotheses for each sentence block separately. Only look as
# far as the active beam size for each sentence.
best_hyp_indices, best_word_indices, scores_accumulated = self._topk(scores)
best_hyp_indices, best_word_indices, scores_accumulated = self._top(scores)

# Constraints for constrained decoding are processed sentence by sentence
if any(raw_constraint_list):
Expand Down
7 changes: 6 additions & 1 deletion sockeye/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ def run_translate(args: argparse.Namespace):
if args.checkpoints is not None:
check_condition(len(args.checkpoints) == len(args.models), "must provide checkpoints for each model")

if args.skip_topk:
check_condition(args.beam_size == 1, "--skip-topk has no effect if beam size is larger than 1")
check_condition(len(args.models) == 1, "--skip-topk has no effect for decoding with more than 1 model")

log_basic_info(args)

output_handler = get_output_handler(args.output_type,
Expand Down Expand Up @@ -102,7 +106,8 @@ def run_translate(args: argparse.Namespace):
restrict_lexicon=restrict_lexicon,
avoid_list=args.avoid_list,
store_beam=store_beam,
strip_unknown_words=args.strip_unknown_words)
strip_unknown_words=args.strip_unknown_words,
skip_topk=args.skip_topk)
read_and_translate(translator=translator,
output_handler=output_handler,
chunk_size=args.chunk_size,
Expand Down
22 changes: 22 additions & 0 deletions sockeye/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,28 @@ def variance(self) -> float:
return self._M2 / self._count


def top1(scores: mx.nd.NDArray,
offset: mx.nd.NDArray) -> Tuple[mx.nd.NDArray, mx.nd.NDArray, mx.nd.NDArray]:
"""
Get the single lowest element per sentence from a `scores` matrix. Expects that
beam size is 1, for greedy decoding.
NOTE(mathmu): The current implementation of argmin in MXNet much slower than topk with k=1.
:param scores: Vocabulary scores for the next beam step. (batch_size * beam_size, target_vocabulary_size)
:param offset: Array to add to the hypothesis indices for offsetting in batch decoding.
:return: The row indices, column indices and values of the smallest items in matrix.
"""
best_word_indices = mx.nd.cast(mx.nd.argmin(scores, axis=1), dtype='int32')
values = scores[mx.nd.arange(scores.shape[0], dtype='int32', ctx=scores.context), best_word_indices]

values = values.reshape((-1, 1))

# for top1, the best hyp indices are equal to the plain offset

return offset, best_word_indices, values


def topk(scores: mx.nd.NDArray,
k: int,
batch_size: int,
Expand Down
8 changes: 8 additions & 0 deletions test/integration/test_seq_copy_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@
" --decode-and-evaluate 0",
"--beam-size 2 --softmax-temperature 0.01",
True, False, False),
# "Vanilla" LSTM encoder-decoder with attention, greedy and skip topk
("--encoder rnn --decoder rnn --num-layers 1 --rnn-cell-type lstm --rnn-num-hidden 8 --num-embed 4 "
" --rnn-attention-type mlp"
" --rnn-attention-num-hidden 8 --batch-size 2 --loss cross-entropy --optimized-metric perplexity --max-updates 2"
" --checkpoint-frequency 2 --optimizer adam --initial-learning-rate 0.01 --batch-type sentence "
" --decode-and-evaluate 0",
"--beam-size 1 --softmax-temperature 0.01 --skip-topk",
True, False, False),
# "Kitchen sink" LSTM encoder-decoder with attention
("--encoder rnn --decoder rnn --num-layers 3:2 --rnn-cell-type lstm --rnn-num-hidden 8"
" --rnn-residual-connections"
Expand Down
3 changes: 2 additions & 1 deletion test/unit/test_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,8 @@ def test_training_arg(test_params, expected_params):
length_penalty_alpha=1.0,
length_penalty_beta=0.0,
strip_unknown_words=False,
override_dtype=None)),
override_dtype=None,
skip_topk=False)),
])
def test_inference_args(test_params, expected_params):
_test_args(test_params, expected_params, arguments.add_inference_args)
Expand Down

0 comments on commit 7e8fc97

Please sign in to comment.