-
Notifications
You must be signed in to change notification settings - Fork 323
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Skip softmax and topk #519
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice change, I do have 2 major comments on this:
- I'm a little bit concerned about the user interface of this. There is a lot of silent configuration change happening for various cases (depending on beam size and whether the user loads multiple models (ensembling)). I would prefer being explicit about incompatible settings and fail early, for example, if the user wants to use ensembling but requests to skip softmax. Likewise for greedy decoding, I would expect an error if
--skip-topk
and--beam-size >1
are set. - We should go for the current best performance settings, that is, use topk for k=1 instead of argmin IF that is faster and add a comment or TODO in the code.
Get the single lowest element per sentence from a `scores` matrix. Expects that | ||
beam size is 1, for greedy decoding. | ||
|
||
NOTE(mathmu): The current implementation of argmin in MXNet much slower than topk with k=1. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In that case I think we should use topk then and add a TODO to change to argmin once its sped up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The default behaviour is still to use topk for beam size 1, using top1 is a CLI option: --skip-topk
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but aren't you saying that this is slower than using mx.nd.topk? What's the point of this option if its slower?
Edit: nevermind
" --rnn-attention-num-hidden 8 --batch-size 2 --loss cross-entropy --optimized-metric perplexity --max-updates 2" | ||
" --checkpoint-frequency 2 --optimizer adam --initial-learning-rate 0.01 --batch-type sentence " | ||
" --decode-and-evaluate 0", | ||
"--beam-size 2 --softmax-temperature 0.01", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see the --skip-softmax
here; also, using --softmax-temperature in this test is quite confusing then :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
--skip-softmax
is not an option anymore, but the default behaviour. Sorry about the confusion; it is something I changed with the third commit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you should update the comment then. Also: It's --beam-size 2
and therefore not greedy ;)
sockeye/inference.py
Outdated
if self.beam_size == 1 and self.skip_softmax: | ||
neg_probs = -probs[0] | ||
else: | ||
neg_probs = -mx.nd.log(probs[0]) # pylint: disable=invalid-unary-operand-type |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these are still negative log probs, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I understand correctly, if log is not applied they will not be in log space.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, so neither name can be entirely correct :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I see - yes you are right. Want me to change it?
@@ -79,6 +83,8 @@ def __init__(self, | |||
self.max_input_length, self.get_max_output_length = models_max_input_output_length([self], | |||
max_output_length_num_stds, | |||
forced_max_output_len=forced_max_output_len) | |||
self.skip_softmax = skip_softmax |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you should probably add a check_condition
on the beam_size
sockeye/inference.py
Outdated
@@ -394,6 +406,13 @@ def load_models(context: mx.context.Context, | |||
if checkpoints is None: | |||
checkpoints = [None] * len(model_folders) | |||
|
|||
# skip softmax for a single model, | |||
if len(model_folders) == 1: | |||
skip_softmax = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
probably makes sense to only do this for beam_size == 1
explicitly
sockeye/inference.py
Outdated
forced_max_output_len: Optional[int] = None) -> None: | ||
forced_max_output_len: Optional[int] = None, | ||
skip_softmax: bool = False, | ||
skip_topk: bool = False) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
skip_topk
is only used in Translator but not in InfereceModel. Therefore it should not be a member of InferenceModel, but rather be directly passed to Translator.
" --rnn-attention-num-hidden 8 --batch-size 2 --loss cross-entropy --optimized-metric perplexity --max-updates 2" | ||
" --checkpoint-frequency 2 --optimizer adam --initial-learning-rate 0.01 --batch-type sentence " | ||
" --decode-and-evaluate 0", | ||
"--beam-size 2 --softmax-temperature 0.01", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you should update the comment then. Also: It's --beam-size 2
and therefore not greedy ;)
792581b
to
73e9579
Compare
sockeye/inference.py
Outdated
super().__init__(config) | ||
self.params_fname = params_fname | ||
self.context = context | ||
self.beam_size = beam_size | ||
utils.check_condition(beam_size < self.config.vocab_target_size, | ||
'The beam size must be smaller than the target vocabulary size.') | ||
if skip_softmax: | ||
utils.check_condition(beam_size == 1, 'Skipping softmax does not have any effect for beam size > 1') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as this will raise a RuntimeError, maybe: "Softmax skipping can only be enabled with beam size of 1."
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also, actually now skip_softmax is no longer a user facing parameter. So I in that case I'd say make it an assertion.
sockeye/inference.py
Outdated
if len(model_folders) == 1 and beam_size == 1: | ||
skip_softmax = True | ||
else: | ||
# but not for an ensemble |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe log whether softmax skipping is enabled or not
sockeye/inference.py
Outdated
@@ -981,7 +1001,8 @@ def __init__(self, | |||
restrict_lexicon: Optional[lexicon.TopKLexicon] = None, | |||
avoid_list: Optional[str] = None, | |||
store_beam: bool = False, | |||
strip_unknown_words: bool = False) -> None: | |||
strip_unknown_words: bool = False, | |||
skip_topk: bool = False,) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you probably want to remove that comma
sockeye/inference.py
Outdated
@@ -1005,6 +1026,9 @@ def __init__(self, | |||
self.interpolation_func = self._get_interpolation_func(ensemble_mode) | |||
self.beam_size = self.models[0].beam_size | |||
self.batch_size = self.models[0].batch_size | |||
# skip softmax for a single model, but not for an ensemble | |||
self.skip_softmax = True if len(models) == 1 else False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's decide on skip_softmax
in just a single location, probably in load_models
and reuse it after.
sockeye/inference.py
Outdated
@@ -1352,9 +1380,14 @@ def _decode_step(self, | |||
# Compute logits and softmax with restricted vocabulary | |||
if self.restrict_lexicon: | |||
logits = model.output_layer(decoder_outputs, out_w, out_b) | |||
probs = mx.nd.softmax(logits) | |||
if self.beam_size == 1 and self.skip_softmax: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wouldn't it be enough to use skip_softmax
, as it should only be true with beam_size 1?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for iterating on this!
sockeye/inference.py
Outdated
@@ -1005,6 +1028,9 @@ def __init__(self, | |||
self.interpolation_func = self._get_interpolation_func(ensemble_mode) | |||
self.beam_size = self.models[0].beam_size | |||
self.batch_size = self.models[0].batch_size | |||
# skip softmax for a single model, but not for an ensemble | |||
self.skip_softmax = self.models[0].skip_softmax | |||
self.skip_topk = skip_topk |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe a final check_condition to cover the case of someone creating a Translator
without load_models
.
if skip_softmax:
check_condition(len(self.models) == 0 and self.beam_size ==1, "...")
sockeye/inference.py
Outdated
@@ -1377,10 +1412,13 @@ def _combine_predictions(self, | |||
|
|||
# combine model predictions and convert to neg log probs | |||
if len(self.models) == 1: | |||
neg_logprobs = -mx.nd.log(probs[0]) # pylint: disable=invalid-unary-operand-type | |||
if self.beam_size == 1 and self.skip_softmax: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this check necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
sockeye/inference.py
Outdated
else: | ||
# but not for an ensemble or beam search | ||
skip_softmax = False | ||
logger.info("Disabled skipping softmax for several models or beam size larger than 1.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this warning can be confusing if a user does NOT set skip_softmax and runs with the default settings (beam_size=5). Maybe only print this warning when skip_softmax==True? Or remove it entirely
Get the single lowest element per sentence from a `scores` matrix. Expects that | ||
beam size is 1, for greedy decoding. | ||
|
||
NOTE(mathmu): The current implementation of argmin in MXNet much slower than topk with k=1. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but aren't you saying that this is slower than using mx.nd.topk? What's the point of this option if its slower?
Edit: nevermind
Skip softmax by default for greedy decoding, and adding option to skip topk:
Skipping softmax is significantly faster, skipping topk is significantly slower, but it could be faster if the MXnet implementation of argmin changes.
Pull Request Checklist
until you can check this box.
pytest
)pytest test/system
)./style-check.sh
)sockeye/__init__.py
. Major version bump if this is a backwards incompatible change.By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license.