diff --git a/onmt/train_single.py b/onmt/train_single.py index fb7ef8ee3e..e55c89955a 100755 --- a/onmt/train_single.py +++ b/onmt/train_single.py @@ -99,15 +99,14 @@ def main(opt, device_id): opt, device_id, model, fields, optim, model_saver=model_saver) train_iterables = [] - for train_id in opt.data_ids: - if train_id: + if len(opt.data_ids) > 1: + for train_id in opt.data_ids: shard_base = "train_" + train_id - else: - shard_base = "train" - iterable = build_dataset_iter(shard_base, fields, opt, multi=True) - train_iterables.append(iterable) - - train_iter = MultipleDatasetIterator(train_iterables, device_id, opt) + iterable = build_dataset_iter(shard_base, fields, opt, multi=True) + train_iterables.append(iterable) + train_iter = MultipleDatasetIterator(train_iterables, device_id, opt) + else: + train_iter = build_dataset_iter("train", fields, opt) valid_iter = build_dataset_iter( "valid", fields, opt, is_train=False)