Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
da03 committed Mar 3, 2018
1 parent e625e1c commit 97760f2
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 4 deletions.
4 changes: 2 additions & 2 deletions onmt/io/IO.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def build_vocab(train_dataset_files, fields, data_type, share_vocab,
src_vocab = None
if len(src_vocab_path) > 0:
src_vocab = set([])
print ('Loading source vocab from %s' % src_vocab_path)
print('Loading source vocab from %s' % src_vocab_path)
assert os.path.exists(src_vocab_path), \
'src vocab %s not found!' % src_vocab_path
with open(src_vocab_path) as f:
Expand All @@ -266,7 +266,7 @@ def build_vocab(train_dataset_files, fields, data_type, share_vocab,
tgt_vocab = None
if len(tgt_vocab_path) > 0:
tgt_vocab = set([])
print ('Loading target vocab from %s' % tgt_vocab_path)
print('Loading target vocab from %s' % tgt_vocab_path)
assert os.path.exists(tgt_vocab_path), \
'tgt vocab %s not found!' % tgt_vocab_path
with open(tgt_vocab_path) as f:
Expand Down
4 changes: 2 additions & 2 deletions opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,10 @@ def preprocess_opts(parser):
# Dictionary options, for text corpus

group = parser.add_argument_group('Vocab')
group.add_argument('-src_vocab',
group.add_argument('-src_vocab', default="",
help="""Path to an existing source vocabulary. Format:
one word per line.""")
group.add_argument('-tgt_vocab',
group.add_argument('-tgt_vocab', default="",
help="""Path to an existing target vocabulary. Format:
one word per line.""")
group.add_argument('-features_vocabs_prefix', type=str, default='',
Expand Down
2 changes: 2 additions & 0 deletions preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,10 @@ def build_save_dataset(corpus_type, fields, opt):
def build_save_vocab(train_dataset, fields, opt):
fields = onmt.io.build_vocab(train_dataset, fields, opt.data_type,
opt.share_vocab,
opt.src_vocab,
opt.src_vocab_size,
opt.src_words_min_frequency,
opt.tgt_vocab,
opt.tgt_vocab_size,
opt.tgt_words_min_frequency)

Expand Down
14 changes: 14 additions & 0 deletions test/test_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import unittest
import glob
import os
import codecs
from collections import Counter

import torchtext
Expand Down Expand Up @@ -39,6 +40,13 @@ def __init__(self, *args, **kwargs):
def dataset_build(self, opt):
fields = onmt.io.get_fields("text", 0, 0)

if hasattr(opt, 'src_vocab') and len(opt.src_vocab) > 0:
with codecs.open(opt.src_vocab, 'w', 'utf-8') as f:
f.write('a\nb\nc\nd\ne\nf\n')
if hasattr(opt, 'tgt_vocab') and len(opt.tgt_vocab) > 0:
with codecs.open(opt.tgt_vocab, 'w', 'utf-8') as f:
f.write('a\nb\nc\nd\ne\nf\n')

train_data_files = preprocess.build_save_dataset('train', fields, opt)

preprocess.build_save_vocab(train_data_files, fields, opt)
Expand All @@ -48,6 +56,10 @@ def dataset_build(self, opt):
# Remove the generated *pt files.
for pt in glob.glob(SAVE_DATA_PREFIX + '*.pt'):
os.remove(pt)
if hasattr(opt, 'src_vocab') and os.path.exists(opt.src_vocab):
os.remove(opt.src_vocab)
if hasattr(opt, 'tgt_vocab') and os.path.exists(opt.tgt_vocab):
os.remove(opt.tgt_vocab)

def test_merge_vocab(self):
va = torchtext.vocab.Vocab(Counter('abbccc'))
Expand Down Expand Up @@ -109,6 +121,8 @@ def test_method(self):
('share_vocab', True)],
[('dynamic_dict', True),
('max_shard_size', 500000)],
[('src_vocab', '/tmp/src_vocab.txt'),
('tgt_vocab', '/tmp/tgt_vocab.txt')],
]

for p in test_databuild:
Expand Down

0 comments on commit 97760f2

Please sign in to comment.