Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sampling #586

Merged
merged 34 commits into from
Dec 14, 2018
Merged
Changes from 1 commit
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
bc1b9f5
implemented decoding by sampling
Oct 30, 2018
bdcacbf
Merge branch 'master' into sampling
mjpost Oct 30, 2018
189508b
Sampling rewrite
Nov 4, 2018
af3945b
Merge branch 'sampling-rewrite' into 'sampling'
Nov 4, 2018
583c57d
Merge branch 'master' into sampling
mjpost Nov 27, 2018
aadbf59
fixed sampling, simplified use of "inactive"
mjpost Nov 27, 2018
6ff5793
fixed util.topk() when MXNet is being used
mjpost Nov 27, 2018
b3778f3
added sampling from the top n vocab items
mjpost Nov 28, 2018
4c51720
removed print statement
mjpost Nov 28, 2018
45016f7
pulled out batch_indices
mjpost Nov 28, 2018
2ebea0e
Merge branch 'master' into sampling
mjpost Nov 28, 2018
eea8d95
bugfix in unravel
mjpost Nov 28, 2018
6d5484c
Merge branch 'sampling' into sampling_topn
mjpost Nov 28, 2018
a032e86
fixed test case
mjpost Nov 28, 2018
3d4e8b3
inverted conditional block to get rid of mypy error
mjpost Nov 28, 2018
192e9a4
removed DEFAULT_RANDOM_SEED
mjpost Nov 28, 2018
6d054f4
fixed negative constraints with sampling
Nov 28, 2018
a4ed533
Merge remote-tracking branch 'origin/sampling_topn' into sampling
mjpost Nov 28, 2018
d785bbf
only update target_dists when sampling
mjpost Nov 28, 2018
28e5986
Merge branch 'master' of github.com:awslabs/sockeye into sampling
mjpost Nov 29, 2018
97b3389
added documentation and incremented version
mjpost Nov 29, 2018
9c5070f
fixing code review items from @fhieber
mjpost Nov 29, 2018
8ac34f8
Merge branch 'master' into sampling
mjpost Nov 29, 2018
ae00ba9
removed stray comment
mjpost Nov 29, 2018
d18b830
Merge branch 'master' into sampling
mjpost Dec 3, 2018
535f02f
added sampling test cases
mjpost Dec 5, 2018
869a7a0
check for restrict_lexicon
mjpost Dec 5, 2018
4055a02
simplified conditional
mjpost Dec 5, 2018
fdea4fc
cleanup test case
mjpost Dec 6, 2018
a3cb97d
reverted pre-computation of `skip_softmax`
mjpost Dec 6, 2018
9f0d70d
Merge branch 'master' into sampling
mjpost Dec 10, 2018
ba128c8
got rid of asscalar(), added some comments
mjpost Dec 10, 2018
0b6fe53
Merge remote-tracking branch 'amazon/master' into sampling
mjpost Dec 12, 2018
5a0bf23
Merge remote-tracking branch 'amazon/master' into sampling
mjpost Dec 13, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
inverted conditional block to get rid of mypy error
  • Loading branch information
mjpost committed Nov 28, 2018
commit 3d4e8b3f371e5e3b91b035c380457af45d41f433
19 changes: 10 additions & 9 deletions sockeye/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1173,15 +1173,7 @@ def __init__(self,
# Vocabulary selection leads to different vocabulary sizes across requests. Hence, we cannot use a
# statically-shaped HybridBlock for the topk operation in this case; resorting to imperative topk
# function in this case.
if self.restrict_lexicon:
if self.skip_topk:
self._top = partial(utils.top1, offset=self.offset) # type: Callable
else:
self._top = partial(utils.topk,
k=self.beam_size,
offset=self.offset,
use_mxnet_topk=True) # type: Callable
else:
if not self.restrict_lexicon:
if self.skip_topk:
self._top = Top1(k=self.beam_size,
batch_size=self.batch_size) # type: mx.gluon.HybridBlock
Expand All @@ -1195,8 +1187,17 @@ def __init__(self,
batch_size=self.batch_size,
vocab_size=len(self.vocab_target)) # type: mx.gluon.HybridBlock


self._top.initialize(ctx=self.context)
self._top.hybridize(static_alloc=True, static_shape=True)
else:
mjpost marked this conversation as resolved.
Show resolved Hide resolved
if self.skip_topk:
self._top = partial(utils.top1, offset=self.offset) # type: Callable
else:
self._top = partial(utils.topk,
k=self.beam_size,
offset=self.offset,
use_mxnet_topk=True) # type: Callable

self._sort_by_index = SortByIndex()
self._sort_by_index.initialize(ctx=self.context)
Expand Down