Skip to content
This repository was archived by the owner on Oct 31, 2023. It is now read-only.

ConvNeXt port for Detectron2 framework. #43

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion object_detection/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,4 @@ tools/dist_test.sh <CONFIG_FILE> <DET_CHECKPOINT_FILE> <GPU_NUM> --eval bbox seg

## Acknowledgment

This code is built using [mmdetection](https://github.com/open-mmlab/mmdetection), [timm](https://github.com/rwightman/pytorch-image-models) libraries, and [BeiT](https://github.com/microsoft/unilm/tree/f8f3df80c65eb5e5fc6d6d3c9bd3137621795d1e/beit), [Swin Transformer](https://github.com/microsoft/Swin-Transformer) repositories.
This code is built using the [mmdetection](https://github.com/open-mmlab/mmdetection) library, [Timm](https://github.com/rwightman/pytorch-image-models) library, the [BeiT](https://github.com/microsoft/unilm/tree/f8f3df80c65eb5e5fc6d6d3c9bd3137621795d1e/beit) repository, the [Swin](https://github.com/microsoft/Swin-Transformer).
2 changes: 2 additions & 0 deletions object_detection/detectron2/models/backbones/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .config import add_convnext_config
from .convnext import build_convnext_fpn_backbone
13 changes: 13 additions & 0 deletions object_detection/detectron2/models/backbones/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from detectron2.config import CfgNode as CN


def add_convnext_config(cfg):
# extra configs for convnext
cfg.MODEL.CONVNEXT = CN()
cfg.MODEL.CONVNEXT.DEPTHS= [3, 3, 9, 3]
cfg.MODEL.CONVNEXT.DIMS= [96, 192, 384, 768]
cfg.MODEL.CONVNEXT.DROP_PATH_RATE= 0.2
cfg.MODEL.CONVNEXT.LAYER_SCALE_INIT_VALUE= 1e-6
cfg.MODEL.CONVNEXT.OUT_FEATURES= [0, 1, 2, 3]
cfg.SOLVER.WEIGHT_DECAY_RATE= 0.95

229 changes: 229 additions & 0 deletions object_detection/detectron2/models/backbones/convnext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.


from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import trunc_normal_, DropPath
import torch.utils.checkpoint as checkpoint
import numpy as np

from detectron2.modeling.backbone import Backbone
from detectron2.modeling.backbone.build import BACKBONE_REGISTRY
from detectron2.modeling.backbone.fpn import FPN, LastLevelMaxPool, LastLevelP6P7
from detectron2.layers import ShapeSpec

class Block(nn.Module):
r""" ConvNeXt Block. There are two equivalent implementations:
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
We use (2) as we find it slightly faster in PyTorch

Args:
dim (int): Number of input channels.
drop_path (float): Stochastic depth rate. Default: 0.0
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
"""
def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
super().__init__()
self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
self.norm = LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.pwconv2 = nn.Linear(4 * dim, dim)
self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
requires_grad=True) if layer_scale_init_value > 0 else None
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

def forward(self, x):
input = x
x = self.dwconv(x)
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
if self.gamma is not None:
x = self.gamma * x
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)

x = input + self.drop_path(x)
return x

class ConvNeXt(Backbone):
r""" ConvNeXt
A PyTorch impl of : `A ConvNet for the 2020s` -
https://arxiv.org/pdf/2201.03545.pdf
Args:
in_chans (int): Number of input image channels. Default: 3
num_classes (int): Number of classes for classification head. Default: 1000
depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
drop_path_rate (float): Stochastic depth rate. Default: 0.
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
out_features (tuple(int)): Stage numbers of the outputs given to the Neck.
"""
def __init__(self, in_chans=3, depths=[3, 3, 9, 3], dims=[96, 192, 384, 768],
drop_path_rate=0., layer_scale_init_value=1e-6, out_features=None):
super().__init__()

self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
stem = nn.Sequential(
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
)

self.downsample_layers.append(stem)
for i in range(3):
downsample_layer = nn.Sequential(
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
)
self.downsample_layers.append(downsample_layer)


self.num_layers = len(depths)
num_features = [int(dims[i] * 2 ** i) for i in range(self.num_layers)]
self.num_features = num_features
self._out_features = out_features

self._out_feature_strides = {}
self._out_feature_channels = {}

self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
cur = 0
strides = [4,4,4,4]
for i in range(4):
stage = nn.Sequential(
*[Block(dim=dims[i], drop_path=dp_rates[cur + j],
layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
)
self.stages.append(stage)
cur += depths[i]

self._out_feature_channels[i] = dims[i]
self._out_feature_strides[i] = strides[i] * 2 ** i

norm_layer = partial(LayerNorm, eps=1e-6, data_format="channels_first")
for i_layer in range(4):
layer = norm_layer(dims[i_layer])
layer_name = f'norm{i_layer}'
self.add_module(layer_name, layer)

self.apply(self._init_weights)

def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
trunc_normal_(m.weight, std=.02)
nn.init.constant_(m.bias, 0)

def init_weights(self, pretrained=None):
"""Initialize the weights in backbone.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""

def _init_weights(m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)

self.apply(_init_weights)

def forward_features(self, x):
outs ={}
for i in range(4):
x = self.downsample_layers[i](x)
x = self.stages[i](x)
if i in self._out_features:
norm_layer = getattr(self, f'norm{i}')
x_out = norm_layer(x)
out = x_out.contiguous()
stage_name = i
outs[stage_name] = out

return outs # {"stage%d" % (i+2,): out for i, out in enumerate(outs)} #tuple(outs)

def forward(self, x):
x = self.forward_features(x)
return x

class LayerNorm(nn.Module):
r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
with shape (batch_size, channels, height, width).
"""
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError
self.normalized_shape = (normalized_shape, )

def forward(self, x):
if self.data_format == "channels_last":
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
elif self.data_format == "channels_first":
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x

@BACKBONE_REGISTRY.register()
def build_convnext_backbone(cfg, input_shape):
"""
Create a ConvNeXt instance from config.

Returns:
VoVNet: a :class:`VoVNet` instance.
"""
out_features = cfg.MODEL.CONVNEXT.OUT_FEATURE_INDICES
return ConvNeXt(
in_chans=input_shape.channels,
depths=cfg.MODEL.CONVNEXT.DEPTHS,
dims=cfg.MODEL.CONVNEXT.DIMS,
drop_path_rate=cfg.MODEL.CONVNEXT.DROP_PATH_RATE,
layer_scale_init_value=cfg.MODEL.CONVNEXT.LAYER_SCALE_INIT_VALUE,
out_features=cfg.MODEL.CONVNEXT.OUT_FEATURES
)

@BACKBONE_REGISTRY.register()
def build_convnext_fpn_backbone(cfg, input_shape: ShapeSpec):
"""
Args:
cfg: a detectron2 CfgNode

Returns:
backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`.
"""
bottom_up = build_convnext_backbone(cfg, input_shape)
in_features = cfg.MODEL.FPN.IN_FEATURES
out_channels = cfg.MODEL.FPN.OUT_CHANNELS
backbone = FPN(
bottom_up=bottom_up,
in_features=in_features,
out_channels=out_channels,
norm=cfg.MODEL.FPN.NORM,
top_block=LastLevelMaxPool(),
fuse_type=cfg.MODEL.FPN.FUSE_TYPE,
)
return backbone
42 changes: 42 additions & 0 deletions object_detection/detectron2/models/configs/Base-RCNN-FPN.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
BACKBONE:
NAME: "build_resnet_fpn_backbone"
RESNETS:
OUT_FEATURES: ["res2", "res3", "res4", "res5"]
FPN:
IN_FEATURES: [0, 1, 2, 3]
ANCHOR_GENERATOR:
SIZES: [[32], [64], [128], [256], [512]] # One size for each in feature map
ASPECT_RATIOS: [[0.5, 1.0, 2.0]] # Three aspect ratios (same for all in feature maps)
RPN:
IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"]
PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level
PRE_NMS_TOPK_TEST: 1000 # Per FPN level
# Detectron1 uses 2000 proposals per-batch,
# (See "modeling/rpn/rpn_outputs.py" for details of this legacy issue)
# which is approximately 1000 proposals per-image since the default batch size for FPN is 2.
POST_NMS_TOPK_TRAIN: 1000
POST_NMS_TOPK_TEST: 1000
ROI_HEADS:
NAME: "StandardROIHeads"
IN_FEATURES: ["p2", "p3", "p4", "p5"]
ROI_BOX_HEAD:
NAME: "FastRCNNConvFCHead"
NUM_FC: 2
POOLER_RESOLUTION: 7
ROI_MASK_HEAD:
NAME: "MaskRCNNConvUpsampleHead"
NUM_CONV: 4
POOLER_RESOLUTION: 14
DATASETS:
TRAIN: ("coco_2017_train",)
TEST: ("coco_2017_val",)
SOLVER:
IMS_PER_BATCH: 16
BASE_LR: 0.02
STEPS: (60000, 80000)
MAX_ITER: 90000
INPUT:
MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800)
VERSION: 2
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
_BASE_: "Base-RCNN-FPN.yaml"
MODEL:
MASK_ON: True
BACKBONE:
NAME: "build_convnext_fpn_backbone"
CONVNEXT:
DEPTHS: [3, 3, 9, 3]
DIMS: [96, 192, 384, 768]
DROP_PATH_RATE: 0.2
LAYER_SCALE_INIT_VALUE: 1e-6
OUT_FEATURES: [0, 1, 2, 3]
FPN:
IN_FEATURES: [0, 1, 2, 3]
ANCHOR_GENERATOR:
SIZES: [[64], [128], [256], [512], [1024]] # One size for each in feature map
ASPECT_RATIOS: [[0.5, 1.0, 2.0]] # Three aspect ratios (same for all in feature maps)
RPN:
IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"]
PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level
PRE_NMS_TOPK_TEST: 1000 # Per FPN level
# Detectron1 uses 2000 proposals per-batch,
# (See "modeling/rpn/rpn_outputs.py" for details of this legacy issue)
# which is approximately 1000 proposals per-image since the default batch size for FPN is 2.
POST_NMS_TOPK_TRAIN: 1000
POST_NMS_TOPK_TEST: 1000
ROI_HEADS:
NAME: "StandardROIHeads"
IN_FEATURES: ["p2", "p3", "p4", "p5"]
NUM_CLASSES: 13
ROI_BOX_HEAD:
NAME: "FastRCNNConvFCHead"
NUM_FC: 2
POOLER_RESOLUTION: 7
ROI_MASK_HEAD:
NAME: "MaskRCNNConvUpsampleHead"
NUM_CONV: 4
POOLER_RESOLUTION: 14
INPUT:
FORMAT: "RGB"
SOLVER:
WEIGHT_DECAY: 0.05
WEIGHT_DECAY_RATE: 0.95
OPTIMIZER: "ADAMW"
AMP:
ENABLED: True
DATASETS:
TRAIN: ("coco_2017_train",)
TEST: ("coco_2017_val",)