Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scoring #538

Merged
merged 60 commits into from
Sep 28, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
4fc041f
made validation source optional, parameterized fill_up
mjpost Sep 17, 2018
af1e7bd
moved training-specific checks
mjpost Sep 17, 2018
1322b31
added scoring up to generating outputs, almost finished
mjpost Sep 17, 2018
27dd1cb
works but not polished
mjpost Sep 18, 2018
7e8767a
added length penalty as command-line arguments
mjpost Sep 18, 2018
29ac7a4
added 'repeat_last' fillup strategy (wrong approach)
mjpost Sep 18, 2018
d468770
added zero fill_up strategy, no_permute on batch iterator
mjpost Sep 18, 2018
359bc60
print source sentence with generalized ids2tokens()
mjpost Sep 18, 2018
ebd4e10
fixed test cases
mjpost Sep 18, 2018
8994f5c
Merge branch 'master' into scoring
mjpost Sep 18, 2018
0079c71
style checks
mjpost Sep 18, 2018
210a123
moved training-specific check to training
mjpost Sep 18, 2018
5015d54
context
mjpost Sep 18, 2018
48d00cc
turned on bucketing
mjpost Sep 18, 2018
927b92b
turned off more data permuting
mjpost Sep 18, 2018
581805e
set batch size to reasonable default for fully-unrolled graph
mjpost Sep 19, 2018
aa417dd
pulled out bucketing args, removed options from scoring
mjpost Sep 19, 2018
28a5e4b
added --score-type and --output
mjpost Sep 19, 2018
ddd5d88
Merge branch 'master' into scoring
mjpost Sep 19, 2018
307e56a
documentation
mjpost Sep 19, 2018
98bb1f1
style check
mjpost Sep 19, 2018
c1f4035
versioning
mjpost Sep 19, 2018
6118643
fixed easy issues from @fhieber's CR
mjpost Sep 20, 2018
16864dd
Merge remote-tracking branch 'github/master' into scoring
mjpost Sep 20, 2018
6a73883
removed batch preparation
mjpost Sep 20, 2018
bd2a526
moved summing to the computation graph
mjpost Sep 20, 2018
717f18b
now using output_handler
mjpost Sep 20, 2018
fa1d9b8
added warning if sentences get skipped
mjpost Sep 20, 2018
47ddceb
typo
mjpost Sep 20, 2018
6f1c4e8
cleanup and docs
mjpost Sep 20, 2018
ff324fb
documentation and cleanup
mjpost Sep 21, 2018
8805a62
Simplified fill_up, improved comments
mjpost Sep 22, 2018
aef30ae
bugfix with length penalty alpha = 0
mjpost Sep 22, 2018
0ff7feb
bugfixes in names
mjpost Sep 22, 2018
39294ac
combined run_forward() and get_outputs()
mjpost Sep 22, 2018
bbd8c13
uncontroversial reversions
mjpost Sep 24, 2018
9f0a71c
Merge remote-tracking branch 'github/master' into scoring
mjpost Sep 25, 2018
ef6b9f3
stuffing args for scoring, no_permute renamed permute
mjpost Sep 25, 2018
9e1736d
added --output and --softmax-temperature
mjpost Sep 25, 2018
5269a9c
changed float width to 3
mjpost Sep 25, 2018
09c46c2
removed prepare data option, reverted train.py
mjpost Sep 25, 2018
fb601a3
added test cases, bugfix with source factors
mjpost Sep 25, 2018
e049fea
bugfix — get first item from group
mjpost Sep 25, 2018
5f950be
style fixes
mjpost Sep 25, 2018
3c9401d
style changes
mjpost Sep 25, 2018
7c48db9
missed one
mjpost Sep 25, 2018
68f943c
don't score when --skip-topk
mjpost Sep 25, 2018
7db9bef
debugging travis
mjpost Sep 25, 2018
d67488b
seq max seq len very large to pass test
mjpost Sep 25, 2018
d834a7d
reading maxlen from config and skipping some test lines
mjpost Sep 25, 2018
b731159
debugging output since still failing
mjpost Sep 25, 2018
8abcf1b
skipping test outputs with vocab symbols
mjpost Sep 26, 2018
3ae7d7e
debugging travis
mjpost Sep 26, 2018
f39ba62
more systematic testing for when to try to score
mjpost Sep 26, 2018
776fcbb
don't score when translate beam == 1 or length close to max
mjpost Sep 26, 2018
e9808d2
restored skipping topk
mjpost Sep 26, 2018
065abc1
updated documentation
mjpost Sep 26, 2018
b3dd468
entry point for sockeye.score
mjpost Sep 27, 2018
b23bb01
Merge branch 'master' into scoring
fhieber Sep 27, 2018
2da12bc
proper spacing
fhieber Sep 28, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
pulled out bucketing args, removed options from scoring
  • Loading branch information
mjpost committed Sep 19, 2018
commit aa417ddd82a8e90d8525f4f723e1e23b6f51acb3
1 change: 0 additions & 1 deletion sockeye/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1089,7 +1089,6 @@ def add_score_cli_args(params):
add_training_data_args(params, required=False)
add_prepared_data_args(params)
add_vocab_args(params)
add_bucketing_args(params)
add_scoring_args(params)
add_device_args(params)
mjpost marked this conversation as resolved.
Show resolved Hide resolved
add_logging_args(params)
Expand Down
20 changes: 7 additions & 13 deletions sockeye/score.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def create_scoring_model(config: model.ModelConfig,
context=context,
provide_data=score_iter.provide_data,
default_bucket_key=score_iter.default_bucket_key,
bucketing=bucketing)
bucketing=False)

return scoring_model

Expand All @@ -97,13 +97,6 @@ def score(args: argparse.Namespace):

utils.log_basic_info(args)

max_seq_len_source, max_seq_len_target = args.max_seq_len
# The maximum length is the length before we add the BOS/EOS symbols
max_seq_len_source = max_seq_len_source + C.SPACE_FOR_XOS
max_seq_len_target = max_seq_len_target + C.SPACE_FOR_XOS
logger.info("Adjusting maximum length to reserve space for a BOS/EOS marker. New maximum length: (%d, %d)",
max_seq_len_source, max_seq_len_target)

with ExitStack() as exit_stack:
context = utils.determine_context(device_ids=args.device_ids,
use_cpu=args.use_cpu,
Expand All @@ -116,25 +109,26 @@ def score(args: argparse.Namespace):
"size that is a multiple of %d." % len(context))
logger.info("Scoring Device(s): %s", ", ".join(str(c) for c in context))

model_config = model.SockeyeModel.load_config(os.path.join(args.model, C.CONFIG_NAME))
max_seq_len_source = model_config.config_data.max_seq_len_source
max_seq_len_target = model_config.config_data.max_seq_len_target

score_iter, _, config_data, source_vocabs, target_vocab, data_info = train.create_data_iters_and_vocabs(
args=args,
mjpost marked this conversation as resolved.
Show resolved Hide resolved
max_seq_len_source=max_seq_len_source,
max_seq_len_target=max_seq_len_target,
shared_vocab=args.shared_vocab,
resume_training=True,
mjpost marked this conversation as resolved.
Show resolved Hide resolved
output_folder=args.model,
bucketing=False,
mjpost marked this conversation as resolved.
Show resolved Hide resolved
fill_up='zeros',
no_permute=True)

max_seq_len_source = config_data.max_seq_len_source
max_seq_len_target = config_data.max_seq_len_target

model_config = model.SockeyeModel.load_config(os.path.join(args.model, C.CONFIG_NAME))

mjpost marked this conversation as resolved.
Show resolved Hide resolved
scoring_model = create_scoring_model(config=model_config,
model_dir=args.model,
context=context,
bucketing=not args.no_bucketing,
bucketing=False,
score_iter=score_iter)

scorer = scoring.Scorer(scoring_model, source_vocabs, target_vocab,
Expand Down
8 changes: 6 additions & 2 deletions sockeye/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,8 @@ def create_data_iters_and_vocabs(args: argparse.Namespace,
validation_sources: Optional[List[str]] = None,
mjpost marked this conversation as resolved.
Show resolved Hide resolved
validation_target: Optional[str] = None,
output_folder: Optional[str] = None,
bucketing: bool = True,
bucket_width: int = 10,
fill_up: str = C.DEFAULT_FILL_UP_STRATEGY,
no_permute: bool = False) -> Tuple['data_io.BaseParallelSampleIter',
'data_io.BaseParallelSampleIter',
Expand Down Expand Up @@ -349,8 +351,8 @@ def create_data_iters_and_vocabs(args: argparse.Namespace,
no_permute=no_permute,
max_seq_len_source=max_seq_len_source,
max_seq_len_target=max_seq_len_target,
bucketing=not args.no_bucketing,
bucket_width=args.bucket_width)
bucketing=bucketing,
bucket_width=bucket_width)

return train_iter, validation_iter, config_data, source_vocabs, target_vocab, data_info

Expand Down Expand Up @@ -812,6 +814,8 @@ def train(args: argparse.Namespace):
validation_sources=[args.validation_source] + args.validation_source_factors,
validation_target=args.validation_target,
output_folder=output_folder,
bucketing=not args.no_bucketing,
bucket_width=args.bucket_width,
fill_up=args.fill_up)
max_seq_len_source = config_data.max_seq_len_source
max_seq_len_target = config_data.max_seq_len_target
Expand Down