diff --git a/examples/machine_translation/seq2seq/predict.py b/examples/machine_translation/seq2seq/predict.py index 818dca16d5e91..0f724b66723fa 100644 --- a/examples/machine_translation/seq2seq/predict.py +++ b/examples/machine_translation/seq2seq/predict.py @@ -41,7 +41,7 @@ def post_process_seq(seq, bos_idx, eos_idx, output_bos=False, output_eos=False): def do_predict(args): - device = paddle.set_device("gpu" if args.use_gpu else "cpu") + device = paddle.set_device(args.select_device) test_loader, src_vocab_size, tgt_vocab_size, bos_id, eos_id = create_infer_loader( args) diff --git a/examples/machine_translation/seq2seq/train.py b/examples/machine_translation/seq2seq/train.py index 834fcf19d51c8..b15ebadc33733 100644 --- a/examples/machine_translation/seq2seq/train.py +++ b/examples/machine_translation/seq2seq/train.py @@ -23,7 +23,7 @@ def do_train(args): - device = paddle.set_device("gpu" if args.use_gpu else "cpu") + device = paddle.set_device(args.select_device) # Define dataloader train_loader, eval_loader, src_vocab_size, tgt_vocab_size, eos_id = create_train_loader(