Skip to content

Commit

Permalink
add use xpu option, *test=kunlun
Browse files Browse the repository at this point in the history
  • Loading branch information
ykkk2333 committed Mar 1, 2022
1 parent fec42fd commit 6d2bdc8
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,12 @@ def parse_args():
help='The option of train profiler. If profiler_options is not None, the train ' \
'profiler is enabled. Refer to the paddleseg/utils/train_profiler.py for details.'
)
parser.add_argument(
'--device',
dest='device',
help='Device place, which can be GPU, XPU, CPU',
default='gpu',
type=str)

return parser.parse_args()

Expand All @@ -137,8 +143,12 @@ def main(args):
['-' * 48])
logger.info(info)

place = 'gpu' if env_info['Paddle compiled with cuda'] and env_info[
'GPUs used'] else 'cpu'
if args.device == 'gpu' and env_info['Paddle compiled with cuda'] and and env_info['GPUs used']:
place = 'gpu'
elif args.device == 'xpu' and paddle.is_compiled_with_xpu():
place = 'xpu'
else:
place = 'cpu'

paddle.set_device(place)
if not args.cfg:
Expand Down

0 comments on commit 6d2bdc8

Please sign in to comment.