Skip to content

Commit 2ecf4a4

Browse files
committed
[Fix] Fix cropping polygon mask.
1 parent 61dd8d5 commit 2ecf4a4

File tree

2 files changed

+31
-4
lines changed

2 files changed

+31
-4
lines changed

mmdet/structures/mask/structures.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import mmcv
88
import numpy as np
99
import pycocotools.mask as maskUtils
10+
import shapely.geometry as geometry
1011
import torch
1112
from mmcv.ops.roi_align import roi_align
1213

@@ -753,15 +754,40 @@ def crop(self, bbox):
753754
if len(self.masks) == 0:
754755
cropped_masks = PolygonMasks([], h, w)
755756
else:
757+
# reference: https://github.com/facebookresearch/fvcore/blob/main/fvcore/transforms/transform.py # noqa
758+
crop_box = geometry.box(x1, y1, x2, y2).buffer(0.0)
756759
cropped_masks = []
757760
for poly_per_obj in self.masks:
758761
cropped_poly_per_obj = []
759762
for p in poly_per_obj:
760-
# pycocotools will clip the boundary
761763
p = p.copy()
762-
p[0::2] = p[0::2] - bbox[0]
763-
p[1::2] = p[1::2] - bbox[1]
764-
cropped_poly_per_obj.append(p)
764+
p = geometry.Polygon(p.reshape(-1, 2)).buffer(0.0)
765+
# polygon must be valid to perform intersection.
766+
if not p.is_valid:
767+
continue
768+
cropped = p.intersection(crop_box)
769+
if cropped.is_empty:
770+
continue
771+
if isinstance(cropped,
772+
geometry.collection.BaseMultipartGeometry):
773+
cropped = cropped.geoms
774+
else:
775+
cropped = [cropped]
776+
# one polygon may be cropped to multiple ones
777+
for poly in cropped:
778+
# ignore lines or points
779+
if not isinstance(
780+
poly, geometry.Polygon) or not poly.is_valid:
781+
continue
782+
coords = np.asarray(poly.exterior.coords)
783+
# remove an extra identical vertex at the end
784+
coords = coords[:-1]
785+
coords[:, 0] -= x1
786+
coords[:, 1] -= y1
787+
cropped_poly_per_obj.append(coords.reshape(-1))
788+
789+
if len(cropped_poly_per_obj) == 0:
790+
cropped_poly_per_obj = [np.array([0, 0, 0, 0, 0, 0])]
765791
cropped_masks.append(cropped_poly_per_obj)
766792
cropped_masks = PolygonMasks(cropped_masks, h, w)
767793
return cropped_masks

requirements/runtime.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@ matplotlib
22
numpy
33
pycocotools
44
scipy
5+
shapely
56
six
67
terminaltables

0 commit comments

Comments
 (0)