Skip to content

Commit

Permalink
added detectron2 detection
Browse files Browse the repository at this point in the history
  • Loading branch information
iKrishneel committed May 7, 2021
1 parent 574200b commit 8f557bc
Show file tree
Hide file tree
Showing 9 changed files with 252 additions and 17 deletions.
50 changes: 50 additions & 0 deletions efficient_net_v2/build.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#!/usr/bin/env python

from detectron2.config import CfgNode, get_cfg
from detectron2.engine import (
DefaultTrainer,
default_argument_parser,
default_setup,
launch
)

from efficient_net_v2.model.backbone import build_effnet_backbone


def setup(args) -> CfgNode:
cfg = get_cfg()
cfg.merge_from_file(args.config_file)

try:
cfg.OUTPUT_DIR = args.output_dir
cfg.MODEL.WEIGHTS = args.weights
except AttributeError as e:
pass

cfg.freeze()
default_setup(cfg, args)
return cfg


def main(args):
cfg = setup(args)

trainer = DefaultTrainer(cfg)
trainer.resume_or_load(resume=args.resume)
return trainer.train()


if __name__ == '__main__':
parser = default_argument_parser()
parser.add_argument('--output_dir', required=True, type=str)
parser.add_argument('--num_gpus', required=False, type=int, default=1)
parser.add_argument('--weights', required=False, type=str, default=None)
args = parser.parse_args()
print(args)

train = True
if train:
launch(main, args.num_gpus, args=(args,))
else:
main(args)

82 changes: 82 additions & 0 deletions efficient_net_v2/config/detectron2_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#!/usr/bin/env python

from yacs.config import CfgNode as CN


_C = CN()

# stages of the network
_C.BACKBONE = CN()

# stage 0
_C.BACKBONE.S0 = CN()
_C.BACKBONE.S0.OPS = 'conv'
_C.BACKBONE.S0.KERNEL = 3
_C.BACKBONE.S0.STRIDE = 2
_C.BACKBONE.S0.CHANNELS = 24
_C.BACKBONE.S0.LAYERS = 1
_C.BACKBONE.S0.PADDING = 1

# stage 1
_C.BACKBONE.S1 = CN()
_C.BACKBONE.S1.OPS = 'fused_mbconv'
_C.BACKBONE.S1.KERNEL = 3
_C.BACKBONE.S1.STRIDE = 1
_C.BACKBONE.S1.EXPANSION = 1
# _C.BACKBONE.S1.SE = 1
_C.BACKBONE.S1.CHANNELS = 24
_C.BACKBONE.S1.LAYERS = 2

# stage 2
_C.BACKBONE.S2 = CN()
_C.BACKBONE.S2.OPS = 'fused_mbconv'
_C.BACKBONE.S2.KERNEL = 3
_C.BACKBONE.S2.STRIDE = 2
_C.BACKBONE.S2.EXPANSION = 4
# _C.BACKBONE.S2.SE = 1
_C.BACKBONE.S2.CHANNELS = 48
_C.BACKBONE.S2.LAYERS = 4

# stage 3
_C.BACKBONE.S3 = CN()
_C.BACKBONE.S3.OPS = 'fused_mbconv'
_C.BACKBONE.S3.KERNEL = 3
_C.BACKBONE.S3.STRIDE = 2
_C.BACKBONE.S3.EXPANSION = 4
# _C.BACKBONE.S3.SE = 1
_C.BACKBONE.S3.CHANNELS = 64
_C.BACKBONE.S3.LAYERS = 4

# stage 4
_C.BACKBONE.S4 = CN()
_C.BACKBONE.S4.OPS = 'mbconv'
_C.BACKBONE.S4.KERNEL = 3
_C.BACKBONE.S4.STRIDE = 2
_C.BACKBONE.S4.EXPANSION = 4
_C.BACKBONE.S4.SE = 4
_C.BACKBONE.S4.CHANNELS = 128
_C.BACKBONE.S4.LAYERS = 6

# stage 5
_C.BACKBONE.S5 = CN()
_C.BACKBONE.S5.OPS = 'mbconv'
_C.BACKBONE.S5.KERNEL = 3
_C.BACKBONE.S5.STRIDE = 1
_C.BACKBONE.S5.EXPANSION = 6
_C.BACKBONE.S5.SE = 4
_C.BACKBONE.S5.CHANNELS = 160
_C.BACKBONE.S5.LAYERS = 9

# stage 6
_C.BACKBONE.S6 = CN()
_C.BACKBONE.S6.OPS = 'mbconv'
_C.BACKBONE.S6.KERNEL = 3
_C.BACKBONE.S6.STRIDE = 2
_C.BACKBONE.S6.EXPANSION = 6
_C.BACKBONE.S6.SE = 4
_C.BACKBONE.S6.CHANNELS = 272
_C.BACKBONE.S6.LAYERS = 15


def get_cfg():
return _C.clone()
41 changes: 41 additions & 0 deletions efficient_net_v2/config/effnet_coco.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
VERSION: 2
INPUT:
MASK_FORMAT: "bitmask"
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
BACKBONE:
NAME: "build_effnet_backbone"
WEIGHTS: ""
MASK_ON: True
LOAD_PROPOSALS: False
RETINANET:
NUM_CLASSES: &num_classes 80
ROI_HEADS:
IN_FEATURES: ["stage7", ]
NUM_CLASSES: *num_classes
NAME: "StandardROIHeads"
ROI_BOX_HEAD:
NAME: "FastRCNNConvFCHead"
NUM_FC: 1
NUM_CONV: 2
NORM: "BN"
RPN:
IN_FEATURES: ["stage7",]
PRE_NMS_TOPK_TRAIN: 6000
POST_NMS_TOPK_TRAIN: 1000
PROPOSAL_GENERATOR:
NAME: "RPN"
SEM_SEG_HEAD:
NUM_CLASSES: *num_classes
SOLVER:
IMS_PER_BATCH: 2
BASE_LR: 0.1
STEPS: (60000, 800000)
MAX_ITER: 20000
INPUT:
MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800)
DATALOADER:
NUM_WORKERS: 2
DATASETS:
TRAIN: ("coco_2017_train", )
TEST: ("coco_2017_val", )
7 changes: 4 additions & 3 deletions efficient_net_v2/layers/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@ class ConvBNA(nn.Module):

def __init__(
self, in_channels: int, out_channels: int,
activation=nn.ReLU(inplace=True), use_bn: bool = True,
**kwargs: dict
use_bn: bool = True, **kwargs: dict
):
super(ConvBNA, self).__init__()

momentum = kwargs.pop('momentum', 0.1)
eps = kwargs.pop('eps', 1e-5)
self.activation = kwargs.pop('activation', nn.ReLU(inplace=True))
self.stride = kwargs.get('stride', 1)
self.out_channels = out_channels

self.conv = nn.Conv2d(
in_channels=in_channels,
Expand All @@ -25,7 +27,6 @@ def __init__(
momentum=momentum,
eps=eps
) if use_bn else None
self.activation = activation

def forward(self, inp):
x = self.conv(inp)
Expand Down
12 changes: 6 additions & 6 deletions efficient_net_v2/layers/mbconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ def __init__(
):
super(MBConv, self).__init__()

out_channels = max(in_channels, out_channels)
self.out_channels = max(in_channels, out_channels)
hidden_channels = in_channels * max(expansion, 1)
reduction = kwargs.get('reduction', 4)
knxn = kwargs.get('knxn', 3)
stride = kwargs.get('stride', 1)
bias = kwargs.get('bias', False)
self.stride = kwargs.get('stride', 1)

self.conv1 = ConvBNA(
in_channels=in_channels,
Expand All @@ -39,7 +39,7 @@ def __init__(
groups=hidden_channels,
kernel_size=knxn,
padding=1,
stride=stride,
stride=self.stride,
bias=bias
)

Expand Down Expand Up @@ -76,19 +76,19 @@ def __init__(
):
super(FusedMBConv, self).__init__()

out_channels = max(in_channels, out_channels)
self.out_channels = max(in_channels, out_channels)
hidden_channels = in_channels * max(expansion, 1)
reduction = kwargs.get('reduction', 4)
knxn = kwargs.get('knxn', 3)
stride = kwargs.get('stride', 1)
bias = kwargs.get('bias', False)
self.stride = kwargs.get('stride', 1)

self.conv1 = ConvBNA(
in_channels=in_channels,
out_channels=hidden_channels,
kernel_size=knxn,
padding=1,
stride=stride,
stride=self.stride,
bias=bias
)
self.se = SE(
Expand Down
1 change: 1 addition & 0 deletions efficient_net_v2/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#!/usr/bin/env python

from .efficient_net_v2 import EfficientNetV2 # NOQA
from .backbone import build_effnet_backbone # NOQA
50 changes: 50 additions & 0 deletions efficient_net_v2/model/backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#!/usr/bin/env python

from typing import List

from detectron2.modeling import (
BACKBONE_REGISTRY, Backbone, ShapeSpec
)

from .efficient_net_v2 import EfficientNetV2
from ..layers import ConvBNA, FusedMBConv, MBConv
from ..config.detectron2_config import get_cfg


class EfficientNet(EfficientNetV2, Backbone):

def __init__(self, cfg, out_features: List[str] = None):
super(EfficientNet, self).__init__(cfg)

self.out_features = (
['stage7'] if out_features is None else out_features
)

self._stride, self._channels = 1, 0
for child in self.backbone.children():
self._stride, self._channels = (
(self._stride * child.stride, child.out_channels)
if isinstance(child, (ConvBNA, FusedMBConv, MBConv))
else (self._stride, self._channels)
)

assert self._channels > 0

def output_shape(self):
return {
name: ShapeSpec(
channels=self._channels, stride=self._stride
)
for name in self.out_features
}

def forward(self, x):
return {
'stage7': super().forward(x)
}


@BACKBONE_REGISTRY.register()
def build_effnet_backbone(cfg=None, input_shape=None):
config = get_cfg()
return EfficientNet(config,)
22 changes: 14 additions & 8 deletions efficient_net_v2/model/efficient_net_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,25 @@

class EfficientNetV2(nn.Module):

def __init__(self, cfg: CN):
def __init__(self, cfg: CN, in_channels: int = 3):
super(EfficientNetV2, self).__init__()

input_shape = cfg.get('INPUTS').get('SHAPE')
# input_shape = cfg.get('INPUTS').get('SHAPE')
backbone = cfg['BACKBONE']
head = cfg['HEAD']
assert len(input_shape) == 3
# assert len(input_shape) == 3
# in_channels = input_shape[0]

in_channels = input_shape[0]
layers, in_channels = self.build(backbone, in_channels)
self.backbone = nn.Sequential(*layers)

layers, _ = self.build(head, in_channels)
self.head = nn.Sequential(*layers)
try:
head = cfg['HEAD']
layers, in_channels = self.build(head, in_channels)
self.head = nn.Sequential(*layers)
except KeyError:
self.head = None

self.out_channels = in_channels

def build(self, nodes, in_channels):
layers = []
Expand Down Expand Up @@ -75,5 +80,6 @@ def create_layer(

def forward(self, x):
x = self.backbone(x)
x = self.head(x)
if self.head is not None:
x = self.head(x)
return x
4 changes: 4 additions & 0 deletions efficient_net_v2/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,10 @@ def validation(self):
self.model.eval()
with torch.no_grad():
pass

def load_state_dict(self, path: str):
state_dict = torch.load(path, map_location=self.device)
self.model.load_state_dict(state_dict, strict=strict)


def main(args):
Expand Down

0 comments on commit 8f557bc

Please sign in to comment.