Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skip softmax and topk #519

Merged
merged 8 commits into from
Sep 7, 2018
Merged

Skip softmax and topk #519

merged 8 commits into from
Sep 7, 2018

Conversation

bricksdont
Copy link
Contributor

@bricksdont bricksdont commented Aug 31, 2018

Skip softmax by default for greedy decoding, and adding option to skip topk:

  • this change only affects inference
  • only affects greedy decoding

Skipping softmax is significantly faster, skipping topk is significantly slower, but it could be faster if the MXnet implementation of argmin changes.

Pull Request Checklist

  • Changes are complete (if posting work-in-progress code, prefix your pull request title with '[WIP]'
    until you can check this box.
  • Unit tests pass (pytest)
  • Were system tests modified? If so did you run these at least 5 times to account for the variation across runs?
  • System tests pass (pytest test/system)
  • Passed code style checking (./style-check.sh)
  • You have considered writing a test
  • Updated major/minor version in sockeye/__init__.py. Major version bump if this is a backwards incompatible change.
  • Updated CHANGELOG.md

By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license.

@bricksdont bricksdont changed the title [WIP] Skip softmax and topk Skip softmax and topk Sep 3, 2018
Copy link
Contributor

@fhieber fhieber left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice change, I do have 2 major comments on this:

  • I'm a little bit concerned about the user interface of this. There is a lot of silent configuration change happening for various cases (depending on beam size and whether the user loads multiple models (ensembling)). I would prefer being explicit about incompatible settings and fail early, for example, if the user wants to use ensembling but requests to skip softmax. Likewise for greedy decoding, I would expect an error if --skip-topk and --beam-size >1 are set.
  • We should go for the current best performance settings, that is, use topk for k=1 instead of argmin IF that is faster and add a comment or TODO in the code.

Get the single lowest element per sentence from a `scores` matrix. Expects that
beam size is 1, for greedy decoding.

NOTE(mathmu): The current implementation of argmin in MXNet much slower than topk with k=1.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case I think we should use topk then and add a TODO to change to argmin once its sped up.

Copy link
Contributor Author

@bricksdont bricksdont Sep 4, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default behaviour is still to use topk for beam size 1, using top1 is a CLI option: --skip-topk.

Copy link
Contributor

@fhieber fhieber Sep 6, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but aren't you saying that this is slower than using mx.nd.topk? What's the point of this option if its slower?

Edit: nevermind

" --rnn-attention-num-hidden 8 --batch-size 2 --loss cross-entropy --optimized-metric perplexity --max-updates 2"
" --checkpoint-frequency 2 --optimizer adam --initial-learning-rate 0.01 --batch-type sentence "
" --decode-and-evaluate 0",
"--beam-size 2 --softmax-temperature 0.01",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see the --skip-softmax here; also, using --softmax-temperature in this test is quite confusing then :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

--skip-softmax is not an option anymore, but the default behaviour. Sorry about the confusion; it is something I changed with the third commit.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should update the comment then. Also: It's --beam-size 2 and therefore not greedy ;)

if self.beam_size == 1 and self.skip_softmax:
neg_probs = -probs[0]
else:
neg_probs = -mx.nd.log(probs[0]) # pylint: disable=invalid-unary-operand-type
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these are still negative log probs, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, if log is not applied they will not be in log space.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, so neither name can be entirely correct :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see - yes you are right. Want me to change it?

@@ -79,6 +83,8 @@ def __init__(self,
self.max_input_length, self.get_max_output_length = models_max_input_output_length([self],
max_output_length_num_stds,
forced_max_output_len=forced_max_output_len)
self.skip_softmax = skip_softmax
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should probably add a check_condition on the beam_size

@@ -394,6 +406,13 @@ def load_models(context: mx.context.Context,
if checkpoints is None:
checkpoints = [None] * len(model_folders)

# skip softmax for a single model,
if len(model_folders) == 1:
skip_softmax = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably makes sense to only do this for beam_size == 1 explicitly

forced_max_output_len: Optional[int] = None) -> None:
forced_max_output_len: Optional[int] = None,
skip_softmax: bool = False,
skip_topk: bool = False) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

skip_topk is only used in Translator but not in InfereceModel. Therefore it should not be a member of InferenceModel, but rather be directly passed to Translator.

" --rnn-attention-num-hidden 8 --batch-size 2 --loss cross-entropy --optimized-metric perplexity --max-updates 2"
" --checkpoint-frequency 2 --optimizer adam --initial-learning-rate 0.01 --batch-type sentence "
" --decode-and-evaluate 0",
"--beam-size 2 --softmax-temperature 0.01",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should update the comment then. Also: It's --beam-size 2 and therefore not greedy ;)

super().__init__(config)
self.params_fname = params_fname
self.context = context
self.beam_size = beam_size
utils.check_condition(beam_size < self.config.vocab_target_size,
'The beam size must be smaller than the target vocabulary size.')
if skip_softmax:
utils.check_condition(beam_size == 1, 'Skipping softmax does not have any effect for beam size > 1')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as this will raise a RuntimeError, maybe: "Softmax skipping can only be enabled with beam size of 1."

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, actually now skip_softmax is no longer a user facing parameter. So I in that case I'd say make it an assertion.

if len(model_folders) == 1 and beam_size == 1:
skip_softmax = True
else:
# but not for an ensemble
Copy link
Contributor

@tdomhan tdomhan Sep 5, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe log whether softmax skipping is enabled or not

@@ -981,7 +1001,8 @@ def __init__(self,
restrict_lexicon: Optional[lexicon.TopKLexicon] = None,
avoid_list: Optional[str] = None,
store_beam: bool = False,
strip_unknown_words: bool = False) -> None:
strip_unknown_words: bool = False,
skip_topk: bool = False,) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you probably want to remove that comma

@@ -1005,6 +1026,9 @@ def __init__(self,
self.interpolation_func = self._get_interpolation_func(ensemble_mode)
self.beam_size = self.models[0].beam_size
self.batch_size = self.models[0].batch_size
# skip softmax for a single model, but not for an ensemble
self.skip_softmax = True if len(models) == 1 else False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's decide on skip_softmax in just a single location, probably in load_models and reuse it after.

@@ -1352,9 +1380,14 @@ def _decode_step(self,
# Compute logits and softmax with restricted vocabulary
if self.restrict_lexicon:
logits = model.output_layer(decoder_outputs, out_w, out_b)
probs = mx.nd.softmax(logits)
if self.beam_size == 1 and self.skip_softmax:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wouldn't it be enough to use skip_softmax, as it should only be true with beam_size 1?

Copy link
Contributor

@tdomhan tdomhan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for iterating on this!

@@ -1005,6 +1028,9 @@ def __init__(self,
self.interpolation_func = self._get_interpolation_func(ensemble_mode)
self.beam_size = self.models[0].beam_size
self.batch_size = self.models[0].batch_size
# skip softmax for a single model, but not for an ensemble
self.skip_softmax = self.models[0].skip_softmax
self.skip_topk = skip_topk
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe a final check_condition to cover the case of someone creating a Translator without load_models.

if skip_softmax:
  check_condition(len(self.models) == 0 and self.beam_size ==1, "...")

@@ -1377,10 +1412,13 @@ def _combine_predictions(self,

# combine model predictions and convert to neg log probs
if len(self.models) == 1:
neg_logprobs = -mx.nd.log(probs[0]) # pylint: disable=invalid-unary-operand-type
if self.beam_size == 1 and self.skip_softmax:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this check necessary?

Copy link
Contributor

@tdomhan tdomhan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

else:
# but not for an ensemble or beam search
skip_softmax = False
logger.info("Disabled skipping softmax for several models or beam size larger than 1.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this warning can be confusing if a user does NOT set skip_softmax and runs with the default settings (beam_size=5). Maybe only print this warning when skip_softmax==True? Or remove it entirely

Get the single lowest element per sentence from a `scores` matrix. Expects that
beam size is 1, for greedy decoding.

NOTE(mathmu): The current implementation of argmin in MXNet much slower than topk with k=1.
Copy link
Contributor

@fhieber fhieber Sep 6, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but aren't you saying that this is slower than using mx.nd.topk? What's the point of this option if its slower?

Edit: nevermind

@fhieber fhieber merged commit 7e8fc97 into awslabs:master Sep 7, 2018
@bricksdont bricksdont mentioned this pull request Sep 26, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants