Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Aug 24, 2023
1 parent 85ae19b commit 4668463
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions benchmark/training/training_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
rename_profile_file,
timeit,
torch_profile,
xpu_profile
xpu_profile,
)

supported_sets = {
Expand All @@ -35,8 +35,9 @@

device_conditions = {
'cuda': (lambda: torch.cuda.is_available()),
'mps': (lambda: (hasattr(torch.backends, 'mps')
and torch.backends.mps.is_available())),
'mps':
(lambda:
(hasattr(torch.backends, 'mps') and torch.backends.mps.is_available())),
'xpu': (lambda: torch.xpu.is_available()),
}

Expand Down Expand Up @@ -231,7 +232,8 @@ def run(args: argparse.ArgumentParser):
lr=0.001)

if args.device == 'xpu':
model, optimizer = ipex.optimize(model, optimizer=optimizer)
model, optimizer = ipex.optimize(
model, optimizer=optimizer)

progress_bar = False if args.no_progress_bar else True
train = train_hetero if hetero else train_homo
Expand Down Expand Up @@ -283,7 +285,8 @@ def run(args: argparse.ArgumentParser):

if args.profile:
if args.device == 'xpu':
profile = xpu_profile(args.export_chrome_trace)
profile = xpu_profile(
args.export_chrome_trace)
else:
profile = torch_profile(
args.export_chrome_trace, csv_data,
Expand Down

0 comments on commit 4668463

Please sign in to comment.