Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scoring #538

Merged
merged 60 commits into from
Sep 28, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
4fc041f
made validation source optional, parameterized fill_up
mjpost Sep 17, 2018
af1e7bd
moved training-specific checks
mjpost Sep 17, 2018
1322b31
added scoring up to generating outputs, almost finished
mjpost Sep 17, 2018
27dd1cb
works but not polished
mjpost Sep 18, 2018
7e8767a
added length penalty as command-line arguments
mjpost Sep 18, 2018
29ac7a4
added 'repeat_last' fillup strategy (wrong approach)
mjpost Sep 18, 2018
d468770
added zero fill_up strategy, no_permute on batch iterator
mjpost Sep 18, 2018
359bc60
print source sentence with generalized ids2tokens()
mjpost Sep 18, 2018
ebd4e10
fixed test cases
mjpost Sep 18, 2018
8994f5c
Merge branch 'master' into scoring
mjpost Sep 18, 2018
0079c71
style checks
mjpost Sep 18, 2018
210a123
moved training-specific check to training
mjpost Sep 18, 2018
5015d54
context
mjpost Sep 18, 2018
48d00cc
turned on bucketing
mjpost Sep 18, 2018
927b92b
turned off more data permuting
mjpost Sep 18, 2018
581805e
set batch size to reasonable default for fully-unrolled graph
mjpost Sep 19, 2018
aa417dd
pulled out bucketing args, removed options from scoring
mjpost Sep 19, 2018
28a5e4b
added --score-type and --output
mjpost Sep 19, 2018
ddd5d88
Merge branch 'master' into scoring
mjpost Sep 19, 2018
307e56a
documentation
mjpost Sep 19, 2018
98bb1f1
style check
mjpost Sep 19, 2018
c1f4035
versioning
mjpost Sep 19, 2018
6118643
fixed easy issues from @fhieber's CR
mjpost Sep 20, 2018
16864dd
Merge remote-tracking branch 'github/master' into scoring
mjpost Sep 20, 2018
6a73883
removed batch preparation
mjpost Sep 20, 2018
bd2a526
moved summing to the computation graph
mjpost Sep 20, 2018
717f18b
now using output_handler
mjpost Sep 20, 2018
fa1d9b8
added warning if sentences get skipped
mjpost Sep 20, 2018
47ddceb
typo
mjpost Sep 20, 2018
6f1c4e8
cleanup and docs
mjpost Sep 20, 2018
ff324fb
documentation and cleanup
mjpost Sep 21, 2018
8805a62
Simplified fill_up, improved comments
mjpost Sep 22, 2018
aef30ae
bugfix with length penalty alpha = 0
mjpost Sep 22, 2018
0ff7feb
bugfixes in names
mjpost Sep 22, 2018
39294ac
combined run_forward() and get_outputs()
mjpost Sep 22, 2018
bbd8c13
uncontroversial reversions
mjpost Sep 24, 2018
9f0a71c
Merge remote-tracking branch 'github/master' into scoring
mjpost Sep 25, 2018
ef6b9f3
stuffing args for scoring, no_permute renamed permute
mjpost Sep 25, 2018
9e1736d
added --output and --softmax-temperature
mjpost Sep 25, 2018
5269a9c
changed float width to 3
mjpost Sep 25, 2018
09c46c2
removed prepare data option, reverted train.py
mjpost Sep 25, 2018
fb601a3
added test cases, bugfix with source factors
mjpost Sep 25, 2018
e049fea
bugfix — get first item from group
mjpost Sep 25, 2018
5f950be
style fixes
mjpost Sep 25, 2018
3c9401d
style changes
mjpost Sep 25, 2018
7c48db9
missed one
mjpost Sep 25, 2018
68f943c
don't score when --skip-topk
mjpost Sep 25, 2018
7db9bef
debugging travis
mjpost Sep 25, 2018
d67488b
seq max seq len very large to pass test
mjpost Sep 25, 2018
d834a7d
reading maxlen from config and skipping some test lines
mjpost Sep 25, 2018
b731159
debugging output since still failing
mjpost Sep 25, 2018
8abcf1b
skipping test outputs with vocab symbols
mjpost Sep 26, 2018
3ae7d7e
debugging travis
mjpost Sep 26, 2018
f39ba62
more systematic testing for when to try to score
mjpost Sep 26, 2018
776fcbb
don't score when translate beam == 1 or length close to max
mjpost Sep 26, 2018
e9808d2
restored skipping topk
mjpost Sep 26, 2018
065abc1
updated documentation
mjpost Sep 26, 2018
b3dd468
entry point for sockeye.score
mjpost Sep 27, 2018
b23bb01
Merge branch 'master' into scoring
fhieber Sep 27, 2018
2da12bc
proper spacing
fhieber Sep 28, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
documentation and cleanup
  • Loading branch information
mjpost committed Sep 21, 2018
commit ff324fbdb324e4a1b8b931a9f19ee3b9b77b0b43
7 changes: 4 additions & 3 deletions sockeye/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,7 +789,8 @@ def add_training_args(params):
'Default: %(default)s.')
train_params.add_argument('--fill-up',
type=str,
default=C.DEFAULT_FILL_UP_STRATEGY,
default=C.FILL_UP_DEFAULT,
choices=C.FILL_UP_CHOICES,
help=argparse.SUPPRESS)

train_params.add_argument('--loss',
Expand Down Expand Up @@ -1100,14 +1101,14 @@ def add_score_cli_args(params):
params.add_argument('--length-penalty-alpha',
default=1.0,
type=float,
help='Alpha factor for the length penalty used in beam search: '
help='Alpha factor for the length penalty used in scoring: '
'(beta + len(Y))**alpha/(beta + 1)**alpha. A value of 0.0 will therefore turn off '
'length normalization. Default: %(default)s')

params.add_argument('--length-penalty-beta',
default=0.0,
type=float,
help='Beta factor for the length penalty used in beam search: '
help='Beta factor for the length penalty used in scoring: '
'(beta + len(Y))**alpha/(beta + 1)**alpha. Default: %(default)s')

params.add_argument('--output-type',
Expand Down
6 changes: 5 additions & 1 deletion sockeye/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,11 @@
DATA_CONFIG = "data.config"
PREPARED_DATA_VERSION_FILE = "data.version"
PREPARED_DATA_VERSION = 2
DEFAULT_FILL_UP_STRATEGY = 'replicate'

FILL_UP_REPLICATE = 'replicate'
FILL_UP_ZEROS = 'zeros'
FILL_UP_DEFAULT = FILL_UP_REPLICATE
FILL_UP_CHOICES = [FILL_UP_REPLICATE, FILL_UP_ZEROS]

# reranking
RERANK_BLEU = "bleu"
Expand Down
40 changes: 22 additions & 18 deletions sockeye/data_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,7 +769,7 @@ def get_training_data_iters(sources: List[str],
bucketing: bool,
bucket_width: int,
no_permute: bool = False) -> Tuple['BaseParallelSampleIter',
'BaseParallelSampleIter',
Optional['BaseParallelSampleIter'],
'DataConfig', 'DataInfo']:
"""
Returns data iterators for training and validation data.
Expand Down Expand Up @@ -1027,18 +1027,18 @@ def ids2strids(ids: Iterable[int]) -> str:

def ids2tokens(token_ids: Iterable[int],
mjpost marked this conversation as resolved.
Show resolved Hide resolved
vocab_inv: Dict[int, str],
exclude_list: Set[int] = set()) -> List[str]:
exclude_set: Set[int] = set()) -> List[str]:
mjpost marked this conversation as resolved.
Show resolved Hide resolved
"""
Transforms a list of token IDs into a list of words, exluding any IDs in `exclude_list`.
Transforms a list of token IDs into a list of words, exluding any IDs in `exclude_set`.

:param token_ids: The list of token IDs.
:param vocab_inv: The inverse vocabulary.
:param exclude_list: The list of token IDs to exclude.
:param exclude_set: The list of token IDs to exclude.
:return: The list of words.
"""

tokens = [vocab_inv[token] for token in token_ids]
return [tok for token_id, tok in zip(token_ids, tokens) if token_id not in exclude_list]
return [tok for token_id, tok in zip(token_ids, tokens) if token_id not in exclude_set]


class SequenceReader(Iterable):
Expand Down Expand Up @@ -1283,16 +1283,19 @@ def fill_up(self,
# 'zeros' instead repeats the last element and then writes zeros over everything.
if num_samples % bucket_batch_size != 0:
rest = bucket_batch_size - num_samples % bucket_batch_size
if fill_up == 'replicate':
if fill_up == C.FILL_UP_REPLICATE:
logger.info("Filling bucket %s from size %d to %d by sampling with replacement",
mjpost marked this conversation as resolved.
Show resolved Hide resolved
bucket, num_samples, bucket_batch_size)
desired_indices_np = rs.randint(num_samples, size=rest)
desired_indices = mx.nd.array(desired_indices_np)

elif fill_up == 'zeros':
logger.info("Filling bucket %s from size %d to %d by repeating the last element %d %s",
bucket, num_samples, bucket_batch_size, rest, inflect('time', rest))
desired_indices_np = np.array([num_samples-1] * rest)
elif fill_up == C.FILL_UP_ZEROS:
logger.info("Filling bucket %s from size %d to %d with zeros",
bucket, num_samples, bucket_batch_size)
desired_indices_np = np.full((rest), num_samples - 1)
# data_source = [np.full((num_samples, source_len, num_factors), self.pad_id, dtype=self.dtype)
# for (source_len, target_len), num_samples in zip(self.buckets, num_samples_per_bucket)]

desired_indices = mx.nd.array(desired_indices_np)

else:
Expand All @@ -1306,9 +1309,9 @@ def fill_up(self,
label[bucket_idx] = mx.nd.concat(bucket_label, bucket_label.take(desired_indices), dim=0)

if fill_up == 'zeros':
source[bucket_idx][num_samples:,:,:] = C.PAD_ID
target[bucket_idx][num_samples:,:] = C.PAD_ID
label[bucket_idx][num_samples:,:] = C.PAD_ID
source[bucket_idx][num_samples:, :, :] = C.PAD_ID
target[bucket_idx][num_samples:, :] = C.PAD_ID
label[bucket_idx][num_samples:, :] = C.PAD_ID

return ParallelDataSet(source, target, label)

Expand Down Expand Up @@ -1391,6 +1394,8 @@ class MetaBaseParallelSampleIter(ABC):
class BaseParallelSampleIter(mx.io.DataIter):
"""
Base parallel sample iterator.

:param no_permute: Turn off random shuffling of parallel data.
"""
__metaclass__ = MetaBaseParallelSampleIter

Expand All @@ -1402,6 +1407,7 @@ def __init__(self,
target_data_name,
label_name,
num_factors: int = 1,
no_permute: bool = False,
dtype='float32') -> None:
super().__init__(batch_size=batch_size)

Expand All @@ -1412,6 +1418,7 @@ def __init__(self,
self.target_data_name = target_data_name
self.label_name = label_name
self.num_factors = num_factors
self.no_permute = no_permute
self.dtype = dtype

# "Staging area" that needs to fit any size batch we're using by total number of elements.
Expand Down Expand Up @@ -1478,12 +1485,11 @@ def __init__(self,
dtype='float32') -> None:
super().__init__(buckets=buckets, batch_size=batch_size, bucket_batch_sizes=bucket_batch_sizes,
source_data_name=source_data_name, target_data_name=target_data_name,
label_name=label_name, num_factors=num_factors, dtype=dtype)
label_name=label_name, num_factors=num_factors, no_permute=no_permute, dtype=dtype)
assert len(shards_fnames) > 0
self.shards_fnames = list(shards_fnames)
self.shard_index = -1
self.fill_up = fill_up
self.no_permute = no_permute

self.reset()

Expand Down Expand Up @@ -1572,7 +1578,7 @@ def __init__(self,
dtype='float32') -> None:
super().__init__(buckets=buckets, batch_size=batch_size, bucket_batch_sizes=bucket_batch_sizes,
source_data_name=source_data_name, target_data_name=target_data_name,
label_name=label_name, num_factors=num_factors, dtype=dtype)
label_name=label_name, num_factors=num_factors, no_permute=no_permute, dtype=dtype)

# create independent lists to be shuffled
self.data = ParallelDataSet(list(data.source), list(data.target), list(data.label))
Expand All @@ -1586,8 +1592,6 @@ def __init__(self,
self.data_permutations = [mx.nd.arange(0, max(1, self.data.source[i].shape[0]))
for i in range(len(self.data))]

self.no_permute = no_permute

self.reset()

def reset(self):
Expand Down
2 changes: 1 addition & 1 deletion sockeye/output_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def handle(self,

class PairWithScoreOutputHandler(OutputHandler):
"""
Output handler to write translation score along with sntence input and output (tab-delimited).
Output handler to write translation score along with sentence input and output (tab-delimited).

:param stream: Stream to write translations to (e.g., sys.stdout).
"""
Expand Down
5 changes: 5 additions & 0 deletions sockeye/score.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ def score(args: argparse.Namespace):
else:
max_seq_len_source, max_seq_len_target = args.max_seq_len

# This call has a number of different parameters compared to training which reflect our need to get scores
# one-for-one and in order with the input data.
# Bucketing and permuting need to be turned off in order to preserve the ordering of sentences.
# The 'zeros' fill_up strategy fills underfilled buckets with zeros which can then be used to find the last item.
# Finally, 'resume_training' needs to be set to True because it causes the model to be loaded instead of initialized.
score_iter, _, config_data, source_vocabs, target_vocab, data_info = train.create_data_iters_and_vocabs(
args=args,
mjpost marked this conversation as resolved.
Show resolved Hide resolved
max_seq_len_source=max_seq_len_source,
Expand Down
3 changes: 2 additions & 1 deletion sockeye/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ def score(self,
score_type: str,
output_handler: OutputHandler):

total_time = 0.
tic = time.time()
sentence_no = 0
for i, batch in enumerate(score_iter):
Expand All @@ -233,8 +234,8 @@ def score(self,
self.model.run_forward(batch)
scores, __ = self.model.get_outputs()

total_time = time.time() - tic
batch_time = time.time() - batch_tic
total_time += batch_time

for source, target, score in zip(batch.data[0], batch.data[1], scores):

Expand Down
2 changes: 1 addition & 1 deletion sockeye/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def create_data_iters_and_vocabs(args: argparse.Namespace,
output_folder: Optional[str] = None,
bucketing: bool = True,
bucket_width: int = 10,
fill_up: str = C.DEFAULT_FILL_UP_STRATEGY,
fill_up: str = C.FILL_UP_DEFAULT,
no_permute: bool = False) -> Tuple['data_io.BaseParallelSampleIter',
'data_io.BaseParallelSampleIter',
'data_io.DataConfig',
Expand Down