Skip to content

Add NPU backend support for val and inference #2109

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions timm/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,17 @@ def __init__(
)
else:
self.random_erasing = None
self.is_cuda = torch.cuda.is_available() and device.type == 'cuda'
self.is_cuda = device.type == 'cuda' and torch.cuda.is_available()
self.is_npu = device.type == 'npu' and torch.npu.is_available()

def __iter__(self):
first = True
if self.is_cuda:
stream = torch.cuda.Stream()
stream_context = partial(torch.cuda.stream, stream=stream)
elif self.is_npu:
stream = torch.npu.Stream()
stream_context = partial(torch.npu.stream, stream=stream)
else:
stream = None
stream_context = suppress
Expand All @@ -139,7 +143,10 @@ def __iter__(self):
first = False

if stream is not None:
torch.cuda.current_stream().wait_stream(stream)
if self.is_cuda:
torch.cuda.current_stream().wait_stream(stream)
elif self.is_npu:
torch.npu.current_stream().wait_stream(stream)

input = next_input
target = next_target
Expand Down
3 changes: 3 additions & 0 deletions timm/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def init_distributed_device_so(
"xpu": "ccl",
"hpu": "hccl",
"cuda": "nccl",
"npu": "hccl",
}
dist_backend = dist_backends.get(device_type, 'gloo')
dist_url = dist_url or 'env://'
Expand Down Expand Up @@ -159,6 +160,8 @@ def init_distributed_device_so(

if device_type == 'cuda':
assert torch.cuda.is_available(), f'CUDA is not available but {device} was specified.'
if device_type == 'npu':
assert torch.npu.is_available(), f'Ascend NPU is not available but {device} was specified.'

if distributed and device != 'cpu':
# Ignore manually specified device index in distributed mode and
Expand Down
9 changes: 7 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -1054,8 +1054,11 @@ def _backward(_loss):
if model_ema is not None:
model_ema.update(model, step=num_updates)

if args.synchronize_step and device.type == 'cuda':
torch.cuda.synchronize()
if args.synchronize_step:
if device.type == 'cuda':
torch.cuda.synchronize()
elif device.type == 'npu':
torch.npu.synchronize()
time_now = time.time()
update_time_m.update(time.time() - update_start_time)
update_start_time = time_now
Expand Down Expand Up @@ -1155,6 +1158,8 @@ def validate(

if device.type == 'cuda':
torch.cuda.synchronize()
elif device.type == "npu":
torch.npu.synchronize()

losses_m.update(reduced_loss.item(), input.size(0))
top1_m.update(acc1.item(), output.size(0))
Expand Down
4 changes: 3 additions & 1 deletion validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,8 +395,10 @@ def _try_run(args, initial_batch_size):
while batch_size:
args.batch_size = batch_size * args.num_gpu # multiply by num-gpu for DataParallel case
try:
if torch.cuda.is_available() and 'cuda' in args.device:
if 'cuda' in args.device and torch.cuda.is_available():
torch.cuda.empty_cache()
elif "npu" in args.device and torch.npu.is_available():
torch.npu.empty_cache()
results = validate(args)
return results
except RuntimeError as e:
Expand Down
Loading