diff --git a/stanza/models/common/seq2seq_model.py b/stanza/models/common/seq2seq_model.py index 78d799ae54..4e05e78a1a 100644 --- a/stanza/models/common/seq2seq_model.py +++ b/stanza/models/common/seq2seq_model.py @@ -163,8 +163,61 @@ def get_log_prob(self, logits): return log_probs return log_probs.view(logits.size(0), logits.size(1), logits.size(2)) + def predict_greedy(self, src, src_mask, pos=None, beam_size=1): + """ Predict with greedy decoding. """ + enc_inputs = self.embedding(src) + batch_size = enc_inputs.size(0) + if self.use_pos: + assert pos is not None, "Missing POS input for seq2seq lemmatizer." + pos_inputs = self.pos_drop(self.pos_embedding(pos)) + enc_inputs = torch.cat([pos_inputs.unsqueeze(1), enc_inputs], dim=1) + pos_src_mask = src_mask.new_zeros([batch_size, 1]) + src_mask = torch.cat([pos_src_mask, src_mask], dim=1) + src_lens = list(src_mask.data.eq(constant.PAD_ID).long().sum(1)) + + # encode source + h_in, (hn, cn) = self.encode(enc_inputs, src_lens) + ## add pos-aware transformation to hn + #if self.use_pos: + # assert pos is not None + # pos_inputs = self.pos_embedding(pos) + # hn = self.enc2dec(torch.cat([hn, pos_inputs], dim=1)) + + if self.edit: + edit_logits = self.edit_clf(hn) + else: + edit_logits = None + + # greedy decode by step + dec_inputs = self.embedding(self.SOS_tensor) + dec_inputs = dec_inputs.expand(batch_size, dec_inputs.size(0), dec_inputs.size(1)) + + done = [False for _ in range(batch_size)] + total_done = 0 + max_len = 0 + output_seqs = [[] for _ in range(batch_size)] + + while total_done < batch_size and max_len < self.max_dec_len: + log_probs, (hn, cn) = self.decode(dec_inputs, hn, cn, h_in, src_mask) + assert log_probs.size(1) == 1, "Output must have 1-step of output." + _, preds = log_probs.squeeze(1).max(1, keepdim=True) + dec_inputs = self.embedding(preds) # update decoder inputs + max_len += 1 + for i in range(batch_size): + if not done[i]: + token = preds.data[i][0] + if token == constant.EOS_ID: + done[i] == True + total_done += 1 + else: + output_seqs[i].append(token) + return output_seqs, edit_logits + def predict(self, src, src_mask, pos=None, beam_size=5): """ Predict with beam search. """ + if beam_size == 1: + return self.predict_greedy(src, src_mask, pos=pos) + enc_inputs = self.embedding(src) batch_size = enc_inputs.size(0) if self.use_pos: