Skip to content

Commit

Permalink
do not overwrite pt vocab when preprocessing again (OpenNMT#1447)
Browse files Browse the repository at this point in the history
  • Loading branch information
francoishernandez authored and vince62s committed May 24, 2019
1 parent 7f1fc81 commit 1b3cc33
Showing 1 changed file with 13 additions and 10 deletions.
23 changes: 13 additions & 10 deletions preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,20 +59,20 @@ def build_save_dataset(corpus_type, fields, src_reader, tgt_reader, opt):
filter_pred = None

if corpus_type == "train":
if opt.src_vocab:
existing_fields = None
if opt.src_vocab != "":
try:
logger.info("Using existing vocabulary...")
src_vocab = torch.load(opt.src_vocab)
existing_fields = torch.load(opt.src_vocab)
except torch.serialization.pickle.UnpicklingError:
logger.info("Building vocab from text file...")

src_vocab, src_vocab_size = _load_vocab(
opt.src_vocab, "src", counters,
opt.src_words_min_frequency)
else:
src_vocab = None

if opt.tgt_vocab:
if opt.tgt_vocab != "":
tgt_vocab, tgt_vocab_size = _load_vocab(
opt.tgt_vocab, "tgt", counters,
opt.tgt_words_min_frequency)
Expand All @@ -93,7 +93,7 @@ def build_save_dataset(corpus_type, fields, src_reader, tgt_reader, opt):
sort_key=inputters.str2sortkey[opt.data_type],
filter_pred=filter_pred
)
if corpus_type == "train":
if corpus_type == "train" and existing_fields is None:
for ex in dataset.examples:
for name, field in fields.items():
try:
Expand Down Expand Up @@ -129,12 +129,15 @@ def build_save_dataset(corpus_type, fields, src_reader, tgt_reader, opt):
gc.collect()

if corpus_type == "train":
fields = _build_fields_vocab(
fields, counters, opt.data_type,
opt.share_vocab, opt.vocab_size_multiple,
opt.src_vocab_size, opt.src_words_min_frequency,
opt.tgt_vocab_size, opt.tgt_words_min_frequency)
vocab_path = opt.save_data + '.vocab.pt'
if existing_fields is None:
fields = _build_fields_vocab(
fields, counters, opt.data_type,
opt.share_vocab, opt.vocab_size_multiple,
opt.src_vocab_size, opt.src_words_min_frequency,
opt.tgt_vocab_size, opt.tgt_words_min_frequency)
else:
fields = existing_fields
torch.save(fields, vocab_path)


Expand Down

0 comments on commit 1b3cc33

Please sign in to comment.