Skip to content

Commit

Permalink
fix device handling MultipleDatasetIterator
Browse files Browse the repository at this point in the history
  • Loading branch information
francoishernandez committed May 3, 2019
1 parent 253130b commit 3a8c9fc
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion onmt/inputters/inputter.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,7 @@ def __init__(self,
self.batch_size_fn = max_tok_len \
if opt.batch_type == "tokens" else None
self.batch_size_multiple = 8 if opt.model_dtype == "fp16" else 1
self.device = device
self.device = "cuda" if device >= 0 else "cpu"
# Temporarily load one shard to retrieve sort_key for data_type
temp_dataset = torch.load(self.iterables[0]._paths[0])
self.sort_key = temp_dataset.sort_key
Expand Down

0 comments on commit 3a8c9fc

Please sign in to comment.