From 14ea9e64a5f3ecebdd991a430c03d563083e9b3c Mon Sep 17 00:00:00 2001 From: Max Schettler Date: Fri, 30 Sep 2022 11:43:31 +0200 Subject: [PATCH] Avoid circular dependencies on import --- mmdet/ops/dcn/functions/deform_conv.py | 4 ++-- mmdet/ops/dcn/functions/deform_pool.py | 6 ++++-- mmdet/ops/dcn/modules/deform_conv.py | 10 ++++++++-- mmdet/ops/masked_conv/functions/masked_conv.py | 3 ++- mmdet/ops/nms/nms_wrapper.py | 3 ++- mmdet/ops/roi_align/functions/roi_align.py | 6 ++++-- mmdet/ops/roi_pool/functions/roi_pool.py | 6 ++++-- .../sigmoid_focal_loss/functions/sigmoid_focal_loss.py | 6 ++++-- 8 files changed, 30 insertions(+), 14 deletions(-) diff --git a/mmdet/ops/dcn/functions/deform_conv.py b/mmdet/ops/dcn/functions/deform_conv.py index 6af75a75..c693c563 100755 --- a/mmdet/ops/dcn/functions/deform_conv.py +++ b/mmdet/ops/dcn/functions/deform_conv.py @@ -2,8 +2,6 @@ from torch.autograd import Function from torch.nn.modules.utils import _pair -from .. import deform_conv_cuda - class DeformConvFunction(Function): @@ -18,6 +16,8 @@ def forward(ctx, groups=1, deformable_groups=1, im2col_step=64): + from .. import deform_conv_cuda + if input is not None and input.dim() != 4: raise ValueError( "Expected 4D tensor as input, got {}D tensor instead.".format( diff --git a/mmdet/ops/dcn/functions/deform_pool.py b/mmdet/ops/dcn/functions/deform_pool.py index 65ff0efb..42190e72 100755 --- a/mmdet/ops/dcn/functions/deform_pool.py +++ b/mmdet/ops/dcn/functions/deform_pool.py @@ -1,8 +1,6 @@ import torch from torch.autograd import Function -from .. import deform_pool_cuda - class DeformRoIPoolingFunction(Function): @@ -19,6 +17,8 @@ def forward(ctx, part_size=None, sample_per_part=4, trans_std=.0): + from .. import deform_pool_cuda + ctx.spatial_scale = spatial_scale ctx.out_size = out_size ctx.out_channels = out_channels @@ -48,6 +48,8 @@ def forward(ctx, @staticmethod def backward(ctx, grad_output): + from .. import deform_pool_cuda + if not grad_output.is_cuda: raise NotImplementedError diff --git a/mmdet/ops/dcn/modules/deform_conv.py b/mmdet/ops/dcn/modules/deform_conv.py index 50d15d15..e4f76b37 100755 --- a/mmdet/ops/dcn/modules/deform_conv.py +++ b/mmdet/ops/dcn/modules/deform_conv.py @@ -4,8 +4,6 @@ import torch.nn as nn from torch.nn.modules.utils import _pair -from ..functions.deform_conv import deform_conv, modulated_deform_conv - class DeformConv(nn.Module): @@ -52,6 +50,8 @@ def reset_parameters(self): self.weight.data.uniform_(-stdv, stdv) def forward(self, x, offset): + from ..functions.deform_conv import deform_conv + return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups, self.deformable_groups) @@ -76,6 +76,8 @@ def init_offset(self): self.conv_offset.bias.data.zero_() def forward(self, x): + from ..functions.deform_conv import deform_conv + offset = self.conv_offset(x) return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups, self.deformable_groups) @@ -123,6 +125,8 @@ def reset_parameters(self): self.bias.data.zero_() def forward(self, x, offset, mask): + from ..functions.deform_conv import modulated_deform_conv + return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups, self.deformable_groups) @@ -148,6 +152,8 @@ def init_offset(self): self.conv_offset_mask.bias.data.zero_() def forward(self, x): + from ..functions.deform_conv import modulated_deform_conv + out = self.conv_offset_mask(x) o1, o2, mask = torch.chunk(out, 3, dim=1) offset = torch.cat((o1, o2), dim=1) diff --git a/mmdet/ops/masked_conv/functions/masked_conv.py b/mmdet/ops/masked_conv/functions/masked_conv.py index eed32b73..04e4b8ae 100755 --- a/mmdet/ops/masked_conv/functions/masked_conv.py +++ b/mmdet/ops/masked_conv/functions/masked_conv.py @@ -2,13 +2,14 @@ import torch from torch.autograd import Function from torch.nn.modules.utils import _pair -from .. import masked_conv2d_cuda class MaskedConv2dFunction(Function): @staticmethod def forward(ctx, features, mask, weight, bias, padding=0, stride=1): + from .. import masked_conv2d_cuda + assert mask.dim() == 3 and mask.size(0) == 1 assert features.dim() == 4 and features.size(0) == 1 assert features.size()[2:] == mask.size()[1:] diff --git a/mmdet/ops/nms/nms_wrapper.py b/mmdet/ops/nms/nms_wrapper.py index 8ce5bc44..29b14ace 100755 --- a/mmdet/ops/nms/nms_wrapper.py +++ b/mmdet/ops/nms/nms_wrapper.py @@ -2,7 +2,6 @@ import torch from . import nms_cuda, nms_cpu -from .soft_nms_cpu import soft_nms_cpu def nms(dets, iou_thr, device_id=None): @@ -50,6 +49,8 @@ def nms(dets, iou_thr, device_id=None): def soft_nms(dets, iou_thr, method='linear', sigma=0.5, min_score=1e-3): + from .soft_nms_cpu import soft_nms_cpu + if isinstance(dets, torch.Tensor): is_tensor = True dets_np = dets.detach().cpu().numpy() diff --git a/mmdet/ops/roi_align/functions/roi_align.py b/mmdet/ops/roi_align/functions/roi_align.py index 096badd2..34dfa171 100755 --- a/mmdet/ops/roi_align/functions/roi_align.py +++ b/mmdet/ops/roi_align/functions/roi_align.py @@ -1,12 +1,12 @@ from torch.autograd import Function -from .. import roi_align_cuda - class RoIAlignFunction(Function): @staticmethod def forward(ctx, features, rois, out_size, spatial_scale, sample_num=0): + from .. import roi_align_cuda + if isinstance(out_size, int): out_h = out_size out_w = out_size @@ -37,6 +37,8 @@ def forward(ctx, features, rois, out_size, spatial_scale, sample_num=0): @staticmethod def backward(ctx, grad_output): + from .. import roi_align_cuda + feature_size = ctx.feature_size spatial_scale = ctx.spatial_scale sample_num = ctx.sample_num diff --git a/mmdet/ops/roi_pool/functions/roi_pool.py b/mmdet/ops/roi_pool/functions/roi_pool.py index 068da600..aae9d666 100755 --- a/mmdet/ops/roi_pool/functions/roi_pool.py +++ b/mmdet/ops/roi_pool/functions/roi_pool.py @@ -1,13 +1,13 @@ import torch from torch.autograd import Function -from .. import roi_pool_cuda - class RoIPoolFunction(Function): @staticmethod def forward(ctx, features, rois, out_size, spatial_scale): + from .. import roi_pool_cuda + if isinstance(out_size, int): out_h = out_size out_w = out_size @@ -36,6 +36,8 @@ def forward(ctx, features, rois, out_size, spatial_scale): @staticmethod def backward(ctx, grad_output): + from .. import roi_pool_cuda + assert grad_output.is_cuda spatial_scale = ctx.spatial_scale feature_size = ctx.feature_size diff --git a/mmdet/ops/sigmoid_focal_loss/functions/sigmoid_focal_loss.py b/mmdet/ops/sigmoid_focal_loss/functions/sigmoid_focal_loss.py index e690f763..85696ba0 100755 --- a/mmdet/ops/sigmoid_focal_loss/functions/sigmoid_focal_loss.py +++ b/mmdet/ops/sigmoid_focal_loss/functions/sigmoid_focal_loss.py @@ -1,13 +1,13 @@ from torch.autograd import Function from torch.autograd.function import once_differentiable -from .. import sigmoid_focal_loss_cuda - class SigmoidFocalLossFunction(Function): @staticmethod def forward(ctx, input, target, gamma=2.0, alpha=0.25): + from .. import sigmoid_focal_loss_cuda + ctx.save_for_backward(input, target) num_classes = input.shape[1] ctx.num_classes = num_classes @@ -21,6 +21,8 @@ def forward(ctx, input, target, gamma=2.0, alpha=0.25): @staticmethod @once_differentiable def backward(ctx, d_loss): + from .. import sigmoid_focal_loss_cuda + input, target = ctx.saved_tensors num_classes = ctx.num_classes gamma = ctx.gamma