Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SSDlite architecture with MobileNetV3 backbones #3757

Merged
merged 27 commits into from
May 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
6e87247
Partial implementation of SSDlite.
datumbox Apr 30, 2021
1f1381d
Add normal init and BN hyperparams.
datumbox May 1, 2021
41e107b
Refactor to keep JIT happy
datumbox May 1, 2021
cd25aef
Completed SSDlite.
datumbox May 1, 2021
e3680ad
Fix lint
datumbox May 1, 2021
1fafb8f
Update todos
datumbox May 1, 2021
415058b
Add expected file in repo.
datumbox May 1, 2021
d17eb6c
Use C4 expansion instead of C4 output.
datumbox May 1, 2021
f318332
Change scales formula for Default Boxes.
datumbox May 1, 2021
f4b907d
Add cosine annealing on trainer.
datumbox May 1, 2021
6d40406
Make T_max count epochs.
datumbox May 1, 2021
34c2769
Fix test and handle corner-case.
datumbox May 1, 2021
ea46bfc
Add support of support width_mult
datumbox May 2, 2021
0dca06c
Add ssdlite presets.
datumbox May 2, 2021
7cce538
Merge branch 'master' into models/ssdlite
datumbox May 3, 2021
8b9ca53
Change ReLU6, [-1,1] rescaling, backbone init & no pretraining.
datumbox May 4, 2021
f8cbe46
Merge branch 'master' into models/ssdlite
datumbox May 4, 2021
8aa3f58
Use _reduced_tail=True.
datumbox May 4, 2021
da81b69
Add sync BN support.
datumbox May 5, 2021
5fbc112
Merge branch 'master' into models/ssdlite
datumbox May 6, 2021
d4024cb
Merge branch 'master' into models/ssdlite
datumbox May 7, 2021
4ca472e
Adding the best config along with its weights and documentation.
datumbox May 10, 2021
d8d55b7
Merge branch 'master' into models/ssdlite
datumbox May 10, 2021
bad974a
Merge branch 'master' into models/ssdlite
datumbox May 11, 2021
4020705
Make mean/std configurable.
datumbox May 11, 2021
c0bfc51
Fix not implemented for half exception
datumbox May 11, 2021
d05ef49
Merge branch 'master' into models/ssdlite
datumbox May 11, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,7 @@ Faster R-CNN MobileNetV3-Large FPN 32.8 - -
Faster R-CNN MobileNetV3-Large 320 FPN 22.8 - -
RetinaNet ResNet-50 FPN 36.4 - -
SSD VGG16 25.1 - -
SSDlite MobileNetV3-Large 21.3 - -
Mask R-CNN ResNet-50 FPN 37.9 34.6 -
====================================== ======= ======== ===========

Expand Down Expand Up @@ -486,6 +487,7 @@ Faster R-CNN MobileNetV3-Large FPN 0.1020 0.0415
Faster R-CNN MobileNetV3-Large 320 FPN 0.0978 0.0376 0.6
RetinaNet ResNet-50 FPN 0.2514 0.0939 4.1
SSD VGG16 0.2093 0.0744 1.5
SSDlite MobileNetV3-Large 0.1773 0.0906 1.5
Mask R-CNN ResNet-50 FPN 0.2728 0.0903 5.4
Keypoint R-CNN ResNet-50 FPN 0.3789 0.1242 6.8
====================================== =================== ================== ===========
Expand All @@ -511,6 +513,12 @@ SSD
.. autofunction:: torchvision.models.detection.ssd300_vgg16


SSDlite
------------

.. autofunction:: torchvision.models.detection.ssdlite320_mobilenet_v3_large


Mask R-CNN
----------

Expand Down
8 changes: 8 additions & 0 deletions references/detection/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,14 @@ python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
--weight-decay 0.0005 --data-augmentation ssd
```

### SSDlite MobileNetV3-Large
```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
--dataset coco --model ssdlite320_mobilenet_v3_large --epochs 660\
--aspect-ratio-group-factor 3 --lr-scheduler cosineannealinglr --lr 0.15 --batch-size 24\
--weight-decay 0.00004 --data-augmentation ssdlite
```


### Mask R-CNN
```
Expand Down
6 changes: 6 additions & 0 deletions references/detection/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ def __init__(self, data_augmentation, hflip_prob=0.5, mean=(123., 117., 104.)):
T.RandomHorizontalFlip(p=hflip_prob),
T.ToTensor(),
])
elif data_augmentation == 'ssdlite':
self.transforms = T.Compose([
T.RandomIoUCrop(),
T.RandomHorizontalFlip(p=hflip_prob),
T.ToTensor(),
])
else:
raise ValueError(f'Unknown data augmentation policy "{data_augmentation}"')

Expand Down
28 changes: 23 additions & 5 deletions references/detection/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,13 @@ def get_args_parser(add_help=True):
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)',
dest='weight_decay')
parser.add_argument('--lr-step-size', default=8, type=int, help='decrease lr every step-size epochs')
parser.add_argument('--lr-steps', default=[16, 22], nargs='+', type=int, help='decrease lr every step-size epochs')
parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')
parser.add_argument('--lr-scheduler', default="multisteplr", help='the lr scheduler (default: multisteplr)')
parser.add_argument('--lr-step-size', default=8, type=int,
help='decrease lr every step-size epochs (multisteplr scheduler only)')
parser.add_argument('--lr-steps', default=[16, 22], nargs='+', type=int,
help='decrease lr every step-size epochs (multisteplr scheduler only)')
parser.add_argument('--lr-gamma', default=0.1, type=float,
help='decrease lr by a factor of lr-gamma (multisteplr scheduler only)')
parser.add_argument('--print-freq', default=20, type=int, help='print frequency')
parser.add_argument('--output-dir', default='.', help='path where to save')
parser.add_argument('--resume', default='', help='resume from checkpoint')
Expand All @@ -85,6 +89,12 @@ def get_args_parser(add_help=True):
parser.add_argument('--trainable-backbone-layers', default=None, type=int,
help='number of trainable layers of backbone')
parser.add_argument('--data-augmentation', default="hflip", help='data augmentation policy (default: hflip)')
parser.add_argument(
"--sync-bn",
dest="sync_bn",
help="Use sync batch norm",
action="store_true",
)
parser.add_argument(
"--test-only",
dest="test_only",
Expand Down Expand Up @@ -156,6 +166,8 @@ def main(args):
model = torchvision.models.detection.__dict__[args.model](num_classes=num_classes, pretrained=args.pretrained,
**kwargs)
model.to(device)
if args.distributed and args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

model_without_ddp = model
if args.distributed:
Expand All @@ -166,8 +178,14 @@ def main(args):
optimizer = torch.optim.SGD(
params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

# lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_steps, gamma=args.lr_gamma)
args.lr_scheduler = args.lr_scheduler.lower()
if args.lr_scheduler == 'multisteplr':
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_steps, gamma=args.lr_gamma)
elif args.lr_scheduler == 'cosineannealinglr':
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
else:
raise RuntimeError("Invalid lr scheduler '{}'. Only MultiStepLR and CosineAnnealingLR "
"are supported.".format(args.lr_scheduler))

if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu')
Expand Down
Binary file not shown.
1 change: 1 addition & 0 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def get_available_video_models():
"keypointrcnn_resnet50_fpn": lambda x: x[1],
"retinanet_resnet50_fpn": lambda x: x[1],
"ssd300_vgg16": lambda x: x[1],
"ssdlite320_mobilenet_v3_large": lambda x: x[1],
}


Expand Down
1 change: 1 addition & 0 deletions torchvision/models/detection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from .keypoint_rcnn import *
from .retinanet import *
from .ssd import *
from .ssdlite import *
4 changes: 2 additions & 2 deletions torchvision/models/detection/anchor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,8 @@ def _grid_default_boxes(self, grid_sizes: List[List[int]], image_size: List[int]
else:
y_f_k, x_f_k = f_k

shifts_x = (torch.arange(0, f_k[1], dtype=dtype) + 0.5) / x_f_k
shifts_y = (torch.arange(0, f_k[0], dtype=dtype) + 0.5) / y_f_k
shifts_x = ((torch.arange(0, f_k[1]) + 0.5) / x_f_k).to(dtype=dtype)
shifts_y = ((torch.arange(0, f_k[0]) + 0.5) / y_f_k).to(dtype=dtype)
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
shift_x = shift_x.reshape(-1)
shift_y = shift_y.reshape(-1)
Expand Down
228 changes: 228 additions & 0 deletions torchvision/models/detection/ssdlite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
import torch

from collections import OrderedDict
from functools import partial
from torch import nn, Tensor
from typing import Any, Callable, Dict, List, Optional, Tuple

from . import _utils as det_utils
from .ssd import SSD, SSDScoringHead
from .anchor_utils import DefaultBoxGenerator
from .backbone_utils import _validate_trainable_layers
from .. import mobilenet
from ..mobilenetv3 import ConvBNActivation
from ..utils import load_state_dict_from_url


__all__ = ['ssdlite320_mobilenet_v3_large']

model_urls = {
'ssdlite320_mobilenet_v3_large_coco':
'https://download.pytorch.org/models/ssdlite320_mobilenet_v3_large_coco-a79551df.pth'
}


def _prediction_block(in_channels: int, out_channels: int, kernel_size: int,
norm_layer: Callable[..., nn.Module]) -> nn.Sequential:
return nn.Sequential(
# 3x3 depthwise with stride 1 and padding 1
ConvBNActivation(in_channels, in_channels, kernel_size=kernel_size, groups=in_channels,
norm_layer=norm_layer, activation_layer=nn.ReLU6),

# 1x1 projetion to output channels
nn.Conv2d(in_channels, out_channels, 1)
)


def _extra_block(in_channels: int, out_channels: int, norm_layer: Callable[..., nn.Module]) -> nn.Sequential:
activation = nn.ReLU6
intermediate_channels = out_channels // 2
return nn.Sequential(
# 1x1 projection to half output channels
ConvBNActivation(in_channels, intermediate_channels, kernel_size=1,
norm_layer=norm_layer, activation_layer=activation),

# 3x3 depthwise with stride 2 and padding 1
ConvBNActivation(intermediate_channels, intermediate_channels, kernel_size=3, stride=2,
groups=intermediate_channels, norm_layer=norm_layer, activation_layer=activation),

# 1x1 projetion to output channels
ConvBNActivation(intermediate_channels, out_channels, kernel_size=1,
norm_layer=norm_layer, activation_layer=activation),
)


def _normal_init(conv: nn.Module):
for layer in conv.modules():
if isinstance(layer, nn.Conv2d):
torch.nn.init.normal_(layer.weight, mean=0.0, std=0.03)
if layer.bias is not None:
torch.nn.init.constant_(layer.bias, 0.0)


class SSDLiteHead(nn.Module):
def __init__(self, in_channels: List[int], num_anchors: List[int], num_classes: int,
norm_layer: Callable[..., nn.Module]):
super().__init__()
self.classification_head = SSDLiteClassificationHead(in_channels, num_anchors, num_classes, norm_layer)
self.regression_head = SSDLiteRegressionHead(in_channels, num_anchors, norm_layer)

def forward(self, x: List[Tensor]) -> Dict[str, Tensor]:
return {
'bbox_regression': self.regression_head(x),
'cls_logits': self.classification_head(x),
}


class SSDLiteClassificationHead(SSDScoringHead):
def __init__(self, in_channels: List[int], num_anchors: List[int], num_classes: int,
norm_layer: Callable[..., nn.Module]):
cls_logits = nn.ModuleList()
for channels, anchors in zip(in_channels, num_anchors):
cls_logits.append(_prediction_block(channels, num_classes * anchors, 3, norm_layer))
_normal_init(cls_logits)
super().__init__(cls_logits, num_classes)


class SSDLiteRegressionHead(SSDScoringHead):
def __init__(self, in_channels: List[int], num_anchors: List[int], norm_layer: Callable[..., nn.Module]):
bbox_reg = nn.ModuleList()
for channels, anchors in zip(in_channels, num_anchors):
bbox_reg.append(_prediction_block(channels, 4 * anchors, 3, norm_layer))
_normal_init(bbox_reg)
super().__init__(bbox_reg, 4)


class SSDLiteFeatureExtractorMobileNet(nn.Module):
def __init__(self, backbone: nn.Module, c4_pos: int, norm_layer: Callable[..., nn.Module], rescaling: bool,
**kwargs: Any):
super().__init__()
# non-public config parameters
min_depth = kwargs.pop('_min_depth', 16)
width_mult = kwargs.pop('_width_mult', 1.0)

assert not backbone[c4_pos].use_res_connect
self.features = nn.Sequential(
nn.Sequential(*backbone[:c4_pos], backbone[c4_pos].block[0]), # from start until C4 expansion layer
nn.Sequential(backbone[c4_pos].block[1:], *backbone[c4_pos + 1:]), # from C4 depthwise until end
)

get_depth = lambda d: max(min_depth, int(d * width_mult)) # noqa: E731
extra = nn.ModuleList([
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@datumbox could you please help me figure it out - I cannot find the info about these extra layers in the papers. Where did you get them from?
I'm trying to create a modification for this model and struggle to understand it - any help would be appreciated!
I want to reduce the number of encoder layers to make feature maps detect small objects.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@evekeen I have written a blogpost about the implementation details of this model. See here. The extra layers are described on section 6.3 of the paper though to get their exact values you need to dig in the original TF code. Hope that helps!

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@datumbox Thank you for the quick reply! It's very helpful

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@datumbox in 6.3 of MobileNet3 paper, I only see the info on connecting C4 and C5 layers to the SSD head. There is nothing on these extra layers there.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you checked the reference code I sent? This comes from their official repo.

Copy link

@evekeen evekeen Nov 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I see that in the TensorFlow implementation.
I'm trying to understand if I'm reducing the depth of C4 (and thus the output stride for targeting super small objects) - how should I change the rest of the layers?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, it's been quite sometime since I wrote the implementation. I think you will need to dig into the original research repo to get the details.

_extra_block(backbone[-1].out_channels, get_depth(512), norm_layer),
_extra_block(get_depth(512), get_depth(256), norm_layer),
_extra_block(get_depth(256), get_depth(256), norm_layer),
_extra_block(get_depth(256), get_depth(128), norm_layer),
])
_normal_init(extra)

self.extra = extra
self.rescaling = rescaling

def forward(self, x: Tensor) -> Dict[str, Tensor]:
# Rescale from [0, 1] to [-1, -1]
if self.rescaling:
x = 2.0 * x - 1.0

# Get feature maps from backbone and extra. Can't be refactored due to JIT limitations.
output = []
for block in self.features:
x = block(x)
output.append(x)

for block in self.extra:
x = block(x)
output.append(x)

return OrderedDict([(str(i), v) for i, v in enumerate(output)])


def _mobilenet_extractor(backbone_name: str, progress: bool, pretrained: bool, trainable_layers: int,
norm_layer: Callable[..., nn.Module], rescaling: bool, **kwargs: Any):
backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, progress=progress,
norm_layer=norm_layer, **kwargs).features
if not pretrained:
# Change the default initialization scheme if not pretrained
_normal_init(backbone)

# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
num_stages = len(stage_indices)

# find the index of the layer from which we wont freeze
assert 0 <= trainable_layers <= num_stages
freeze_before = num_stages if trainable_layers == 0 else stage_indices[num_stages - trainable_layers]

for b in backbone[:freeze_before]:
for parameter in b.parameters():
parameter.requires_grad_(False)

return SSDLiteFeatureExtractorMobileNet(backbone, stage_indices[-2], norm_layer, rescaling, **kwargs)


def ssdlite320_mobilenet_v3_large(pretrained: bool = False, progress: bool = True, num_classes: int = 91,
pretrained_backbone: bool = False, trainable_backbone_layers: Optional[int] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None,
**kwargs: Any):
"""
Constructs an SSDlite model with a MobileNetV3 Large backbone. See `SSD` for more details.

Example:

>>> model = torchvision.models.detection.ssdlite320_mobilenet_v3_large(pretrained=True)
>>> model.eval()
>>> x = [torch.rand(3, 320, 320), torch.rand(3, 500, 400)]
>>> predictions = model(x)

Args:
norm_layer:
**kwargs:
pretrained (bool): If True, returns a model pre-trained on COCO train2017
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable.
norm_layer (callable, optional): Module specifying the normalization layer to use.
"""
trainable_backbone_layers = _validate_trainable_layers(
pretrained or pretrained_backbone, trainable_backbone_layers, 6, 6)

if pretrained:
pretrained_backbone = False

# Enable [-1, 1] rescaling and reduced tail if no pretrained backbone is selected
rescaling = reduce_tail = not pretrained_backbone
Comment on lines +195 to +196
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit confusing, but I assume the [-1, 1] rescaling is necessary to get best results given the current settings?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is correct. Rescaling was part of the changes needed to boost the accuracy by 1mAP.


if norm_layer is None:
norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.03)

backbone = _mobilenet_extractor("mobilenet_v3_large", progress, pretrained_backbone, trainable_backbone_layers,
norm_layer, rescaling, _reduced_tail=reduce_tail, _width_mult=1.0)

size = (320, 320)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This means that the size is hard-coded and even if the user passes a different size **kwargs in the constructor it won't be used?

What about doing something like

size = kwargs.get("size", (320, 320))

instead, so that the users can potentially customize the input size if they wish?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I chose to hardcode it because this is the ssdlite320 model which uses a fixed 320x320 size. The input size is much less flexible on SSD models comparing to FasterRCNN because they make a few strong assumptions about the input.

If someone wants to use a different size, it would be simpler to just create the backbone, configure the DefaultBoxGenerator and then initialize directly the SSD with the config of their choice. Overall I felt that this approach would be simpler than trying to offer an API that tries to cover all user needs.

anchor_generator = DefaultBoxGenerator([[2, 3] for _ in range(6)], min_ratio=0.2, max_ratio=0.95)
out_channels = det_utils.retrieve_out_channels(backbone, size)
num_anchors = anchor_generator.num_anchors_per_location()
assert len(out_channels) == len(anchor_generator.aspect_ratios)

defaults = {
"score_thresh": 0.001,
"nms_thresh": 0.55,
"detections_per_img": 300,
"topk_candidates": 300,
"image_mean": [0., 0., 0.],
"image_std": [1., 1., 1.],
}
kwargs = {**defaults, **kwargs}
model = SSD(backbone, anchor_generator, size, num_classes,
head=SSDLiteHead(out_channels, num_anchors, num_classes, norm_layer), **kwargs)

if pretrained:
weights_name = 'ssdlite320_mobilenet_v3_large_coco'
if model_urls.get(weights_name, None) is None:
raise ValueError("No checkpoint is available for model {}".format(weights_name))
state_dict = load_state_dict_from_url(model_urls[weights_name], progress=progress)
model.load_state_dict(state_dict)
return model