Skip to content

Commit

Permalink
fix bug in top_k
Browse files Browse the repository at this point in the history
  • Loading branch information
lipiji committed Apr 6, 2021
1 parent 15760ed commit 4771cad
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions polish.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def init_model(m_path, device, vocab):
lm_model.eval()
return lm_model, lm_vocab, lm_args

m_path = "./model/tmp.ckpt"
m_path = "./model/songci.ckpt"
lm_model, lm_vocab, lm_args = init_model(m_path, gpu, "./model/vocab.txt")


Expand Down Expand Up @@ -86,7 +86,7 @@ def top_k(enc, src_padding_mask, inp_ys_tpl, inp_ys_seg, inp_ys_pos, s):
next_tk = []
for i in range(len(s)):
ctk = lm_vocab.idx2token(inp_ys_tpl[l,i].item())
if ctk != "<c1>":
if ctk != "<c1>" and ctk != "<c2>" and ctk != "<c0>":
next_tk.append(ctk)
continue
logits = probs[len(s[i]) - 1, i]
Expand Down
2 changes: 1 addition & 1 deletion test.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def top_k(enc, src_padding_mask, inp_ys_tpl, inp_ys_seg, inp_ys_pos, s):
next_tk = []
for i in range(len(s)):
ctk = lm_vocab.idx2token(inp_ys_tpl[l,i].item())
if ctk != "<c1>":
if ctk != "<c1>" and ctk != "<c2>" and ctk != "<c0>":
next_tk.append(ctk)
continue
logits = probs[len(s[i]) - 1, i]
Expand Down

0 comments on commit 4771cad

Please sign in to comment.