-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Enhancement] Add xpu option, *test=kunlun #1815
Conversation
train.py
Outdated
@@ -120,6 +120,12 @@ def parse_args(): | |||
help='The option of train profiler. If profiler_options is not None, the train ' \ | |||
'profiler is enabled. Refer to the paddleseg/utils/train_profiler.py for details.' | |||
) | |||
parser.add_argument( | |||
'--use_xpu', | |||
dest='use_xpu', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use --device argument, support gpu/cpu/xpu.
train.py
Outdated
@@ -139,6 +145,8 @@ def main(args): | |||
|
|||
place = 'gpu' if env_info['Paddle compiled with cuda'] and env_info[ | |||
'GPUs used'] else 'cpu' | |||
if args.use_xpu and paddle.is_compiled_with_xpu(): | |||
place = 'xpu' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please replace with the the following code.
If args.device == 'gpu' and env_info['Paddle compiled with cuda'] and and env_info['GPUs used']:
place = 'gpu'
elif args.device == 'xpu' and paddle.is_compiled_with_xpu():
place = 'xpu'
else:
place = 'cpu'
b46821c
to
2a443fe
Compare
train.py
Outdated
'--device', | ||
dest='device', | ||
help='Device place, which can be GPU, XPU, CPU', | ||
default=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
default='gpu' and type=str
train.py
Outdated
@@ -137,8 +143,12 @@ def main(args): | |||
['-' * 48]) | |||
logger.info(info) | |||
|
|||
place = 'gpu' if env_info['Paddle compiled with cuda'] and env_info[ | |||
'GPUs used'] else 'cpu' | |||
if args.device == 'gpu' and env_info['Paddle compiled with cuda'] and and env_info['GPUs used']: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First run pip install pre-commit
and then git commit xx
.
The length of the above line is too long.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
add use xpu option in config, *test=kunlun