Skip to content

Commit

Permalink
Merge 6b963f0 into dc01545
Browse files Browse the repository at this point in the history
  • Loading branch information
wangjiangben-hw authored Oct 19, 2022
2 parents dc01545 + 6b963f0 commit 402f555
Show file tree
Hide file tree
Showing 11 changed files with 83 additions and 16 deletions.
1 change: 1 addition & 0 deletions docs/en/api/device.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@ mmengine.device
get_device
get_max_cuda_memory
is_cuda_available
is_npu_available
is_mlu_available
is_mps_available
1 change: 1 addition & 0 deletions docs/zh_cn/api/device.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@ mmengine.device
get_device
get_max_cuda_memory
is_cuda_available
is_npu_available
is_mlu_available
is_mps_available
4 changes: 2 additions & 2 deletions mmengine/device/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .utils import (get_device, get_max_cuda_memory, is_cuda_available,
is_mlu_available, is_mps_available)
is_mlu_available, is_mps_available, is_npu_available)

__all__ = [
'get_max_cuda_memory', 'get_device', 'is_cuda_available',
'is_mlu_available', 'is_mps_available'
'is_mlu_available', 'is_mps_available', 'is_npu_available'
]
15 changes: 13 additions & 2 deletions mmengine/device/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,15 @@ def is_cuda_available() -> bool:
return torch.cuda.is_available()


def is_npu_available() -> bool:
"""Returns True if Ascend PyTorch and npu devices exist."""
try:
import torch_npu # noqa: F401
except Exception:
return False
return hasattr(torch, 'npu') and torch.npu.is_available()


def is_mlu_available() -> bool:
"""Returns True if Cambricon PyTorch and mlu devices exist."""
return hasattr(torch, 'is_mlu_available') and torch.is_mlu_available()
Expand All @@ -49,9 +58,11 @@ def get_device() -> str:
"""Returns the currently existing device type.
Returns:
str: cuda | mlu | mps | cpu.
str: cuda | npu | mlu | mps | cpu.
"""
if is_cuda_available():
if is_npu_available():
return 'npu'
elif is_cuda_available():
return 'cuda'
elif is_mlu_available():
return 'mlu'
Expand Down
12 changes: 9 additions & 3 deletions mmengine/dist/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
get_comm_device, cast_data_device)
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION
from mmengine.device import is_npu_available


def _get_reduce_op(name: str) -> torch_dist.ReduceOp:
Expand Down Expand Up @@ -411,7 +412,11 @@ def _broadcast_object_list(object_list: List[Any],
group_backend = get_backend(group)
is_nccl_backend = group_backend == torch_dist.Backend.NCCL
current_device = torch.device('cpu')
if is_nccl_backend:
is_hccl_backend = group_backend == 'hccl'
if is_hccl_backend:
current_device = torch.npu.current_device()
object_sizes_tensor = object_sizes_tensor.to(current_device)
elif is_nccl_backend:
# See note about using torch.cuda.current_device() here in
# docstring. We cannot simply use my_rank since rank == device is
# not necessarily true.
Expand All @@ -430,7 +435,7 @@ def _broadcast_object_list(object_list: List[Any],
dtype=torch.uint8,
)

if is_nccl_backend:
if is_nccl_backend or is_hccl_backend:
object_tensor = object_tensor.to(current_device)
torch_dist.broadcast(object_tensor, src=src, group=group)
# Deserialize objects using their stored sizes.
Expand Down Expand Up @@ -504,7 +509,8 @@ def broadcast_object_list(data: List[Any],
if group is None:
group = get_default_group()

if digit_version(TORCH_VERSION) >= digit_version('1.8.0'):
if digit_version(TORCH_VERSION) >= digit_version(
'1.8.0') and not is_npu_available():
torch_dist.broadcast_object_list(data, src, group)
else:
_broadcast_object_list(data, src, group)
Expand Down
15 changes: 13 additions & 2 deletions mmengine/dist/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch import Tensor
from torch import distributed as torch_dist
from torch.distributed import ProcessGroup
from mmengine.device import is_mlu_available
from mmengine.device import is_mlu_available, is_npu_available

from collections.abc import Iterable, Mapping

Expand Down Expand Up @@ -80,6 +80,14 @@ def _init_dist_pytorch(backend, **kwargs) -> None:
rank=rank,
world_size=int(os.environ['WORLD_SIZE']),
**kwargs)
elif is_npu_available():
import torch_npu # noqa: F401
torch.npu.set_device(rank)
torch_dist.init_process_group(
backend='hccl',
rank=rank,
world_size=int(os.environ['WORLD_SIZE']),
**kwargs)
else:
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
Expand Down Expand Up @@ -437,7 +445,10 @@ def get_comm_device(group: Optional[ProcessGroup] = None) -> torch.device:
torch.device: The device of backend.
"""
backend = get_backend(group)
if backend == torch_dist.Backend.NCCL:
if backend == 'hccl':
import torch_npu # noqa: F401
return torch.device('npu', torch.npu.current_device())
elif backend == torch_dist.Backend.NCCL:
return torch.device('cuda', torch.cuda.current_device())
elif backend == 'cncl':
import torch_mlu # noqa: F401
Expand Down
19 changes: 19 additions & 0 deletions mmengine/model/base_model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,25 @@ def cuda(
self._set_device(torch.device(device))
return super().cuda(device)

def npu(
self,
device: Union[int, str, torch.device, None] = None,
) -> nn.Module:
"""Overrides this method to call :meth:`BaseDataPreprocessor.npu`
additionally.
Returns:
nn.Module: The model itself.
Note:
This generation of NPU(Ascend910) does not support
the use of multiple cards in a single process,
so the index here needs to be consistent with the default device
"""
device = torch.npu.current_device()
self._set_device(device)
return super().npu()

def cpu(self, *args, **kwargs) -> nn.Module:
"""Overrides this method to call :meth:`BaseDataPreprocessor.cpu`
additionally.
Expand Down
11 changes: 8 additions & 3 deletions mmengine/optim/optimizer/amp_optimizer_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,18 @@

import torch
import torch.nn as nn
from torch.cuda.amp import GradScaler

from mmengine.device import is_cuda_available, is_npu_available
from mmengine.registry import OPTIM_WRAPPERS
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION
from .optimizer_wrapper import OptimWrapper

if is_npu_available():
from torch.npu.amp import GradScaler
else:
from torch.cuda.amp import GradScaler


@OPTIM_WRAPPERS.register_module()
class AmpOptimWrapper(OptimWrapper):
Expand Down Expand Up @@ -44,8 +49,8 @@ class AmpOptimWrapper(OptimWrapper):
def __init__(self, loss_scale='dynamic', **kwargs):
assert digit_version(TORCH_VERSION) >= digit_version('1.6.0'), (
'`torch.cuda.amp` is only available when pytorch version >= 1.6')
assert torch.cuda.is_available(), (
'``AmpOptimizerWrapper`` is only available training on gpu')
assert is_cuda_available() or is_npu_available(), (
'``AmpOptimizerWrapper`` is only available training on gpu or npu')
super().__init__(**kwargs)
self._scale_update_param = None
if loss_scale == 'dynamic':
Expand Down
8 changes: 8 additions & 0 deletions mmengine/optim/optimizer/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch.nn as nn

from mmengine.config import Config, ConfigDict
from mmengine.device import is_npu_available
from mmengine.registry import OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS
from .optimizer_wrapper import OptimWrapper

Expand Down Expand Up @@ -53,6 +54,13 @@ def build_optim_wrapper(model: nn.Module,
constructor_type = optim_wrapper_cfg.pop('constructor',
'DefaultOptimWrapperConstructor')
paramwise_cfg = optim_wrapper_cfg.pop('paramwise_cfg', None)

# Since the current generation of NPU(Ascend 910) only supports
# mixed precision training, here we turn on mixed precision by default
# on the NPU to make the training normal
if is_npu_available():
optim_wrapper_cfg['type'] = 'AmpOptimWrapper'

optim_wrapper_constructor = OPTIM_WRAPPER_CONSTRUCTORS.build(
dict(
type=constructor_type,
Expand Down
7 changes: 5 additions & 2 deletions mmengine/runner/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import torch

from mmengine.device import get_device
from mmengine.device import get_device, is_cuda_available, is_npu_available
from mmengine.logging import print_log
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION
Expand Down Expand Up @@ -86,7 +86,10 @@ def autocast(device_type: Optional[str] = None,
logger='current',
level=logging.WARNING)

if torch.cuda.is_available():
if is_npu_available():
with torch.npu.amp.autocast(enabled=enabled):
yield
elif is_cuda_available():
with torch.cuda.amp.autocast(enabled=enabled):
yield
else:
Expand Down
6 changes: 4 additions & 2 deletions tests/test_device/test_device.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.device import (get_device, is_cuda_available, is_mlu_available,
is_mps_available)
is_mps_available, is_npu_available)


def test_get_device():
device = get_device()
if is_cuda_available():
if is_npu_available():
assert device == 'npu'
elif is_cuda_available():
assert device == 'cuda'
elif is_mlu_available():
assert device == 'mlu'
Expand Down

0 comments on commit 402f555

Please sign in to comment.