Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add random sampling as alternative to beam search. #1157

Merged
merged 16 commits into from
Jan 7, 2019
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fixed a few small errors
  • Loading branch information
daphnei committed Jan 3, 2019
commit efb3b586e1318de71150892d9218aac9c63909ab
18 changes: 7 additions & 11 deletions onmt/translate/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,15 +290,14 @@ def translate(
codecs.open(self.dump_beam, 'w', 'utf-8'))
return all_scores, all_predictions

def sample_with_temperature(self, logits, sampling_temp, restrict_to_topk):
def sample_with_temperature(self, logits, sampling_temp, keep_topk):
if sampling_temp == 0.0:
# To avoid divide-by-zero errors, just take the argmax.
topk_scores, topk_ids = logits.topk(1, dim=-1)
else:
logits = torch.div(logits, sampling_temp)

top_values, top_indices = torch.topk(
logits, restrict_to_topk, dim=1)
top_values, top_indices = torch.topk(logits, keep_topk, dim=1)
kth_best = top_values[:, -1].view([-1, 1])
kth_best = kth_best.repeat([1, logits.shape[1]])
kth_best = kth_best.type(torch.cuda.FloatTensor)
Expand Down Expand Up @@ -329,10 +328,8 @@ def _translate_random_sampling(
assert self.beam_size == 1

# TODO: support these blacklisted features.
assert not self.dump_beam
assert not self.use_filter_pred
assert self.block_ngram_repeat == 0
assert self.global_scorer.beta == 0

batch_size = batch.batch_size
vocab = self.fields["tgt"].vocab
Expand All @@ -347,7 +344,7 @@ def _translate_random_sampling(
use_src_map = data.data_type == 'text' and self.copy_attn

results = {}
results["predictions"] = [[] for _ in range(batch_size)] # noqa: F812
results["predictions"] = [[] * range(batch_size)] # noqa: F812
results["scores"] = [[] for _ in range(batch_size)] # noqa: F812
results["attention"] = [[] for _ in range(batch_size)] # noqa: F812
results["batch"] = batch
Expand Down Expand Up @@ -393,7 +390,7 @@ def _translate_random_sampling(
if step < min_length:
log_probs[:, end_token] = -1e20

# Note that what this code calls log_probs are actual logits.
# Note that what this code calls log_probs are actually logits.
topk_ids, topk_scores = self.sample_with_temperature(
log_probs, sampling_temp, keep_topk)

Expand All @@ -417,13 +414,12 @@ def _translate_random_sampling(
# there will only ever be 1 hypothesis per example.
score = topk_scores[i, 0]
pred = predictions[i, 0, 1:] # Ignore start_token.
mem_length = memory_lengths[i]
attn = (attention[:, i, 0, :mem_length]
if attention is not None else None)
m_len = memory_lengths[i]
attn = attention[:, i, 0, :m_len] if attention is not None else []

results["scores"][i].append(score)
results["predictions"][i].append(pred)
results["attention"][i].append(attn if attn is not None else [])
results["attention"][i].append(attn)

return results

Expand Down