Skip to content

Commit

Permalink
[Feature] Support modulated_deform_conv with cambricon MLU backend (o…
Browse files Browse the repository at this point in the history
…pen-mmlab#2411)

* [Feature] Support modulated_deform_conv with cambricon MLU backend

* fix error of torch_mlu

* modify with commit suggestion

* Update modulated_deform_conv.py

* Update mmcv/ops/modulated_deform_conv.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
  • Loading branch information
2 people authored and ckirchhoff2021 committed Dec 21, 2022
1 parent 82c9e19 commit 27a2b22
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 13 deletions.
2 changes: 1 addition & 1 deletion docs/en/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ We implement common ops used in detection, segmentation, etc.
| MaskedConv | ||| ||
| MergeCells | || | | |
| MinAreaPolygon | || | | |
| ModulatedDeformConv2d ||| | ||
| ModulatedDeformConv2d ||| | ||
| MultiScaleDeformableAttn | ||| | |
| NMS |||| ||
| NMSRotated ||| | | |
Expand Down
2 changes: 1 addition & 1 deletion docs/zh_cn/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ MMCV 提供了检测、分割等任务中常用的算子
| MaskedConv | ||| ||
| MergeCells | || | | |
| MinAreaPolygon | || | | |
| ModulatedDeformConv2d ||| | ||
| ModulatedDeformConv2d ||| | ||
| MultiScaleDeformableAttn | ||| | |
| NMS |||| ||
| NMSRotated ||| | | |
Expand Down
6 changes: 6 additions & 0 deletions mmcv/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.utils import IS_MLU_AVAILABLE
from .active_rotated_filter import active_rotated_filter
from .assign_score_withk import assign_score_withk
from .ball_query import ball_query
Expand Down Expand Up @@ -106,3 +107,8 @@
'diff_iou_rotated_2d', 'diff_iou_rotated_3d', 'chamfer_distance',
'PrRoIPool', 'prroi_pool'
]

if IS_MLU_AVAILABLE:
from .modulated_deform_conv import \
ModulatedDeformConv2dPack_MLU # noqa:F401
__all__.append('ModulatedDeformConv2dPack_MLU')
87 changes: 86 additions & 1 deletion mmcv/ops/modulated_deform_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair, _single

from mmcv.utils import deprecated_api_warning
from mmcv.utils import IS_MLU_AVAILABLE, deprecated_api_warning
from ..cnn import CONV_LAYERS
from ..utils import ext_loader, print_log

Expand Down Expand Up @@ -352,3 +352,88 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
super()._load_from_state_dict(state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys,
error_msgs)


if IS_MLU_AVAILABLE:
from torchvision.ops import deform_conv2d as tv_deform_conv2d

@CONV_LAYERS.register_module('DCNv2', force=True)
class ModulatedDeformConv2dPack_MLU(nn.modules.Module):
"""This class is the DCNv2 implementation of the MLU device. The MLU
backend support of the operator has been implemented in torchvision.
The mmcv registration mechanism is used for multiplexing here. The
torchvision implementation of DCNv2 is called.
Args:
in_channels (int): Same as nn.Conv2d.
out_channels (int): Same as nn.Conv2d.
kernel_size (int or tuple[int]): Same as nn.Conv2d.
stride (int): Same as nn.Conv2d, while tuple is not supported.
padding (int): Same as nn.Conv2d, while tuple is not supported.
dilation (int): Same as nn.Conv2d, while tuple is not supported.
groups (int): Same as nn.Conv2d.
bias (bool or str): If specified as `auto`, it will be decided by
the norm_cfg. Bias will be set as True if norm_cfg is None,
otherwise False.
"""

def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int]],
stride: int = 1,
padding: int = 0,
dilation: int = 1,
groups: int = 1,
deform_groups: int = 1,
bias: Union[bool, str] = True):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
self.stride = _pair(stride)
self.padding = _pair(padding)
self.dilation = _pair(dilation)
self.groups = groups
self.deform_groups = deform_groups
self.weight = nn.Parameter(
torch.Tensor(out_channels, in_channels, *self.kernel_size))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.conv_offset = nn.Conv2d(
self.in_channels,
self.deform_groups * 3 * self.kernel_size[0] *
self.kernel_size[1],
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.padding,
bias=True)
self.init_weights()

def init_weights(self):
n = self.in_channels
for k in self.kernel_size:
n *= k
stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.zero_()
self.conv_offset.weight.data.zero_()
self.conv_offset.bias.data.zero_()

def forward(self, x):
out = self.conv_offset(x)
o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat((o1, o2), dim=1)
mask = torch.sigmoid(mask)
return tv_deform_conv2d(
x,
offset,
self.weight,
bias=self.bias,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
mask=mask)
23 changes: 13 additions & 10 deletions tests/test_ops/test_modulated_deform_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest
import torch

from mmcv.utils import TORCH_VERSION, digit_version
from mmcv.utils import IS_MLU_AVAILABLE, TORCH_VERSION, digit_version

try:
# If PyTorch version >= 1.6.0 and fp16 is enabled, torch.cuda.amp.autocast
Expand Down Expand Up @@ -42,21 +42,22 @@ class TestMdconv:
def _test_mdconv(self, dtype=torch.float, device='cuda'):
if not torch.cuda.is_available() and device == 'cuda':
pytest.skip('test requires GPU')
from mmcv.ops import ModulatedDeformConv2dPack
if device == 'mlu':
from mmcv.ops import \
ModulatedDeformConv2dPack_MLU as ModulatedDeformConv2dPack
else:
from mmcv.ops import ModulatedDeformConv2dPack

input = torch.tensor(input_t, dtype=dtype, device=device)
input.requires_grad = True

dcn = ModulatedDeformConv2dPack(
1,
1,
kernel_size=(2, 2),
stride=1,
padding=1,
deform_groups=1,
bias=False)

if device == 'cuda':
dcn.cuda()
bias=False).to(device)

dcn.weight.data.fill_(1.)
dcn.type(dtype)
Expand Down Expand Up @@ -114,9 +115,11 @@ def _test_amp_mdconv(self, input_dtype=torch.float):
def test_mdconv(self):
self._test_mdconv(torch.double, device='cpu')
self._test_mdconv(torch.float, device='cpu')
self._test_mdconv(torch.double)
self._test_mdconv(torch.float)
self._test_mdconv(torch.half)

device = 'mlu' if IS_MLU_AVAILABLE else 'cuda'
self._test_mdconv(torch.double, device=device)
self._test_mdconv(torch.float, device=device)
self._test_mdconv(torch.half, device=device)

# test amp when torch version >= '1.6.0', the type of
# input data for mdconv might be torch.float or torch.half
Expand Down

0 comments on commit 27a2b22

Please sign in to comment.