Skip to content

Commit

Permalink
Avoid circular dependencies on import
Browse files Browse the repository at this point in the history
  • Loading branch information
mxsrc committed Sep 30, 2022
1 parent 746cdce commit 14ea9e6
Show file tree
Hide file tree
Showing 8 changed files with 30 additions and 14 deletions.
4 changes: 2 additions & 2 deletions mmdet/ops/dcn/functions/deform_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
from torch.autograd import Function
from torch.nn.modules.utils import _pair

from .. import deform_conv_cuda


class DeformConvFunction(Function):

Expand All @@ -18,6 +16,8 @@ def forward(ctx,
groups=1,
deformable_groups=1,
im2col_step=64):
from .. import deform_conv_cuda

if input is not None and input.dim() != 4:
raise ValueError(
"Expected 4D tensor as input, got {}D tensor instead.".format(
Expand Down
6 changes: 4 additions & 2 deletions mmdet/ops/dcn/functions/deform_pool.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import torch
from torch.autograd import Function

from .. import deform_pool_cuda


class DeformRoIPoolingFunction(Function):

Expand All @@ -19,6 +17,8 @@ def forward(ctx,
part_size=None,
sample_per_part=4,
trans_std=.0):
from .. import deform_pool_cuda

ctx.spatial_scale = spatial_scale
ctx.out_size = out_size
ctx.out_channels = out_channels
Expand Down Expand Up @@ -48,6 +48,8 @@ def forward(ctx,

@staticmethod
def backward(ctx, grad_output):
from .. import deform_pool_cuda

if not grad_output.is_cuda:
raise NotImplementedError

Expand Down
10 changes: 8 additions & 2 deletions mmdet/ops/dcn/modules/deform_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
import torch.nn as nn
from torch.nn.modules.utils import _pair

from ..functions.deform_conv import deform_conv, modulated_deform_conv


class DeformConv(nn.Module):

Expand Down Expand Up @@ -52,6 +50,8 @@ def reset_parameters(self):
self.weight.data.uniform_(-stdv, stdv)

def forward(self, x, offset):
from ..functions.deform_conv import deform_conv

return deform_conv(x, offset, self.weight, self.stride, self.padding,
self.dilation, self.groups, self.deformable_groups)

Expand All @@ -76,6 +76,8 @@ def init_offset(self):
self.conv_offset.bias.data.zero_()

def forward(self, x):
from ..functions.deform_conv import deform_conv

offset = self.conv_offset(x)
return deform_conv(x, offset, self.weight, self.stride, self.padding,
self.dilation, self.groups, self.deformable_groups)
Expand Down Expand Up @@ -123,6 +125,8 @@ def reset_parameters(self):
self.bias.data.zero_()

def forward(self, x, offset, mask):
from ..functions.deform_conv import modulated_deform_conv

return modulated_deform_conv(x, offset, mask, self.weight, self.bias,
self.stride, self.padding, self.dilation,
self.groups, self.deformable_groups)
Expand All @@ -148,6 +152,8 @@ def init_offset(self):
self.conv_offset_mask.bias.data.zero_()

def forward(self, x):
from ..functions.deform_conv import modulated_deform_conv

out = self.conv_offset_mask(x)
o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat((o1, o2), dim=1)
Expand Down
3 changes: 2 additions & 1 deletion mmdet/ops/masked_conv/functions/masked_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
import torch
from torch.autograd import Function
from torch.nn.modules.utils import _pair
from .. import masked_conv2d_cuda


class MaskedConv2dFunction(Function):

@staticmethod
def forward(ctx, features, mask, weight, bias, padding=0, stride=1):
from .. import masked_conv2d_cuda

assert mask.dim() == 3 and mask.size(0) == 1
assert features.dim() == 4 and features.size(0) == 1
assert features.size()[2:] == mask.size()[1:]
Expand Down
3 changes: 2 additions & 1 deletion mmdet/ops/nms/nms_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import torch

from . import nms_cuda, nms_cpu
from .soft_nms_cpu import soft_nms_cpu


def nms(dets, iou_thr, device_id=None):
Expand Down Expand Up @@ -50,6 +49,8 @@ def nms(dets, iou_thr, device_id=None):


def soft_nms(dets, iou_thr, method='linear', sigma=0.5, min_score=1e-3):
from .soft_nms_cpu import soft_nms_cpu

if isinstance(dets, torch.Tensor):
is_tensor = True
dets_np = dets.detach().cpu().numpy()
Expand Down
6 changes: 4 additions & 2 deletions mmdet/ops/roi_align/functions/roi_align.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from torch.autograd import Function

from .. import roi_align_cuda


class RoIAlignFunction(Function):

@staticmethod
def forward(ctx, features, rois, out_size, spatial_scale, sample_num=0):
from .. import roi_align_cuda

if isinstance(out_size, int):
out_h = out_size
out_w = out_size
Expand Down Expand Up @@ -37,6 +37,8 @@ def forward(ctx, features, rois, out_size, spatial_scale, sample_num=0):

@staticmethod
def backward(ctx, grad_output):
from .. import roi_align_cuda

feature_size = ctx.feature_size
spatial_scale = ctx.spatial_scale
sample_num = ctx.sample_num
Expand Down
6 changes: 4 additions & 2 deletions mmdet/ops/roi_pool/functions/roi_pool.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import torch
from torch.autograd import Function

from .. import roi_pool_cuda


class RoIPoolFunction(Function):

@staticmethod
def forward(ctx, features, rois, out_size, spatial_scale):
from .. import roi_pool_cuda

if isinstance(out_size, int):
out_h = out_size
out_w = out_size
Expand Down Expand Up @@ -36,6 +36,8 @@ def forward(ctx, features, rois, out_size, spatial_scale):

@staticmethod
def backward(ctx, grad_output):
from .. import roi_pool_cuda

assert grad_output.is_cuda
spatial_scale = ctx.spatial_scale
feature_size = ctx.feature_size
Expand Down
6 changes: 4 additions & 2 deletions mmdet/ops/sigmoid_focal_loss/functions/sigmoid_focal_loss.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from torch.autograd import Function
from torch.autograd.function import once_differentiable

from .. import sigmoid_focal_loss_cuda


class SigmoidFocalLossFunction(Function):

@staticmethod
def forward(ctx, input, target, gamma=2.0, alpha=0.25):
from .. import sigmoid_focal_loss_cuda

ctx.save_for_backward(input, target)
num_classes = input.shape[1]
ctx.num_classes = num_classes
Expand All @@ -21,6 +21,8 @@ def forward(ctx, input, target, gamma=2.0, alpha=0.25):
@staticmethod
@once_differentiable
def backward(ctx, d_loss):
from .. import sigmoid_focal_loss_cuda

input, target = ctx.saved_tensors
num_classes = ctx.num_classes
gamma = ctx.gamma
Expand Down

0 comments on commit 14ea9e6

Please sign in to comment.