diff --git a/train.py b/train.py index 8b01bef614..a923fbb380 100644 --- a/train.py +++ b/train.py @@ -120,6 +120,12 @@ def parse_args(): help='The option of train profiler. If profiler_options is not None, the train ' \ 'profiler is enabled. Refer to the paddleseg/utils/train_profiler.py for details.' ) + parser.add_argument( + '--device', + dest='device', + help='Device place, which can be GPU, XPU, CPU', + default='gpu', + type=str) return parser.parse_args() @@ -137,8 +143,12 @@ def main(args): ['-' * 48]) logger.info(info) - place = 'gpu' if env_info['Paddle compiled with cuda'] and env_info[ - 'GPUs used'] else 'cpu' + if args.device == 'gpu' and env_info['Paddle compiled with cuda'] and and env_info['GPUs used']: + place = 'gpu' + elif args.device == 'xpu' and paddle.is_compiled_with_xpu(): + place = 'xpu' + else: + place = 'cpu' paddle.set_device(place) if not args.cfg: