Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Fixed tensor2tensor language modeling decode #1282

Merged
merged 11 commits into from
Jan 11, 2019
86 changes: 79 additions & 7 deletions tensor2tensor/utils/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,13 +407,18 @@ def input_fn(params):
return dataset
else:
def input_fn():
input_gen = _decode_batch_input_fn(
num_decode_batches, sorted_inputs,
inputs_vocab, decode_hp.batch_size,
decode_hp.max_input_size, task_id=decode_hp.multiproblem_task_id)
gen_fn = make_input_fn_from_generator(input_gen)
example = gen_fn()
return _decode_input_tensor_to_features_dict(example, hparams)
if has_input:
input_gen = _decode_batch_input_fn(
num_decode_batches, sorted_inputs,
inputs_vocab, decode_hp.batch_size,
decode_hp.max_input_size, task_id=decode_hp.multiproblem_task_id)
else:
input_gen = _decode_batch_input_fn_no_padding(sorted_inputs=sorted_inputs,max_batch_size=decode_hp.batch_size,
vocabulary=inputs_vocab,max_input_size=decode_hp.max_input_size,
decode_hp=decode_hp)
gen_fn = make_input_fn_from_generator(input_gen)
example = gen_fn()
return _decode_input_tensor_to_features_dict(example, hparams)
decodes = []
result_iter = estimator.predict(input_fn, checkpoint_path=checkpoint_path)

Expand Down Expand Up @@ -643,6 +648,73 @@ def _decode_batch_input_fn(num_decode_batches, sorted_inputs, vocabulary,
"inputs": np.array(final_batch_inputs).astype(np.int32),
}

def _decode_batch_input_fn_no_padding(sorted_inputs, max_batch_size, vocabulary, max_input_size, decode_hp):
"""Generator to produce batches of same length inputs (batch size will be variable)."""

# First reverse all the input sentences so that if you're going to get OOMs,
# you'll see it in the first batch
sorted_inputs.reverse()

#Get variable batch sizes
last_batch_length=None
batch_lengths, batch_indicies = [],[]
for batch_index,elm in enumerate(sorted_inputs):
#Exclude whitespace and empty strings from batch length.
this_batch_length=len(elm.split(' '))
if max_input_size>0:
if this_batch_length>max_input_size:
this_batch_length=max_input_size
if this_batch_length!=last_batch_length:
batch_lengths.append(this_batch_length)
batch_indicies.append(batch_index)
last_batch_length = this_batch_length
batch_indicies.append(len(sorted_inputs))

#Ensure no batches exceed the maximum batch_size
batch_sizes = np.diff(batch_indicies)
final_batch_sizes = []
final_batch_lengths = []
for ii,bs in enumerate(batch_sizes):
if bs<max_batch_size:
final_batch_sizes.append(bs)
final_batch_lengths.append(batch_lengths[ii])
else:
full_batches = bs//max_batch_size
partial_batch= bs%max_batch_size
for _ in range(full_batches):
final_batch_sizes.append(max_batch_size)
final_batch_lengths.append(batch_lengths[ii])
if partial_batch>0:
final_batch_sizes.append(partial_batch)
final_batch_lengths.append(batch_lengths[ii])

#Continue with now variable batch sizes, no need for padding.
last_index=0
for b,batch_size in enumerate(final_batch_sizes):
tf.logging.info("Decoding batch %d" % b)
# Batch length should be the same for the entire batch -- Add one additional term for <EOS> token insertion (opt)
batch_length = min(max_input_size,final_batch_lengths[b]) + 1
batch_inputs = []
for inputs in sorted_inputs[last_index:last_index+batch_size]:
input_ids = vocabulary.encode(inputs)
if max_input_size>0:
#For language modeling problems, more recent inputs are often more important.
input_ids = input_ids[-max_input_size:]
#Padding and <EOS> removed -- for language modeling problems.
batch_inputs.append(input_ids)
last_index+=batch_size

final_batch_inputs = []
#Ensure consistent batch size
for in_ids in batch_inputs:
assert len(in_ids) == batch_length
x=in_ids
final_batch_inputs.append(x)

yield {
"inputs": np.array(final_batch_inputs).astype(np.int32),
}


def _interactive_input_fn(hparams, decode_hp):
"""Generator that reads from the terminal and yields "interactive inputs".
Expand Down