Skip to content

Commit

Permalink
Update generate_seq2seq.py
Browse files Browse the repository at this point in the history
  • Loading branch information
hhou435 committed Mar 8, 2021
1 parent 9e0c00f commit 7a5193b
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions scripts/generate_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def forward(self, src, seg, tgt):

infer_opts(parser)

parser.add_argument("--target", choices=["seq2seq","t5"], default="t5",
parser.add_argument("--target", choices=["seq2seq", "t5"], default="t5",
help="The training target of the pretraining model.")
parser.add_argument("--share_relative_position_embedding", action="store_true",
help="Add bias on output_layer for lm target.")
Expand Down Expand Up @@ -92,7 +92,7 @@ def forward(self, src, seg, tgt):

with open(args.test_path, mode="r", encoding="utf-8") as f:
line = f.readline().strip()
src = args.tokenizer.convert_tokens_to_ids([CLS_TOKEN]+args.tokenizer.tokenize(line)+[SEP_TOKEN])
src = args.tokenizer.convert_tokens_to_ids([CLS_TOKEN] + args.tokenizer.tokenize(line) + [SEP_TOKEN])
seg = [1] * len(src)
tgt = args.tokenizer.convert_tokens_to_ids([CLS_TOKEN])
beginning_length = len(src)
Expand All @@ -111,4 +111,4 @@ def forward(self, src, seg, tgt):

f.write(line + "\n")
generated_sentence = "".join(args.tgt_tokenizer.convert_ids_to_tokens([token_id.item() for token_id in tgt_tensor[0]]))
f.write(generated_sentence)
f.write(generated_sentence)

0 comments on commit 7a5193b

Please sign in to comment.