-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add SSDlite architecture with MobileNetV3 backbones #3757
Changes from all commits
6e87247
1f1381d
41e107b
cd25aef
e3680ad
1fafb8f
415058b
d17eb6c
f318332
f4b907d
6d40406
34c2769
ea46bfc
0dca06c
7cce538
8b9ca53
f8cbe46
8aa3f58
da81b69
5fbc112
d4024cb
4ca472e
d8d55b7
bad974a
4020705
c0bfc51
d05ef49
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,3 +3,4 @@ | |
from .keypoint_rcnn import * | ||
from .retinanet import * | ||
from .ssd import * | ||
from .ssdlite import * |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,228 @@ | ||
import torch | ||
|
||
from collections import OrderedDict | ||
from functools import partial | ||
from torch import nn, Tensor | ||
from typing import Any, Callable, Dict, List, Optional, Tuple | ||
|
||
from . import _utils as det_utils | ||
from .ssd import SSD, SSDScoringHead | ||
from .anchor_utils import DefaultBoxGenerator | ||
from .backbone_utils import _validate_trainable_layers | ||
from .. import mobilenet | ||
from ..mobilenetv3 import ConvBNActivation | ||
from ..utils import load_state_dict_from_url | ||
|
||
|
||
__all__ = ['ssdlite320_mobilenet_v3_large'] | ||
|
||
model_urls = { | ||
'ssdlite320_mobilenet_v3_large_coco': | ||
'https://download.pytorch.org/models/ssdlite320_mobilenet_v3_large_coco-a79551df.pth' | ||
} | ||
|
||
|
||
def _prediction_block(in_channels: int, out_channels: int, kernel_size: int, | ||
norm_layer: Callable[..., nn.Module]) -> nn.Sequential: | ||
return nn.Sequential( | ||
# 3x3 depthwise with stride 1 and padding 1 | ||
ConvBNActivation(in_channels, in_channels, kernel_size=kernel_size, groups=in_channels, | ||
norm_layer=norm_layer, activation_layer=nn.ReLU6), | ||
|
||
# 1x1 projetion to output channels | ||
nn.Conv2d(in_channels, out_channels, 1) | ||
) | ||
|
||
|
||
def _extra_block(in_channels: int, out_channels: int, norm_layer: Callable[..., nn.Module]) -> nn.Sequential: | ||
activation = nn.ReLU6 | ||
intermediate_channels = out_channels // 2 | ||
return nn.Sequential( | ||
# 1x1 projection to half output channels | ||
ConvBNActivation(in_channels, intermediate_channels, kernel_size=1, | ||
norm_layer=norm_layer, activation_layer=activation), | ||
|
||
# 3x3 depthwise with stride 2 and padding 1 | ||
ConvBNActivation(intermediate_channels, intermediate_channels, kernel_size=3, stride=2, | ||
groups=intermediate_channels, norm_layer=norm_layer, activation_layer=activation), | ||
|
||
# 1x1 projetion to output channels | ||
ConvBNActivation(intermediate_channels, out_channels, kernel_size=1, | ||
norm_layer=norm_layer, activation_layer=activation), | ||
) | ||
|
||
|
||
def _normal_init(conv: nn.Module): | ||
for layer in conv.modules(): | ||
if isinstance(layer, nn.Conv2d): | ||
torch.nn.init.normal_(layer.weight, mean=0.0, std=0.03) | ||
if layer.bias is not None: | ||
torch.nn.init.constant_(layer.bias, 0.0) | ||
|
||
|
||
class SSDLiteHead(nn.Module): | ||
def __init__(self, in_channels: List[int], num_anchors: List[int], num_classes: int, | ||
norm_layer: Callable[..., nn.Module]): | ||
super().__init__() | ||
self.classification_head = SSDLiteClassificationHead(in_channels, num_anchors, num_classes, norm_layer) | ||
self.regression_head = SSDLiteRegressionHead(in_channels, num_anchors, norm_layer) | ||
|
||
def forward(self, x: List[Tensor]) -> Dict[str, Tensor]: | ||
return { | ||
'bbox_regression': self.regression_head(x), | ||
'cls_logits': self.classification_head(x), | ||
} | ||
|
||
|
||
class SSDLiteClassificationHead(SSDScoringHead): | ||
def __init__(self, in_channels: List[int], num_anchors: List[int], num_classes: int, | ||
norm_layer: Callable[..., nn.Module]): | ||
cls_logits = nn.ModuleList() | ||
for channels, anchors in zip(in_channels, num_anchors): | ||
cls_logits.append(_prediction_block(channels, num_classes * anchors, 3, norm_layer)) | ||
_normal_init(cls_logits) | ||
super().__init__(cls_logits, num_classes) | ||
|
||
|
||
class SSDLiteRegressionHead(SSDScoringHead): | ||
def __init__(self, in_channels: List[int], num_anchors: List[int], norm_layer: Callable[..., nn.Module]): | ||
bbox_reg = nn.ModuleList() | ||
for channels, anchors in zip(in_channels, num_anchors): | ||
bbox_reg.append(_prediction_block(channels, 4 * anchors, 3, norm_layer)) | ||
_normal_init(bbox_reg) | ||
super().__init__(bbox_reg, 4) | ||
|
||
|
||
class SSDLiteFeatureExtractorMobileNet(nn.Module): | ||
def __init__(self, backbone: nn.Module, c4_pos: int, norm_layer: Callable[..., nn.Module], rescaling: bool, | ||
**kwargs: Any): | ||
super().__init__() | ||
# non-public config parameters | ||
min_depth = kwargs.pop('_min_depth', 16) | ||
width_mult = kwargs.pop('_width_mult', 1.0) | ||
|
||
assert not backbone[c4_pos].use_res_connect | ||
self.features = nn.Sequential( | ||
nn.Sequential(*backbone[:c4_pos], backbone[c4_pos].block[0]), # from start until C4 expansion layer | ||
nn.Sequential(backbone[c4_pos].block[1:], *backbone[c4_pos + 1:]), # from C4 depthwise until end | ||
) | ||
|
||
get_depth = lambda d: max(min_depth, int(d * width_mult)) # noqa: E731 | ||
extra = nn.ModuleList([ | ||
_extra_block(backbone[-1].out_channels, get_depth(512), norm_layer), | ||
_extra_block(get_depth(512), get_depth(256), norm_layer), | ||
_extra_block(get_depth(256), get_depth(256), norm_layer), | ||
_extra_block(get_depth(256), get_depth(128), norm_layer), | ||
]) | ||
_normal_init(extra) | ||
|
||
self.extra = extra | ||
self.rescaling = rescaling | ||
|
||
def forward(self, x: Tensor) -> Dict[str, Tensor]: | ||
# Rescale from [0, 1] to [-1, -1] | ||
if self.rescaling: | ||
x = 2.0 * x - 1.0 | ||
|
||
# Get feature maps from backbone and extra. Can't be refactored due to JIT limitations. | ||
output = [] | ||
for block in self.features: | ||
x = block(x) | ||
output.append(x) | ||
|
||
for block in self.extra: | ||
x = block(x) | ||
output.append(x) | ||
|
||
return OrderedDict([(str(i), v) for i, v in enumerate(output)]) | ||
|
||
|
||
def _mobilenet_extractor(backbone_name: str, progress: bool, pretrained: bool, trainable_layers: int, | ||
norm_layer: Callable[..., nn.Module], rescaling: bool, **kwargs: Any): | ||
backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, progress=progress, | ||
norm_layer=norm_layer, **kwargs).features | ||
if not pretrained: | ||
# Change the default initialization scheme if not pretrained | ||
_normal_init(backbone) | ||
|
||
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. | ||
# The first and last blocks are always included because they are the C0 (conv1) and Cn. | ||
stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1] | ||
num_stages = len(stage_indices) | ||
|
||
# find the index of the layer from which we wont freeze | ||
assert 0 <= trainable_layers <= num_stages | ||
freeze_before = num_stages if trainable_layers == 0 else stage_indices[num_stages - trainable_layers] | ||
|
||
for b in backbone[:freeze_before]: | ||
for parameter in b.parameters(): | ||
parameter.requires_grad_(False) | ||
|
||
return SSDLiteFeatureExtractorMobileNet(backbone, stage_indices[-2], norm_layer, rescaling, **kwargs) | ||
|
||
|
||
def ssdlite320_mobilenet_v3_large(pretrained: bool = False, progress: bool = True, num_classes: int = 91, | ||
pretrained_backbone: bool = False, trainable_backbone_layers: Optional[int] = None, | ||
norm_layer: Optional[Callable[..., nn.Module]] = None, | ||
**kwargs: Any): | ||
""" | ||
Constructs an SSDlite model with a MobileNetV3 Large backbone. See `SSD` for more details. | ||
|
||
Example: | ||
|
||
>>> model = torchvision.models.detection.ssdlite320_mobilenet_v3_large(pretrained=True) | ||
>>> model.eval() | ||
>>> x = [torch.rand(3, 320, 320), torch.rand(3, 500, 400)] | ||
>>> predictions = model(x) | ||
|
||
Args: | ||
norm_layer: | ||
**kwargs: | ||
pretrained (bool): If True, returns a model pre-trained on COCO train2017 | ||
progress (bool): If True, displays a progress bar of the download to stderr | ||
num_classes (int): number of output classes of the model (including the background) | ||
pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet | ||
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. | ||
Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable. | ||
norm_layer (callable, optional): Module specifying the normalization layer to use. | ||
""" | ||
trainable_backbone_layers = _validate_trainable_layers( | ||
pretrained or pretrained_backbone, trainable_backbone_layers, 6, 6) | ||
|
||
if pretrained: | ||
pretrained_backbone = False | ||
|
||
# Enable [-1, 1] rescaling and reduced tail if no pretrained backbone is selected | ||
rescaling = reduce_tail = not pretrained_backbone | ||
Comment on lines
+195
to
+196
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a bit confusing, but I assume the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That is correct. Rescaling was part of the changes needed to boost the accuracy by 1mAP. |
||
|
||
if norm_layer is None: | ||
norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.03) | ||
|
||
backbone = _mobilenet_extractor("mobilenet_v3_large", progress, pretrained_backbone, trainable_backbone_layers, | ||
norm_layer, rescaling, _reduced_tail=reduce_tail, _width_mult=1.0) | ||
|
||
size = (320, 320) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This means that the size is hard-coded and even if the user passes a different What about doing something like size = kwargs.get("size", (320, 320)) instead, so that the users can potentially customize the input size if they wish? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I chose to hardcode it because this is the If someone wants to use a different size, it would be simpler to just create the backbone, configure the DefaultBoxGenerator and then initialize directly the |
||
anchor_generator = DefaultBoxGenerator([[2, 3] for _ in range(6)], min_ratio=0.2, max_ratio=0.95) | ||
out_channels = det_utils.retrieve_out_channels(backbone, size) | ||
num_anchors = anchor_generator.num_anchors_per_location() | ||
assert len(out_channels) == len(anchor_generator.aspect_ratios) | ||
|
||
defaults = { | ||
"score_thresh": 0.001, | ||
"nms_thresh": 0.55, | ||
"detections_per_img": 300, | ||
"topk_candidates": 300, | ||
"image_mean": [0., 0., 0.], | ||
"image_std": [1., 1., 1.], | ||
} | ||
kwargs = {**defaults, **kwargs} | ||
model = SSD(backbone, anchor_generator, size, num_classes, | ||
head=SSDLiteHead(out_channels, num_anchors, num_classes, norm_layer), **kwargs) | ||
|
||
if pretrained: | ||
weights_name = 'ssdlite320_mobilenet_v3_large_coco' | ||
if model_urls.get(weights_name, None) is None: | ||
raise ValueError("No checkpoint is available for model {}".format(weights_name)) | ||
state_dict = load_state_dict_from_url(model_urls[weights_name], progress=progress) | ||
model.load_state_dict(state_dict) | ||
return model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@datumbox could you please help me figure it out - I cannot find the info about these extra layers in the papers. Where did you get them from?
I'm trying to create a modification for this model and struggle to understand it - any help would be appreciated!
I want to reduce the number of encoder layers to make feature maps detect small objects.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@evekeen I have written a blogpost about the implementation details of this model. See here. The extra layers are described on section 6.3 of the paper though to get their exact values you need to dig in the original TF code. Hope that helps!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@datumbox Thank you for the quick reply! It's very helpful
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@datumbox in 6.3 of MobileNet3 paper, I only see the info on connecting C4 and C5 layers to the SSD head. There is nothing on these extra layers there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you checked the reference code I sent? This comes from their official repo.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I see that in the TensorFlow implementation.
I'm trying to understand if I'm reducing the depth of C4 (and thus the output stride for targeting super small objects) - how should I change the rest of the layers?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, it's been quite sometime since I wrote the implementation. I think you will need to dig into the original research repo to get the details.