-
Notifications
You must be signed in to change notification settings - Fork 323
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Skip softmax and topk #519
Changes from all commits
9a720f1
5b9de14
5c05dc3
f0f05fa
73e9579
38a4a15
f623e02
32cb10d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -252,6 +252,28 @@ def variance(self) -> float: | |
return self._M2 / self._count | ||
|
||
|
||
def top1(scores: mx.nd.NDArray, | ||
offset: mx.nd.NDArray) -> Tuple[mx.nd.NDArray, mx.nd.NDArray, mx.nd.NDArray]: | ||
""" | ||
Get the single lowest element per sentence from a `scores` matrix. Expects that | ||
beam size is 1, for greedy decoding. | ||
|
||
NOTE(mathmu): The current implementation of argmin in MXNet much slower than topk with k=1. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In that case I think we should use topk then and add a TODO to change to argmin once its sped up. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The default behaviour is still to use topk for beam size 1, using top1 is a CLI option: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. but aren't you saying that this is slower than using mx.nd.topk? What's the point of this option if its slower? Edit: nevermind |
||
|
||
:param scores: Vocabulary scores for the next beam step. (batch_size * beam_size, target_vocabulary_size) | ||
:param offset: Array to add to the hypothesis indices for offsetting in batch decoding. | ||
:return: The row indices, column indices and values of the smallest items in matrix. | ||
""" | ||
best_word_indices = mx.nd.cast(mx.nd.argmin(scores, axis=1), dtype='int32') | ||
values = scores[mx.nd.arange(scores.shape[0], dtype='int32', ctx=scores.context), best_word_indices] | ||
|
||
values = values.reshape((-1, 1)) | ||
|
||
# for top1, the best hyp indices are equal to the plain offset | ||
|
||
return offset, best_word_indices, values | ||
|
||
|
||
def topk(scores: mx.nd.NDArray, | ||
k: int, | ||
batch_size: int, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you should probably add a
check_condition
on thebeam_size