Skip to content

Commit

Permalink
[Fix] Fix use_depthwise in RTMDet. (#9624)
Browse files Browse the repository at this point in the history
  • Loading branch information
RangiLyu authored Jan 13, 2023
1 parent b83200a commit 92d03df
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
12 changes: 9 additions & 3 deletions mmdet/models/dense_heads/rtmdet_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch
import torch.nn as nn
from mmcv.cnn import ConvModule, Scale, is_norm
from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule, Scale, is_norm
from mmengine.model import bias_init_with_prob, constant_init, normal_init
from mmengine.structures import InstanceData
from torch import Tensor
Expand Down Expand Up @@ -536,6 +536,8 @@ class RTMDetSepBNHead(RTMDetHead):
in_channels (int): Number of channels in the input feature map.
share_conv (bool): Whether to share conv layers between stages.
Defaults to True.
use_depthwise (bool): Whether to use depthwise separable convolution in
head. Defaults to False.
norm_cfg (:obj:`ConfigDict` or dict)): Config dict for normalization
layer. Defaults to dict(type='BN', momentum=0.03, eps=0.001).
act_cfg (:obj:`ConfigDict` or dict)): Config dict for activation layer.
Expand All @@ -547,6 +549,7 @@ def __init__(self,
num_classes: int,
in_channels: int,
share_conv: bool = True,
use_depthwise: bool = False,
norm_cfg: ConfigType = dict(
type='BN', momentum=0.03, eps=0.001),
act_cfg: ConfigType = dict(type='SiLU'),
Expand All @@ -555,6 +558,7 @@ def __init__(self,
**kwargs) -> None:
self.share_conv = share_conv
self.exp_on_reg = exp_on_reg
self.use_depthwise = use_depthwise
super().__init__(
num_classes,
in_channels,
Expand All @@ -565,6 +569,8 @@ def __init__(self,

def _init_layers(self) -> None:
"""Initialize layers of the head."""
conv = DepthwiseSeparableConvModule \
if self.use_depthwise else ConvModule
self.cls_convs = nn.ModuleList()
self.reg_convs = nn.ModuleList()

Expand All @@ -578,7 +584,7 @@ def _init_layers(self) -> None:
for i in range(self.stacked_convs):
chn = self.in_channels if i == 0 else self.feat_channels
cls_convs.append(
ConvModule(
conv(
chn,
self.feat_channels,
3,
Expand All @@ -588,7 +594,7 @@ def _init_layers(self) -> None:
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
reg_convs.append(
ConvModule(
conv(
chn,
self.feat_channels,
3,
Expand Down
4 changes: 3 additions & 1 deletion mmdet/models/necks/cspnext_pafpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def __init__(
in_channels[idx - 1],
num_blocks=num_csp_blocks,
add_identity=False,
use_depthwise=use_depthwise,
use_cspnext_block=True,
expand_ratio=expand_ratio,
conv_cfg=conv_cfg,
Expand All @@ -108,6 +109,7 @@ def __init__(
in_channels[idx + 1],
num_blocks=num_csp_blocks,
add_identity=False,
use_depthwise=use_depthwise,
use_cspnext_block=True,
expand_ratio=expand_ratio,
conv_cfg=conv_cfg,
Expand All @@ -117,7 +119,7 @@ def __init__(
self.out_convs = nn.ModuleList()
for i in range(len(in_channels)):
self.out_convs.append(
ConvModule(
conv(
in_channels[i],
out_channels,
3,
Expand Down

0 comments on commit 92d03df

Please sign in to comment.