-
Notifications
You must be signed in to change notification settings - Fork 7.1k
RetinaNet object detection. #1697
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
31 commits
Select commit
Hold shift + click to select a range
50f822c
Add rough implementation of RetinaNet.
hgaiser 022f8e1
Move AnchorGenerator to a seperate file.
8e0804d
Move box similarity to Matcher.
hgaiser ad53194
Expose extra blocks in FPN.
hgaiser 2a5a5be
Expose retinanet in __init__.py.
hgaiser 49e990c
Use P6 and P7 in FPN for retinanet.
hgaiser b5966eb
Use parameters from retinanet for anchor generation.
hgaiser aab1b28
General fixes for retinanet model.
hgaiser c078114
Implement loss for retinanet heads.
hgaiser eae4ee5
Output reshaped outputs from retinanet heads.
hgaiser 3dac477
Add postprocessing of detections.
hgaiser 9981a3c
Small fixes.
hgaiser 5571dfe
Remove unused argument.
hgaiser fc7751b
Remove python2 invocation of super.
hgaiser b942648
Add postprocessing for additional outputs.
hgaiser b619936
Add missing import of ImageList.
hgaiser 8c86588
Remove redundant import.
hgaiser 2934f0d
Simplify class correction.
hgaiser 32b8e77
Fix pylint warnings.
hgaiser 437bfe9
Remove the label adjustment for background class.
9e810d6
Set default score threshold to 0.05.
f7d8c2e
Add weight initialization for regression layer.
d86c437
Allow training on images with no annotations.
72e46f2
Use smooth_l1_loss with beta value.
41c90fa
Add more typehints for TorchScript conversions.
hgaiser b9daa86
Fix linting issues.
hgaiser 97d63b6
Fix type hints in postprocess_detections.
hgaiser eba7e16
Fix type annotations for TorchScript.
9545059
Fix inconsistency with matched_idxs.
4865952
Add retinanet model test.
6e065be
Add missing JIT annotations.
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Binary file not shown.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from .faster_rcnn import * | ||
from .mask_rcnn import * | ||
from .keypoint_rcnn import * | ||
from .retinanet import * |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | ||
import torch | ||
from torch import nn | ||
|
||
from torch.jit.annotations import List, Optional, Dict | ||
from .image_list import ImageList | ||
|
||
|
||
class AnchorGenerator(nn.Module): | ||
""" | ||
Module that generates anchors for a set of feature maps and | ||
image sizes. | ||
|
||
The module support computing anchors at multiple sizes and aspect ratios | ||
per feature map. This module assumes aspect ratio = height / width for | ||
each anchor. | ||
|
||
sizes and aspect_ratios should have the same number of elements, and it should | ||
correspond to the number of feature maps. | ||
|
||
sizes[i] and aspect_ratios[i] can have an arbitrary number of elements, | ||
and AnchorGenerator will output a set of sizes[i] * aspect_ratios[i] anchors | ||
per spatial location for feature map i. | ||
|
||
Arguments: | ||
sizes (Tuple[Tuple[int]]): | ||
aspect_ratios (Tuple[Tuple[float]]): | ||
""" | ||
|
||
__annotations__ = { | ||
"cell_anchors": Optional[List[torch.Tensor]], | ||
"_cache": Dict[str, List[torch.Tensor]] | ||
} | ||
|
||
def __init__( | ||
self, | ||
sizes=((128, 256, 512),), | ||
aspect_ratios=((0.5, 1.0, 2.0),), | ||
): | ||
super(AnchorGenerator, self).__init__() | ||
|
||
if not isinstance(sizes[0], (list, tuple)): | ||
# TODO change this | ||
sizes = tuple((s,) for s in sizes) | ||
if not isinstance(aspect_ratios[0], (list, tuple)): | ||
aspect_ratios = (aspect_ratios,) * len(sizes) | ||
|
||
assert len(sizes) == len(aspect_ratios) | ||
|
||
self.sizes = sizes | ||
self.aspect_ratios = aspect_ratios | ||
self.cell_anchors = None | ||
self._cache = {} | ||
|
||
# TODO: https://github.com/pytorch/pytorch/issues/26792 | ||
# For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values. | ||
# (scales, aspect_ratios) are usually an element of zip(self.scales, self.aspect_ratios) | ||
# This method assumes aspect ratio = height / width for an anchor. | ||
def generate_anchors(self, scales, aspect_ratios, dtype=torch.float32, device="cpu"): | ||
# type: (List[int], List[float], int, Device) -> Tensor # noqa: F821 | ||
scales = torch.as_tensor(scales, dtype=dtype, device=device) | ||
aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device) | ||
h_ratios = torch.sqrt(aspect_ratios) | ||
w_ratios = 1 / h_ratios | ||
|
||
ws = (w_ratios[:, None] * scales[None, :]).view(-1) | ||
hs = (h_ratios[:, None] * scales[None, :]).view(-1) | ||
|
||
base_anchors = torch.stack([-ws, -hs, ws, hs], dim=1) / 2 | ||
return base_anchors.round() | ||
|
||
def set_cell_anchors(self, dtype, device): | ||
# type: (int, Device) -> None # noqa: F821 | ||
if self.cell_anchors is not None: | ||
cell_anchors = self.cell_anchors | ||
assert cell_anchors is not None | ||
# suppose that all anchors have the same device | ||
# which is a valid assumption in the current state of the codebase | ||
if cell_anchors[0].device == device: | ||
return | ||
|
||
cell_anchors = [ | ||
self.generate_anchors( | ||
sizes, | ||
aspect_ratios, | ||
dtype, | ||
device | ||
) | ||
for sizes, aspect_ratios in zip(self.sizes, self.aspect_ratios) | ||
] | ||
self.cell_anchors = cell_anchors | ||
|
||
def num_anchors_per_location(self): | ||
return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)] | ||
|
||
# For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2), | ||
# output g[i] anchors that are s[i] distance apart in direction i, with the same dimensions as a. | ||
def grid_anchors(self, grid_sizes, strides): | ||
# type: (List[List[int]], List[List[Tensor]]) -> List[Tensor] | ||
anchors = [] | ||
cell_anchors = self.cell_anchors | ||
assert cell_anchors is not None | ||
assert len(grid_sizes) == len(strides) == len(cell_anchors) | ||
|
||
for size, stride, base_anchors in zip( | ||
grid_sizes, strides, cell_anchors | ||
): | ||
grid_height, grid_width = size | ||
stride_height, stride_width = stride | ||
device = base_anchors.device | ||
|
||
# For output anchor, compute [x_center, y_center, x_center, y_center] | ||
shifts_x = torch.arange( | ||
0, grid_width, dtype=torch.float32, device=device | ||
) * stride_width | ||
shifts_y = torch.arange( | ||
0, grid_height, dtype=torch.float32, device=device | ||
) * stride_height | ||
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) | ||
shift_x = shift_x.reshape(-1) | ||
shift_y = shift_y.reshape(-1) | ||
shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1) | ||
|
||
# For every (base anchor, output anchor) pair, | ||
# offset each zero-centered base anchor by the center of the output anchor. | ||
anchors.append( | ||
(shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(-1, 4) | ||
) | ||
|
||
return anchors | ||
|
||
def cached_grid_anchors(self, grid_sizes, strides): | ||
# type: (List[List[int]], List[List[Tensor]]) -> List[Tensor] | ||
key = str(grid_sizes) + str(strides) | ||
if key in self._cache: | ||
return self._cache[key] | ||
anchors = self.grid_anchors(grid_sizes, strides) | ||
self._cache[key] = anchors | ||
return anchors | ||
|
||
def forward(self, image_list, feature_maps): | ||
# type: (ImageList, List[Tensor]) -> List[Tensor] | ||
grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps]) | ||
image_size = image_list.tensors.shape[-2:] | ||
dtype, device = feature_maps[0].dtype, feature_maps[0].device | ||
strides = [[torch.tensor(image_size[0] // g[0], dtype=torch.int64, device=device), | ||
torch.tensor(image_size[1] // g[1], dtype=torch.int64, device=device)] for g in grid_sizes] | ||
self.set_cell_anchors(dtype, device) | ||
anchors_over_all_feature_maps = self.cached_grid_anchors(grid_sizes, strides) | ||
anchors = torch.jit.annotate(List[List[torch.Tensor]], []) | ||
for i, (image_height, image_width) in enumerate(image_list.image_sizes): | ||
anchors_in_image = [] | ||
for anchors_per_feature_map in anchors_over_all_feature_maps: | ||
anchors_in_image.append(anchors_per_feature_map) | ||
anchors.append(anchors_in_image) | ||
anchors = [torch.cat(anchors_per_image) for anchors_per_image in anchors] | ||
# Clear the cache in case that memory leaks. | ||
self._cache.clear() | ||
return anchors | ||
hgaiser marked this conversation as resolved.
Show resolved
Hide resolved
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.