Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support engine with NPU backend. #572

Merged
merged 20 commits into from
Oct 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/merge_stage_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
- name: Upgrade pip
run: pip install pip --upgrade
# Windows CI could fail If we call `pip install pip --upgrade` directly.
run: python -m pip install pip --upgrade
- name: Install PyTorch
run: pip install torch==1.8.1+${{matrix.platform}} torchvision==0.9.1+${{matrix.platform}} -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html
- name: Build MMEngine from source
Expand Down
3 changes: 2 additions & 1 deletion .github/workflows/pr_stage_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
- name: Upgrade pip
run: pip install pip --upgrade
# Windows CI could fail If we call `pip install pip --upgrade` directly.
run: python -m pip install pip --upgrade
- name: Install PyTorch
run: pip install torch==1.8.1+${{matrix.platform}} torchvision==0.9.1+${{matrix.platform}} -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html
- name: Build MMEngine from source
Expand Down
1 change: 1 addition & 0 deletions docs/en/api/device.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@ mmengine.device
get_device
get_max_cuda_memory
is_cuda_available
is_npu_available
is_mlu_available
is_mps_available
1 change: 1 addition & 0 deletions docs/zh_cn/api/device.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@ mmengine.device
get_device
get_max_cuda_memory
is_cuda_available
is_npu_available
is_mlu_available
is_mps_available
4 changes: 2 additions & 2 deletions mmengine/device/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .utils import (get_device, get_max_cuda_memory, is_cuda_available,
is_mlu_available, is_mps_available)
is_mlu_available, is_mps_available, is_npu_available)

__all__ = [
'get_max_cuda_memory', 'get_device', 'is_cuda_available',
'is_mlu_available', 'is_mps_available'
'is_mlu_available', 'is_mps_available', 'is_npu_available'
]
15 changes: 13 additions & 2 deletions mmengine/device/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,15 @@ def is_cuda_available() -> bool:
return torch.cuda.is_available()


def is_npu_available() -> bool:
"""Returns True if Ascend PyTorch and npu devices exist."""
try:
import torch_npu # noqa: F401
except Exception:
return False
return hasattr(torch, 'npu') and torch.npu.is_available()


def is_mlu_available() -> bool:
"""Returns True if Cambricon PyTorch and mlu devices exist."""
return hasattr(torch, 'is_mlu_available') and torch.is_mlu_available()
Expand All @@ -49,9 +58,11 @@ def get_device() -> str:
"""Returns the currently existing device type.

Returns:
str: cuda | mlu | mps | cpu.
str: cuda | npu | mlu | mps | cpu.
"""
if is_cuda_available():
if is_npu_available():
return 'npu'
elif is_cuda_available():
return 'cuda'
elif is_mlu_available():
return 'mlu'
Expand Down
12 changes: 9 additions & 3 deletions mmengine/dist/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
get_comm_device, cast_data_device)
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION
from mmengine.device import is_npu_available


def _get_reduce_op(name: str) -> torch_dist.ReduceOp:
Expand Down Expand Up @@ -411,7 +412,11 @@ def _broadcast_object_list(object_list: List[Any],
group_backend = get_backend(group)
is_nccl_backend = group_backend == torch_dist.Backend.NCCL
current_device = torch.device('cpu')
if is_nccl_backend:
is_hccl_backend = group_backend == 'hccl'
if is_hccl_backend:
current_device = torch.npu.current_device()
object_sizes_tensor = object_sizes_tensor.to(current_device)
elif is_nccl_backend:
# See note about using torch.cuda.current_device() here in
# docstring. We cannot simply use my_rank since rank == device is
# not necessarily true.
Expand All @@ -430,7 +435,7 @@ def _broadcast_object_list(object_list: List[Any],
dtype=torch.uint8,
)

if is_nccl_backend:
if is_nccl_backend or is_hccl_backend:
object_tensor = object_tensor.to(current_device)
torch_dist.broadcast(object_tensor, src=src, group=group)
# Deserialize objects using their stored sizes.
Expand Down Expand Up @@ -504,7 +509,8 @@ def broadcast_object_list(data: List[Any],
if group is None:
group = get_default_group()

if digit_version(TORCH_VERSION) >= digit_version('1.8.0'):
if digit_version(TORCH_VERSION) >= digit_version(
'1.8.0') and not is_npu_available():
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
torch_dist.broadcast_object_list(data, src, group)
else:
_broadcast_object_list(data, src, group)
Expand Down
15 changes: 13 additions & 2 deletions mmengine/dist/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch import Tensor
from torch import distributed as torch_dist
from torch.distributed import ProcessGroup
from mmengine.device import is_mlu_available
from mmengine.device import is_mlu_available, is_npu_available

from collections.abc import Iterable, Mapping

Expand Down Expand Up @@ -80,6 +80,14 @@ def _init_dist_pytorch(backend, **kwargs) -> None:
rank=rank,
world_size=int(os.environ['WORLD_SIZE']),
**kwargs)
elif is_npu_available():
import torch_npu # noqa: F401
torch.npu.set_device(rank)
torch_dist.init_process_group(
backend='hccl',
rank=rank,
world_size=int(os.environ['WORLD_SIZE']),
**kwargs)
else:
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
Expand Down Expand Up @@ -437,7 +445,10 @@ def get_comm_device(group: Optional[ProcessGroup] = None) -> torch.device:
torch.device: The device of backend.
"""
backend = get_backend(group)
if backend == torch_dist.Backend.NCCL:
if backend == 'hccl':
import torch_npu # noqa: F401
return torch.device('npu', torch.npu.current_device())
elif backend == torch_dist.Backend.NCCL:
return torch.device('cuda', torch.cuda.current_device())
elif backend == 'cncl':
import torch_mlu # noqa: F401
Expand Down
19 changes: 19 additions & 0 deletions mmengine/model/base_model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,25 @@ def cuda(
self._set_device(torch.device(device))
return super().cuda(device)

def npu(
self,
device: Union[int, str, torch.device, None] = None,
) -> nn.Module:
"""Overrides this method to call :meth:`BaseDataPreprocessor.npu`
additionally.

Returns:
nn.Module: The model itself.

Note:
This generation of NPU(Ascend910) does not support
the use of multiple cards in a single process,
so the index here needs to be consistent with the default device
"""
device = torch.npu.current_device()
self._set_device(device)
return super().npu()

def cpu(self, *args, **kwargs) -> nn.Module:
"""Overrides this method to call :meth:`BaseDataPreprocessor.cpu`
additionally.
Expand Down
11 changes: 8 additions & 3 deletions mmengine/optim/optimizer/amp_optimizer_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,18 @@

import torch
import torch.nn as nn
from torch.cuda.amp import GradScaler

from mmengine.device import is_cuda_available, is_npu_available
from mmengine.registry import OPTIM_WRAPPERS
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION
from .optimizer_wrapper import OptimWrapper

if is_npu_available():
from torch.npu.amp import GradScaler
else:
from torch.cuda.amp import GradScaler


@OPTIM_WRAPPERS.register_module()
class AmpOptimWrapper(OptimWrapper):
Expand Down Expand Up @@ -44,8 +49,8 @@ class AmpOptimWrapper(OptimWrapper):
def __init__(self, loss_scale='dynamic', **kwargs):
assert digit_version(TORCH_VERSION) >= digit_version('1.6.0'), (
'`torch.cuda.amp` is only available when pytorch version >= 1.6')
assert torch.cuda.is_available(), (
'``AmpOptimizerWrapper`` is only available training on gpu')
assert is_cuda_available() or is_npu_available(), (
'``AmpOptimizerWrapper`` is only available training on gpu or npu')
super().__init__(**kwargs)
self._scale_update_param = None
if loss_scale == 'dynamic':
Expand Down
8 changes: 8 additions & 0 deletions mmengine/optim/optimizer/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch.nn as nn

from mmengine.config import Config, ConfigDict
from mmengine.device import is_npu_available
from mmengine.registry import OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS
from .optimizer_wrapper import OptimWrapper

Expand Down Expand Up @@ -53,6 +54,13 @@ def build_optim_wrapper(model: nn.Module,
constructor_type = optim_wrapper_cfg.pop('constructor',
'DefaultOptimWrapperConstructor')
paramwise_cfg = optim_wrapper_cfg.pop('paramwise_cfg', None)

# Since the current generation of NPU(Ascend 910) only supports
# mixed precision training, here we turn on mixed precision by default
# on the NPU to make the training normal
if is_npu_available():
optim_wrapper_cfg['type'] = 'AmpOptimWrapper'
wangjiangben-hw marked this conversation as resolved.
Show resolved Hide resolved

optim_wrapper_constructor = OPTIM_WRAPPER_CONSTRUCTORS.build(
dict(
type=constructor_type,
Expand Down
7 changes: 5 additions & 2 deletions mmengine/runner/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import torch

from mmengine.device import get_device
from mmengine.device import get_device, is_cuda_available, is_npu_available
from mmengine.logging import print_log
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION
Expand Down Expand Up @@ -86,7 +86,10 @@ def autocast(device_type: Optional[str] = None,
logger='current',
level=logging.WARNING)

if torch.cuda.is_available():
if is_npu_available():
with torch.npu.amp.autocast(enabled=enabled):
yield
elif is_cuda_available():
with torch.cuda.amp.autocast(enabled=enabled):
yield
else:
Expand Down
6 changes: 4 additions & 2 deletions tests/test_device/test_device.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.device import (get_device, is_cuda_available, is_mlu_available,
is_mps_available)
is_mps_available, is_npu_available)


def test_get_device():
device = get_device()
if is_cuda_available():
if is_npu_available():
assert device == 'npu'
elif is_cuda_available():
assert device == 'cuda'
elif is_mlu_available():
assert device == 'mlu'
Expand Down