From 7aeedd17a4140eef139987e946a7017df7a97433 Mon Sep 17 00:00:00 2001 From: Roman Shapovalov Date: Thu, 20 Apr 2023 07:28:45 -0700 Subject: [PATCH] =?UTF-8?q?When=20bounding=20boxes=20are=20cached=20in=20m?= =?UTF-8?q?etadata,=20don=E2=80=99t=20crash=20on=20load=5Fmasks=3DFalse?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: We currently support caching bounding boxes in MaskAnnotation. If present, they are not re-computed from the mask. However, the masks need to be loaded for the bbox to be set. This diff fixes that. Even if load_masks / load_blobs are unset, the bounding box can be picked up from the metadata. Reviewed By: bottler Differential Revision: D45144918 fbshipit-source-id: 8a2e2c115e96070b6fcdc29cbe57e1cee606ddcd --- pytorch3d/implicitron/dataset/frame_data.py | 35 +++++++++----------- tests/implicitron/test_frame_data_builder.py | 14 +++++--- 2 files changed, 25 insertions(+), 24 deletions(-) diff --git a/pytorch3d/implicitron/dataset/frame_data.py b/pytorch3d/implicitron/dataset/frame_data.py index 1a4e1b5c6..e8e88b707 100644 --- a/pytorch3d/implicitron/dataset/frame_data.py +++ b/pytorch3d/implicitron/dataset/frame_data.py @@ -555,12 +555,19 @@ def build( else None, ) - if load_blobs and self.load_masks and frame_annotation.mask is not None: - ( - frame_data.fg_probability, - frame_data.mask_path, - frame_data.bbox_xywh, - ) = self._load_fg_probability(frame_annotation) + mask_annotation = frame_annotation.mask + if mask_annotation is not None: + fg_mask_np: Optional[np.ndarray] = None + if load_blobs and self.load_masks: + fg_mask_np, mask_path = self._load_fg_probability(frame_annotation) + frame_data.mask_path = mask_path + frame_data.fg_probability = safe_as_tensor(fg_mask_np, torch.float) + + bbox_xywh = mask_annotation.bounding_box_xywh + if bbox_xywh is None and fg_mask_np is not None: + bbox_xywh = get_bbox_from_mask(fg_mask_np, self.box_crop_mask_thr) + + frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.long) if frame_annotation.image is not None: image_size_hw = safe_as_tensor(frame_annotation.image.size, torch.long) @@ -604,25 +611,15 @@ def build( def _load_fg_probability( self, entry: types.FrameAnnotation - ) -> Tuple[Optional[torch.Tensor], Optional[str], Optional[torch.Tensor]]: - + ) -> Tuple[np.ndarray, str]: full_path = os.path.join(self.dataset_root, entry.mask.path) # pyre-ignore fg_probability = load_mask(self._local_path(full_path)) - # we can use provided bbox_xywh or calculate it based on mask - # saves time to skip bbox calculation - # pyre-ignore - bbox_xywh = entry.mask.bounding_box_xywh or get_bbox_from_mask( - fg_probability, self.box_crop_mask_thr - ) if fg_probability.shape[-2:] != entry.image.size: raise ValueError( f"bad mask size: {fg_probability.shape[-2:]} vs {entry.image.size}!" ) - return ( - safe_as_tensor(fg_probability, torch.float), - full_path, - safe_as_tensor(bbox_xywh, torch.long), - ) + + return fg_probability, full_path def _load_images( self, diff --git a/tests/implicitron/test_frame_data_builder.py b/tests/implicitron/test_frame_data_builder.py index f150081b2..e66d67dfc 100644 --- a/tests/implicitron/test_frame_data_builder.py +++ b/tests/implicitron/test_frame_data_builder.py @@ -17,6 +17,7 @@ from pytorch3d.implicitron.dataset.dataset_base import FrameData from pytorch3d.implicitron.dataset.frame_data import FrameDataBuilder from pytorch3d.implicitron.dataset.utils import ( + get_bbox_from_mask, load_16big_png_depth, load_1bit_png_mask, load_depth, @@ -107,11 +108,14 @@ def test_load_and_adjust_frame_data(self): ) self.frame_data.effective_image_size_hw = self.frame_data.image_size_hw - ( - self.frame_data.fg_probability, - self.frame_data.mask_path, - self.frame_data.bbox_xywh, - ) = self.frame_data_builder._load_fg_probability(self.frame_annotation) + fg_mask_np, mask_path = self.frame_data_builder._load_fg_probability( + self.frame_annotation + ) + self.frame_data.mask_path = mask_path + self.frame_data.fg_probability = safe_as_tensor(fg_mask_np, torch.float) + mask_thr = self.frame_data_builder.box_crop_mask_thr + bbox_xywh = get_bbox_from_mask(fg_mask_np, mask_thr) + self.frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.long) self.assertIsNotNone(self.frame_data.mask_path) self.assertTrue(torch.is_tensor(self.frame_data.fg_probability))