Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skip softmax and topk #519

Merged
merged 8 commits into from
Sep 7, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ Note that Sockeye has checks in place to not translate with an old model that wa

Each version section may have have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_.

## [1.18.53]
### Added
- Always skipping softmax for greedy decoding by default, only for single models.
- Added option `--skip-topk` for greedy decoding.

## [1.18.52]
### Fixed
- Fixed bug in constrained decoding to make sure best hypothesis satifies all constraints.
Expand Down
2 changes: 1 addition & 1 deletion sockeye/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

__version__ = '1.18.52'
__version__ = '1.18.53'
5 changes: 5 additions & 0 deletions sockeye/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1148,6 +1148,11 @@ def add_inference_args(params):
' Default: %d without batching '
'and %d * batch_size with batching.' % (C.CHUNK_SIZE_NO_BATCHING,
C.CHUNK_SIZE_PER_BATCH_SEGMENT))
decode_params.add_argument('--skip-topk',
default=False,
action='store_true',
help='Use argmax instead of topk for greedy decoding (when --beam-size 1).'
'Default: %(default)s.')
decode_params.add_argument('--ensemble-mode',
type=str,
default='linear',
Expand Down
70 changes: 55 additions & 15 deletions sockeye/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class InferenceModel(model.SockeyeModel):
:param decoder_return_logit_inputs: Decoder returns inputs to logit computation instead of softmax over target
vocabulary. Used when logits/softmax are handled separately.
:param cache_output_layer_w_b: Cache weights and biases for logit computation.
:param skip_softmax: If True, does not compute softmax for greedy decoding.
"""

def __init__(self,
Expand All @@ -67,18 +68,22 @@ def __init__(self,
max_output_length_num_stds: int = C.DEFAULT_NUM_STD_MAX_OUTPUT_LENGTH,
decoder_return_logit_inputs: bool = False,
cache_output_layer_w_b: bool = False,
forced_max_output_len: Optional[int] = None) -> None:
forced_max_output_len: Optional[int] = None,
skip_softmax: bool = False) -> None:
super().__init__(config)
self.params_fname = params_fname
self.context = context
self.beam_size = beam_size
utils.check_condition(beam_size < self.config.vocab_target_size,
'The beam size must be smaller than the target vocabulary size.')
if skip_softmax:
assert beam_size == 1, 'Skipping softmax does not have any effect for beam size > 1'
self.batch_size = batch_size
self.softmax_temperature = softmax_temperature
self.max_input_length, self.get_max_output_length = models_max_input_output_length([self],
max_output_length_num_stds,
forced_max_output_len=forced_max_output_len)
self.skip_softmax = skip_softmax
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should probably add a check_condition on the beam_size


self.encoder_module = None # type: Optional[mx.mod.BucketingModule]
self.encoder_default_bucket_key = None # type: Optional[int]
Expand Down Expand Up @@ -236,7 +241,11 @@ def sym_gen(bucket_key: Tuple[int, int]):
logits = self.output_layer(target_decoded)
if self.softmax_temperature is not None:
logits = logits / self.softmax_temperature
outputs = mx.sym.softmax(data=logits, name=C.SOFTMAX_NAME)
if self.skip_softmax:
# skip softmax for greedy decoding
outputs = logits
else:
outputs = mx.sym.softmax(data=logits, name=C.SOFTMAX_NAME)

data_names = [C.TARGET_NAME] + state_names
label_names = [] # type: List[str]
Expand Down Expand Up @@ -394,6 +403,14 @@ def load_models(context: mx.context.Context,
if checkpoints is None:
checkpoints = [None] * len(model_folders)

# skip softmax for a single model,
if len(model_folders) == 1 and beam_size == 1:
skip_softmax = True
logger.info("Enabled skipping softmax for a single model and greedy decoding.")
else:
# but not for an ensemble or beam search
skip_softmax = False

for model_folder, checkpoint in zip(model_folders, checkpoints):
model_source_vocabs = vocab.load_source_vocabs(model_folder)
model_target_vocab = vocab.load_target_vocab(model_folder)
Expand All @@ -420,13 +437,16 @@ def load_models(context: mx.context.Context,
batch_size=batch_size,
softmax_temperature=softmax_temperature,
decoder_return_logit_inputs=decoder_return_logit_inputs,
cache_output_layer_w_b=cache_output_layer_w_b)
cache_output_layer_w_b=cache_output_layer_w_b,
skip_softmax=skip_softmax)
utils.check_condition(inference_model.num_source_factors == len(model_source_vocabs),
"Number of loaded source vocabularies (%d) does not match "
"number of source factors for model '%s' (%d)" % (len(model_source_vocabs), model_folder,
inference_model.num_source_factors))
models.append(inference_model)



utils.check_condition(vocab.are_identical(*target_vocabs), "Target vocabulary ids do not match")
first_model_vocabs = source_vocabs[0]
for fi in range(len(first_model_vocabs)):
Expand Down Expand Up @@ -966,6 +986,7 @@ class Translator:
:param avoid_list: Global list of phrases to exclude from the output.
:param store_beam: If True, store the beam search history and return it in the TranslatorOutput.
:param strip_unknown_words: If True, removes any <unk> symbols from outputs.
:param skip_topk: If True, uses argmax instead of topk for greedy decoding.
"""

def __init__(self,
Expand All @@ -981,7 +1002,8 @@ def __init__(self,
restrict_lexicon: Optional[lexicon.TopKLexicon] = None,
avoid_list: Optional[str] = None,
store_beam: bool = False,
strip_unknown_words: bool = False) -> None:
strip_unknown_words: bool = False,
skip_topk: bool = False) -> None:
self.context = context
self.length_penalty = length_penalty
self.beam_prune = beam_prune
Expand All @@ -1005,6 +1027,12 @@ def __init__(self,
self.interpolation_func = self._get_interpolation_func(ensemble_mode)
self.beam_size = self.models[0].beam_size
self.batch_size = self.models[0].batch_size
# skip softmax for a single model, but not for an ensemble
self.skip_softmax = self.models[0].skip_softmax
if self.skip_softmax:
utils.check_condition(len(self.models) == 1 and self.beam_size == 1, "Skipping softmax cannot be enabled for several models, or a beam size > 1.")

self.skip_topk = skip_topk
# after models are loaded we ensured that they agree on max_input_length, max_output_length and batch size
self._max_input_length = self.models[0].max_input_length
if bucket_source_width > 0:
Expand All @@ -1027,11 +1055,15 @@ def __init__(self,
self._update_scores.hybridize()

# topk function used in beam search
self._topk = partial(utils.topk,
k=self.beam_size,
batch_size=self.batch_size,
offset=self.offset,
use_mxnet_topk=self.context != mx.cpu()) # MXNet implementation is faster on GPUs
if self.skip_topk:
self._top = partial(utils.top1,
offset=self.offset)
else:
self._top = partial(utils.topk,
k=self.beam_size,
batch_size=self.batch_size,
offset=self.offset,
use_mxnet_topk=self.context != mx.cpu()) # MXNet implementation is faster on GPUs

self._sort_by_index = SortByIndex()
self._sort_by_index.initialize(ctx=self.context)
Expand Down Expand Up @@ -1352,9 +1384,14 @@ def _decode_step(self,
# Compute logits and softmax with restricted vocabulary
if self.restrict_lexicon:
logits = model.output_layer(decoder_outputs, out_w, out_b)
probs = mx.nd.softmax(logits)
if self.skip_softmax:
# skip softmax for greedy decoding and single model
probs = logits
else:
probs = mx.nd.softmax(logits)
else:
# Otherwise decoder outputs are already target vocab probs
# Otherwise decoder outputs are already target vocab probs,
# or logits if beam size is 1
probs = decoder_outputs
model_probs.append(probs)
model_attention_probs.append(attention_probs)
Expand All @@ -1377,10 +1414,13 @@ def _combine_predictions(self,

# combine model predictions and convert to neg log probs
if len(self.models) == 1:
neg_logprobs = -mx.nd.log(probs[0]) # pylint: disable=invalid-unary-operand-type
if self.skip_softmax:
neg_probs = -probs[0]
else:
neg_probs = -mx.nd.log(probs[0]) # pylint: disable=invalid-unary-operand-type
else:
neg_logprobs = self.interpolation_func(probs)
return neg_logprobs, attention_prob_score
neg_probs = self.interpolation_func(probs)
return neg_probs, attention_prob_score

def _beam_search(self,
source: mx.nd.NDArray,
Expand Down Expand Up @@ -1525,7 +1565,7 @@ def _beam_search(self,

# (3) Get beam_size winning hypotheses for each sentence block separately. Only look as
# far as the active beam size for each sentence.
best_hyp_indices, best_word_indices, scores_accumulated = self._topk(scores)
best_hyp_indices, best_word_indices, scores_accumulated = self._top(scores)

# Constraints for constrained decoding are processed sentence by sentence
if any(raw_constraint_list):
Expand Down
7 changes: 6 additions & 1 deletion sockeye/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ def run_translate(args: argparse.Namespace):
if args.checkpoints is not None:
check_condition(len(args.checkpoints) == len(args.models), "must provide checkpoints for each model")

if args.skip_topk:
check_condition(args.beam_size == 1, "--skip-topk has no effect if beam size is larger than 1")
check_condition(len(args.models) == 1, "--skip-topk has no effect for decoding with more than 1 model")

log_basic_info(args)

output_handler = get_output_handler(args.output_type,
Expand Down Expand Up @@ -102,7 +106,8 @@ def run_translate(args: argparse.Namespace):
restrict_lexicon=restrict_lexicon,
avoid_list=args.avoid_list,
store_beam=store_beam,
strip_unknown_words=args.strip_unknown_words)
strip_unknown_words=args.strip_unknown_words,
skip_topk=args.skip_topk)
read_and_translate(translator=translator,
output_handler=output_handler,
chunk_size=args.chunk_size,
Expand Down
22 changes: 22 additions & 0 deletions sockeye/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,28 @@ def variance(self) -> float:
return self._M2 / self._count


def top1(scores: mx.nd.NDArray,
offset: mx.nd.NDArray) -> Tuple[mx.nd.NDArray, mx.nd.NDArray, mx.nd.NDArray]:
"""
Get the single lowest element per sentence from a `scores` matrix. Expects that
beam size is 1, for greedy decoding.

NOTE(mathmu): The current implementation of argmin in MXNet much slower than topk with k=1.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case I think we should use topk then and add a TODO to change to argmin once its sped up.

Copy link
Contributor Author

@bricksdont bricksdont Sep 4, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default behaviour is still to use topk for beam size 1, using top1 is a CLI option: --skip-topk.

Copy link
Contributor

@fhieber fhieber Sep 6, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but aren't you saying that this is slower than using mx.nd.topk? What's the point of this option if its slower?

Edit: nevermind


:param scores: Vocabulary scores for the next beam step. (batch_size * beam_size, target_vocabulary_size)
:param offset: Array to add to the hypothesis indices for offsetting in batch decoding.
:return: The row indices, column indices and values of the smallest items in matrix.
"""
best_word_indices = mx.nd.cast(mx.nd.argmin(scores, axis=1), dtype='int32')
values = scores[mx.nd.arange(scores.shape[0], dtype='int32', ctx=scores.context), best_word_indices]

values = values.reshape((-1, 1))

# for top1, the best hyp indices are equal to the plain offset

return offset, best_word_indices, values


def topk(scores: mx.nd.NDArray,
k: int,
batch_size: int,
Expand Down
8 changes: 8 additions & 0 deletions test/integration/test_seq_copy_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@
" --decode-and-evaluate 0",
"--beam-size 2 --softmax-temperature 0.01",
True, False, False),
# "Vanilla" LSTM encoder-decoder with attention, greedy and skip topk
("--encoder rnn --decoder rnn --num-layers 1 --rnn-cell-type lstm --rnn-num-hidden 8 --num-embed 4 "
" --rnn-attention-type mlp"
" --rnn-attention-num-hidden 8 --batch-size 2 --loss cross-entropy --optimized-metric perplexity --max-updates 2"
" --checkpoint-frequency 2 --optimizer adam --initial-learning-rate 0.01 --batch-type sentence "
" --decode-and-evaluate 0",
"--beam-size 1 --softmax-temperature 0.01 --skip-topk",
True, False, False),
# "Kitchen sink" LSTM encoder-decoder with attention
("--encoder rnn --decoder rnn --num-layers 3:2 --rnn-cell-type lstm --rnn-num-hidden 8"
" --rnn-residual-connections"
Expand Down
3 changes: 2 additions & 1 deletion test/unit/test_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,8 @@ def test_training_arg(test_params, expected_params):
length_penalty_alpha=1.0,
length_penalty_beta=0.0,
strip_unknown_words=False,
override_dtype=None)),
override_dtype=None,
skip_topk=False)),
])
def test_inference_args(test_params, expected_params):
_test_args(test_params, expected_params, arguments.add_inference_args)
Expand Down