Skip to content

Commit

Permalink
[Feature] Support modulated_deform_conv with cambricon MLU backend
Browse files Browse the repository at this point in the history
  • Loading branch information
mengpenghui committed Nov 15, 2022
1 parent 652b1bf commit 3f72469
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 22 deletions.
2 changes: 1 addition & 1 deletion docs/en/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ We implement common ops used in detection, segmentation, etc.
| MaskedConv | ||| |
| MergeCells | || | |
| MinAreaPolygon | || | |
| ModulatedDeformConv2d ||| | |
| ModulatedDeformConv2d ||| | |
| MultiScaleDeformableAttn | || | |
| NMS |||| |
| NMSRotated ||| | |
Expand Down
2 changes: 1 addition & 1 deletion docs/zh_cn/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ MMCV 提供了检测、分割等任务中常用的算子
| MaskedConv | ||| |
| MergeCells | || | |
| MinAreaPolygon | || | |
| ModulatedDeformConv2d ||| | |
| ModulatedDeformConv2d ||| | |
| MultiScaleDeformableAttn | || | |
| NMS |||| |
| NMSRotated ||| | |
Expand Down
23 changes: 12 additions & 11 deletions mmcv/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from .min_area_polygons import min_area_polygons
from .modulated_deform_conv import (ModulatedDeformConv2d,
ModulatedDeformConv2dPack,
ModulatedDeformConv2dPack_MLU,
modulated_deform_conv2d)
from .multi_scale_deform_attn import MultiScaleDeformableAttention
from .nms import batched_nms, nms, nms_match, nms_quadri, nms_rotated, soft_nms
Expand Down Expand Up @@ -80,17 +81,17 @@
'get_compiler_version', 'get_compiling_cuda_version',
'get_onnxruntime_op_path', 'MaskedConv2d', 'masked_conv2d',
'ModulatedDeformConv2d', 'ModulatedDeformConv2dPack',
'modulated_deform_conv2d', 'batched_nms', 'nms', 'soft_nms', 'nms_match',
'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool', 'SyncBatchNorm', 'Conv2d',
'ConvTranspose2d', 'Linear', 'MaxPool2d', 'CrissCrossAttention', 'PSAMask',
'point_sample', 'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign',
'SAConv2d', 'TINShift', 'tin_shift', 'assign_score_withk',
'box_iou_rotated', 'box_iou_quadri', 'RoIPointPool3d', 'nms_rotated',
'knn', 'ball_query', 'upfirdn2d', 'FusedBiasLeakyReLU',
'fused_bias_leakyrelu', 'rotated_feature_align', 'RiRoIAlignRotated',
'riroi_align_rotated', 'RoIAlignRotated', 'roi_align_rotated',
'pixel_group', 'QueryAndGroup', 'GroupAll', 'grouping_operation',
'contour_expand', 'three_nn', 'three_interpolate',
'ModulatedDeformConv2dPack_MLU', 'modulated_deform_conv2d', 'batched_nms',
'nms', 'soft_nms', 'nms_match', 'RoIAlign', 'roi_align', 'RoIPool',
'roi_pool', 'SyncBatchNorm', 'Conv2d', 'ConvTranspose2d', 'Linear',
'MaxPool2d', 'CrissCrossAttention', 'PSAMask', 'point_sample',
'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign', 'SAConv2d', 'TINShift',
'tin_shift', 'assign_score_withk', 'box_iou_rotated', 'box_iou_quadri',
'RoIPointPool3d', 'nms_rotated', 'knn', 'ball_query', 'upfirdn2d',
'FusedBiasLeakyReLU', 'fused_bias_leakyrelu', 'rotated_feature_align',
'RiRoIAlignRotated', 'riroi_align_rotated', 'RoIAlignRotated',
'roi_align_rotated', 'pixel_group', 'QueryAndGroup', 'GroupAll',
'grouping_operation', 'contour_expand', 'three_nn', 'three_interpolate',
'MultiScaleDeformableAttention', 'BorderAlign', 'border_align',
'gather_points', 'furthest_point_sample', 'nms_quadri',
'furthest_point_sample_with_dist', 'PointsSampler', 'Correlation',
Expand Down
71 changes: 70 additions & 1 deletion mmcv/ops/modulated_deform_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@
from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair, _single

from mmcv.utils import deprecated_api_warning
from mmcv.utils import IS_MLU_AVAILABLE, deprecated_api_warning
from ..cnn import CONV_LAYERS
from ..utils import ext_loader, print_log

if IS_MLU_AVAILABLE:
from torchvision.ops import deform_conv2d as tv_deform_conv2d

ext_module = ext_loader.load_ext(
'_ext',
['modulated_deform_conv_forward', 'modulated_deform_conv_backward'])
Expand Down Expand Up @@ -284,3 +287,69 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
super()._load_from_state_dict(state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys,
error_msgs)


@CONV_LAYERS.register_module(
'DCNv2' if IS_MLU_AVAILABLE else 'DCNv2_disabled',
force=True if IS_MLU_AVAILABLE else False)
class ModulatedDeformConv2dPack_MLU(nn.modules.Module):

def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int]],
stride: int = 1,
padding: int = 0,
dilation: int = 1,
groups: int = 1,
deform_groups: int = 1,
bias: Union[bool, str] = True):
super(ModulatedDeformConv2dPack_MLU, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
self.stride = _pair(stride)
self.padding = _pair(padding)
self.dilation = _pair(dilation)
self.groups = groups
self.deform_groups = deform_groups
self.weight = nn.Parameter(
torch.Tensor(out_channels, in_channels, *self.kernel_size))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.conv_offset = nn.Conv2d(
self.in_channels,
self.deform_groups * 3 * self.kernel_size[0] * self.kernel_size[1],
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.padding,
bias=True)
self.init_weights()

def init_weights(self):
n = self.in_channels
for k in self.kernel_size:
n *= k
stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.zero_()
self.conv_offset.weight.data.zero_()
self.conv_offset.bias.data.zero_()

def forward(self, x):
out = self.conv_offset(x)
o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat((o1, o2), dim=1)
mask = torch.sigmoid(mask)
return tv_deform_conv2d(
x,
offset,
self.weight,
bias=self.bias,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
mask=mask)
21 changes: 13 additions & 8 deletions tests/test_ops/test_modulated_deform_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest
import torch

from mmcv.utils import TORCH_VERSION, digit_version
from mmcv.utils import IS_MLU_AVAILABLE, TORCH_VERSION, digit_version

try:
# If PyTorch version >= 1.6.0 and fp16 is enabled, torch.cuda.amp.autocast
Expand Down Expand Up @@ -39,24 +39,29 @@

class TestMdconv:

def _test_mdconv(self, dtype=torch.float, device='cuda'):
def _test_mdconv(self,
dtype=torch.float,
device='mlu' if IS_MLU_AVAILABLE else 'cuda'):
if not torch.cuda.is_available() and device == 'cuda':
pytest.skip('test requires GPU')
from mmcv.ops import ModulatedDeformConv2dPack
if not torch.mlu.is_available() and device == 'mlu':
pytest.skip('test requires MLU')
if device == 'mlu':
from mmcv.ops import \
ModulatedDeformConv2dPack_MLU as ModulatedDeformConv2dPack
else:
from mmcv.ops import ModulatedDeformConv2dPack

input = torch.tensor(input_t, dtype=dtype, device=device)
input.requires_grad = True

dcn = ModulatedDeformConv2dPack(
1,
1,
kernel_size=(2, 2),
stride=1,
padding=1,
deform_groups=1,
bias=False)

if device == 'cuda':
dcn.cuda()
bias=False).to(device)

dcn.weight.data.fill_(1.)
dcn.type(dtype)
Expand Down

0 comments on commit 3f72469

Please sign in to comment.