Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scoring #538

Merged
merged 60 commits into from
Sep 28, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
4fc041f
made validation source optional, parameterized fill_up
mjpost Sep 17, 2018
af1e7bd
moved training-specific checks
mjpost Sep 17, 2018
1322b31
added scoring up to generating outputs, almost finished
mjpost Sep 17, 2018
27dd1cb
works but not polished
mjpost Sep 18, 2018
7e8767a
added length penalty as command-line arguments
mjpost Sep 18, 2018
29ac7a4
added 'repeat_last' fillup strategy (wrong approach)
mjpost Sep 18, 2018
d468770
added zero fill_up strategy, no_permute on batch iterator
mjpost Sep 18, 2018
359bc60
print source sentence with generalized ids2tokens()
mjpost Sep 18, 2018
ebd4e10
fixed test cases
mjpost Sep 18, 2018
8994f5c
Merge branch 'master' into scoring
mjpost Sep 18, 2018
0079c71
style checks
mjpost Sep 18, 2018
210a123
moved training-specific check to training
mjpost Sep 18, 2018
5015d54
context
mjpost Sep 18, 2018
48d00cc
turned on bucketing
mjpost Sep 18, 2018
927b92b
turned off more data permuting
mjpost Sep 18, 2018
581805e
set batch size to reasonable default for fully-unrolled graph
mjpost Sep 19, 2018
aa417dd
pulled out bucketing args, removed options from scoring
mjpost Sep 19, 2018
28a5e4b
added --score-type and --output
mjpost Sep 19, 2018
ddd5d88
Merge branch 'master' into scoring
mjpost Sep 19, 2018
307e56a
documentation
mjpost Sep 19, 2018
98bb1f1
style check
mjpost Sep 19, 2018
c1f4035
versioning
mjpost Sep 19, 2018
6118643
fixed easy issues from @fhieber's CR
mjpost Sep 20, 2018
16864dd
Merge remote-tracking branch 'github/master' into scoring
mjpost Sep 20, 2018
6a73883
removed batch preparation
mjpost Sep 20, 2018
bd2a526
moved summing to the computation graph
mjpost Sep 20, 2018
717f18b
now using output_handler
mjpost Sep 20, 2018
fa1d9b8
added warning if sentences get skipped
mjpost Sep 20, 2018
47ddceb
typo
mjpost Sep 20, 2018
6f1c4e8
cleanup and docs
mjpost Sep 20, 2018
ff324fb
documentation and cleanup
mjpost Sep 21, 2018
8805a62
Simplified fill_up, improved comments
mjpost Sep 22, 2018
aef30ae
bugfix with length penalty alpha = 0
mjpost Sep 22, 2018
0ff7feb
bugfixes in names
mjpost Sep 22, 2018
39294ac
combined run_forward() and get_outputs()
mjpost Sep 22, 2018
bbd8c13
uncontroversial reversions
mjpost Sep 24, 2018
9f0a71c
Merge remote-tracking branch 'github/master' into scoring
mjpost Sep 25, 2018
ef6b9f3
stuffing args for scoring, no_permute renamed permute
mjpost Sep 25, 2018
9e1736d
added --output and --softmax-temperature
mjpost Sep 25, 2018
5269a9c
changed float width to 3
mjpost Sep 25, 2018
09c46c2
removed prepare data option, reverted train.py
mjpost Sep 25, 2018
fb601a3
added test cases, bugfix with source factors
mjpost Sep 25, 2018
e049fea
bugfix — get first item from group
mjpost Sep 25, 2018
5f950be
style fixes
mjpost Sep 25, 2018
3c9401d
style changes
mjpost Sep 25, 2018
7c48db9
missed one
mjpost Sep 25, 2018
68f943c
don't score when --skip-topk
mjpost Sep 25, 2018
7db9bef
debugging travis
mjpost Sep 25, 2018
d67488b
seq max seq len very large to pass test
mjpost Sep 25, 2018
d834a7d
reading maxlen from config and skipping some test lines
mjpost Sep 25, 2018
b731159
debugging output since still failing
mjpost Sep 25, 2018
8abcf1b
skipping test outputs with vocab symbols
mjpost Sep 26, 2018
3ae7d7e
debugging travis
mjpost Sep 26, 2018
f39ba62
more systematic testing for when to try to score
mjpost Sep 26, 2018
776fcbb
don't score when translate beam == 1 or length close to max
mjpost Sep 26, 2018
e9808d2
restored skipping topk
mjpost Sep 26, 2018
065abc1
updated documentation
mjpost Sep 26, 2018
b3dd468
entry point for sockeye.score
mjpost Sep 27, 2018
b23bb01
Merge branch 'master' into scoring
fhieber Sep 27, 2018
2da12bc
proper spacing
fhieber Sep 28, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
turned on bucketing
  • Loading branch information
mjpost committed Sep 18, 2018
commit 48d00ccd75a8057c360c310ab9ff1194ce442b77
7 changes: 5 additions & 2 deletions sockeye/score.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ def check_arg_compatibility(args: argparse.Namespace):
def create_scoring_model(config: model.ModelConfig,
model_dir: str,
context: List[mx.Context],
score_iter: data_io.BaseParallelSampleIter) -> scoring.ScoringModel:
score_iter: data_io.BaseParallelSampleIter,
bucketing: bool = False) -> scoring.ScoringModel:
"""
Create a scoring model and load the parameters from disk if needed.

Expand All @@ -76,7 +77,8 @@ def create_scoring_model(config: model.ModelConfig,
model_dir=model_dir,
context=context,
provide_data=score_iter.provide_data,
default_bucket_key=score_iter.default_bucket_key)
default_bucket_key=score_iter.default_bucket_key,
bucketing=bucketing)

return scoring_model

Expand Down Expand Up @@ -132,6 +134,7 @@ def score(args: argparse.Namespace):
scoring_model = create_scoring_model(config=model_config,
model_dir=args.model,
context=context,
bucketing=not args.no_bucketing,
score_iter=score_iter)

scorer = scoring.Scorer(scoring_model, source_vocabs, target_vocab,
Expand Down
27 changes: 19 additions & 8 deletions sockeye/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,11 @@ def __init__(self,
model_dir: str,
context: List[mx.context.Context],
provide_data: List[mx.io.DataDesc],
bucketing: bool,
default_bucket_key: Tuple[int, int]) -> None:
super().__init__(config)
self.context = context
self.bucketing = bucketing
self._initialize(provide_data, default_bucket_key)

params_fname = os.path.join(model_dir, C.PARAMS_BEST_NAME)
Expand Down Expand Up @@ -134,16 +136,24 @@ def sym_gen(seq_lens):
# return the outputs and the data names (we don't need the labels)
return outputs, data_names, None

logger.info("Using bucketing. Default max_seq_len=%s", default_bucket_key)
self.module = mx.mod.BucketingModule(sym_gen=sym_gen,
logger=logger,
default_bucket_key=default_bucket_key,
context=self.context)
if self.bucketing:
logger.info("Using bucketing. Default max_seq_len=%s", default_bucket_key)
self.module = mx.mod.BucketingModule(sym_gen=sym_gen,
logger=logger,
default_bucket_key=default_bucket_key,
context=self.context)
else:
symbol, _, __ = sym_gen(default_bucket_key)
self.module = mx.mod.Module(symbol=symbol,
data_names=data_names,
label_names=None,
logger=logger,
context=self.context)

self.module.bind(data_shapes=provide_data,
label_shapes=None,
for_training=False,
force_rebind=True,
force_rebind=False,
grad_req=None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think grad_req should be set to 'null', not None. This could save you memory :)
https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/module/executor_group.py#L182



Expand Down Expand Up @@ -207,9 +217,10 @@ def score(self,
score_iter):

for i, batch in enumerate(score_iter):
# data_io generates labels, too, which we don't need
label, batch.provide_label = batch.provide_label, None
# data_io generates labels, too, which aren't needed in the computation graph
batch.provide_label = None
labels = batch.label[0].as_in_context(self.model.context[0])
batch.label = None
self.model.prepare_batch(batch)
mjpost marked this conversation as resolved.
Show resolved Hide resolved
self.model.run_forward(batch)
outputs = self.model.get_outputs()
Expand Down