Skip to content

Commit

Permalink
Speedup the lemmatizer by using batched greedy decoding by default
Browse files Browse the repository at this point in the history
  • Loading branch information
qipeng committed Apr 16, 2020
1 parent 4341e4b commit 00cfa46
Showing 1 changed file with 53 additions and 0 deletions.
53 changes: 53 additions & 0 deletions stanza/models/common/seq2seq_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,61 @@ def get_log_prob(self, logits):
return log_probs
return log_probs.view(logits.size(0), logits.size(1), logits.size(2))

def predict_greedy(self, src, src_mask, pos=None, beam_size=1):
""" Predict with greedy decoding. """
enc_inputs = self.embedding(src)
batch_size = enc_inputs.size(0)
if self.use_pos:
assert pos is not None, "Missing POS input for seq2seq lemmatizer."
pos_inputs = self.pos_drop(self.pos_embedding(pos))
enc_inputs = torch.cat([pos_inputs.unsqueeze(1), enc_inputs], dim=1)
pos_src_mask = src_mask.new_zeros([batch_size, 1])
src_mask = torch.cat([pos_src_mask, src_mask], dim=1)
src_lens = list(src_mask.data.eq(constant.PAD_ID).long().sum(1))

# encode source
h_in, (hn, cn) = self.encode(enc_inputs, src_lens)
## add pos-aware transformation to hn
#if self.use_pos:
# assert pos is not None
# pos_inputs = self.pos_embedding(pos)
# hn = self.enc2dec(torch.cat([hn, pos_inputs], dim=1))

if self.edit:
edit_logits = self.edit_clf(hn)
else:
edit_logits = None

# greedy decode by step
dec_inputs = self.embedding(self.SOS_tensor)
dec_inputs = dec_inputs.expand(batch_size, dec_inputs.size(0), dec_inputs.size(1))

done = [False for _ in range(batch_size)]
total_done = 0
max_len = 0
output_seqs = [[] for _ in range(batch_size)]

while total_done < batch_size and max_len < self.max_dec_len:
log_probs, (hn, cn) = self.decode(dec_inputs, hn, cn, h_in, src_mask)
assert log_probs.size(1) == 1, "Output must have 1-step of output."
_, preds = log_probs.squeeze(1).max(1, keepdim=True)
dec_inputs = self.embedding(preds) # update decoder inputs
max_len += 1
for i in range(batch_size):
if not done[i]:
token = preds.data[i][0]
if token == constant.EOS_ID:
done[i] == True
total_done += 1
else:
output_seqs[i].append(token)
return output_seqs, edit_logits

def predict(self, src, src_mask, pos=None, beam_size=5):
""" Predict with beam search. """
if beam_size == 1:
return self.predict_greedy(src, src_mask, pos=pos)

enc_inputs = self.embedding(src)
batch_size = enc_inputs.size(0)
if self.use_pos:
Expand Down

0 comments on commit 00cfa46

Please sign in to comment.