diff --git a/docs/en/understand_mmcv/ops.md b/docs/en/understand_mmcv/ops.md index 8a4bfe88eb..822ee15589 100644 --- a/docs/en/understand_mmcv/ops.md +++ b/docs/en/understand_mmcv/ops.md @@ -32,7 +32,7 @@ We implement common ops used in detection, segmentation, etc. | MaskedConv | | √ | √ | | √ | | MergeCells | | √ | | | | | MinAreaPolygon | | √ | | | | -| ModulatedDeformConv2d | √ | √ | | | √ | +| ModulatedDeformConv2d | √ | √ | √ | | √ | | MultiScaleDeformableAttn | | √ | √ | | | | NMS | √ | √ | √ | | √ | | NMSRotated | √ | √ | | | | diff --git a/docs/zh_cn/understand_mmcv/ops.md b/docs/zh_cn/understand_mmcv/ops.md index 7cf062ef79..23d9b6e5fd 100644 --- a/docs/zh_cn/understand_mmcv/ops.md +++ b/docs/zh_cn/understand_mmcv/ops.md @@ -32,7 +32,7 @@ MMCV 提供了检测、分割等任务中常用的算子 | MaskedConv | | √ | √ | | √ | | MergeCells | | √ | | | | | MinAreaPolygon | | √ | | | | -| ModulatedDeformConv2d | √ | √ | | | √ | +| ModulatedDeformConv2d | √ | √ | √ | | √ | | MultiScaleDeformableAttn | | √ | √ | | | | NMS | √ | √ | √ | | √ | | NMSRotated | √ | √ | | | | diff --git a/mmcv/ops/__init__.py b/mmcv/ops/__init__.py index 8c76b47670..bcb9a5a4da 100755 --- a/mmcv/ops/__init__.py +++ b/mmcv/ops/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from mmcv.utils import IS_MLU_AVAILABLE from .active_rotated_filter import active_rotated_filter from .assign_score_withk import assign_score_withk from .ball_query import ball_query @@ -106,3 +107,8 @@ 'diff_iou_rotated_2d', 'diff_iou_rotated_3d', 'chamfer_distance', 'PrRoIPool', 'prroi_pool' ] + +if IS_MLU_AVAILABLE: + from .modulated_deform_conv import \ + ModulatedDeformConv2dPack_MLU # noqa:F401 + __all__.append('ModulatedDeformConv2dPack_MLU') diff --git a/mmcv/ops/modulated_deform_conv.py b/mmcv/ops/modulated_deform_conv.py index 7970d5323e..6a5173cb4f 100644 --- a/mmcv/ops/modulated_deform_conv.py +++ b/mmcv/ops/modulated_deform_conv.py @@ -8,7 +8,7 @@ from torch.autograd.function import once_differentiable from torch.nn.modules.utils import _pair, _single -from mmcv.utils import deprecated_api_warning +from mmcv.utils import IS_MLU_AVAILABLE, deprecated_api_warning from ..cnn import CONV_LAYERS from ..utils import ext_loader, print_log @@ -352,3 +352,88 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + +if IS_MLU_AVAILABLE: + from torchvision.ops import deform_conv2d as tv_deform_conv2d + + @CONV_LAYERS.register_module('DCNv2', force=True) + class ModulatedDeformConv2dPack_MLU(nn.modules.Module): + """This class is the DCNv2 implementation of the MLU device. The MLU + backend support of the operator has been implemented in torchvision. + The mmcv registration mechanism is used for multiplexing here. The + torchvision implementation of DCNv2 is called. + + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + kernel_size (int or tuple[int]): Same as nn.Conv2d. + stride (int): Same as nn.Conv2d, while tuple is not supported. + padding (int): Same as nn.Conv2d, while tuple is not supported. + dilation (int): Same as nn.Conv2d, while tuple is not supported. + groups (int): Same as nn.Conv2d. + bias (bool or str): If specified as `auto`, it will be decided by + the norm_cfg. Bias will be set as True if norm_cfg is None, + otherwise False. + """ + + def __init__(self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int]], + stride: int = 1, + padding: int = 0, + dilation: int = 1, + groups: int = 1, + deform_groups: int = 1, + bias: Union[bool, str] = True): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = _pair(stride) + self.padding = _pair(padding) + self.dilation = _pair(dilation) + self.groups = groups + self.deform_groups = deform_groups + self.weight = nn.Parameter( + torch.Tensor(out_channels, in_channels, *self.kernel_size)) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter('bias', None) + self.conv_offset = nn.Conv2d( + self.in_channels, + self.deform_groups * 3 * self.kernel_size[0] * + self.kernel_size[1], + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding, + bias=True) + self.init_weights() + + def init_weights(self): + n = self.in_channels + for k in self.kernel_size: + n *= k + stdv = 1. / math.sqrt(n) + self.weight.data.uniform_(-stdv, stdv) + if self.bias is not None: + self.bias.data.zero_() + self.conv_offset.weight.data.zero_() + self.conv_offset.bias.data.zero_() + + def forward(self, x): + out = self.conv_offset(x) + o1, o2, mask = torch.chunk(out, 3, dim=1) + offset = torch.cat((o1, o2), dim=1) + mask = torch.sigmoid(mask) + return tv_deform_conv2d( + x, + offset, + self.weight, + bias=self.bias, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + mask=mask) diff --git a/tests/test_ops/test_modulated_deform_conv.py b/tests/test_ops/test_modulated_deform_conv.py index 3b9070491a..927489df6d 100644 --- a/tests/test_ops/test_modulated_deform_conv.py +++ b/tests/test_ops/test_modulated_deform_conv.py @@ -5,7 +5,7 @@ import pytest import torch -from mmcv.utils import TORCH_VERSION, digit_version +from mmcv.utils import IS_MLU_AVAILABLE, TORCH_VERSION, digit_version try: # If PyTorch version >= 1.6.0 and fp16 is enabled, torch.cuda.amp.autocast @@ -42,10 +42,14 @@ class TestMdconv: def _test_mdconv(self, dtype=torch.float, device='cuda'): if not torch.cuda.is_available() and device == 'cuda': pytest.skip('test requires GPU') - from mmcv.ops import ModulatedDeformConv2dPack + if device == 'mlu': + from mmcv.ops import \ + ModulatedDeformConv2dPack_MLU as ModulatedDeformConv2dPack + else: + from mmcv.ops import ModulatedDeformConv2dPack + input = torch.tensor(input_t, dtype=dtype, device=device) input.requires_grad = True - dcn = ModulatedDeformConv2dPack( 1, 1, @@ -53,10 +57,7 @@ def _test_mdconv(self, dtype=torch.float, device='cuda'): stride=1, padding=1, deform_groups=1, - bias=False) - - if device == 'cuda': - dcn.cuda() + bias=False).to(device) dcn.weight.data.fill_(1.) dcn.type(dtype) @@ -114,9 +115,11 @@ def _test_amp_mdconv(self, input_dtype=torch.float): def test_mdconv(self): self._test_mdconv(torch.double, device='cpu') self._test_mdconv(torch.float, device='cpu') - self._test_mdconv(torch.double) - self._test_mdconv(torch.float) - self._test_mdconv(torch.half) + + device = 'mlu' if IS_MLU_AVAILABLE else 'cuda' + self._test_mdconv(torch.double, device=device) + self._test_mdconv(torch.float, device=device) + self._test_mdconv(torch.half, device=device) # test amp when torch version >= '1.6.0', the type of # input data for mdconv might be torch.float or torch.half