Skip to content

Commit

Permalink
When bounding boxes are cached in metadata, don’t crash on load_masks…
Browse files Browse the repository at this point in the history
…=False

Summary:
We currently support caching bounding boxes in MaskAnnotation. If present, they are not re-computed from the mask. However, the masks need to be loaded for the bbox to be set.

This diff fixes that. Even if load_masks / load_blobs are unset, the bounding box can be picked up from the metadata.

Reviewed By: bottler

Differential Revision: D45144918

fbshipit-source-id: 8a2e2c115e96070b6fcdc29cbe57e1cee606ddcd
  • Loading branch information
shapovalov authored and facebook-github-bot committed Apr 20, 2023
1 parent 0e3138e commit 7aeedd1
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 24 deletions.
35 changes: 16 additions & 19 deletions pytorch3d/implicitron/dataset/frame_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,12 +555,19 @@ def build(
else None,
)

if load_blobs and self.load_masks and frame_annotation.mask is not None:
(
frame_data.fg_probability,
frame_data.mask_path,
frame_data.bbox_xywh,
) = self._load_fg_probability(frame_annotation)
mask_annotation = frame_annotation.mask
if mask_annotation is not None:
fg_mask_np: Optional[np.ndarray] = None
if load_blobs and self.load_masks:
fg_mask_np, mask_path = self._load_fg_probability(frame_annotation)
frame_data.mask_path = mask_path
frame_data.fg_probability = safe_as_tensor(fg_mask_np, torch.float)

bbox_xywh = mask_annotation.bounding_box_xywh
if bbox_xywh is None and fg_mask_np is not None:
bbox_xywh = get_bbox_from_mask(fg_mask_np, self.box_crop_mask_thr)

frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.long)

if frame_annotation.image is not None:
image_size_hw = safe_as_tensor(frame_annotation.image.size, torch.long)
Expand Down Expand Up @@ -604,25 +611,15 @@ def build(

def _load_fg_probability(
self, entry: types.FrameAnnotation
) -> Tuple[Optional[torch.Tensor], Optional[str], Optional[torch.Tensor]]:

) -> Tuple[np.ndarray, str]:
full_path = os.path.join(self.dataset_root, entry.mask.path) # pyre-ignore
fg_probability = load_mask(self._local_path(full_path))
# we can use provided bbox_xywh or calculate it based on mask
# saves time to skip bbox calculation
# pyre-ignore
bbox_xywh = entry.mask.bounding_box_xywh or get_bbox_from_mask(
fg_probability, self.box_crop_mask_thr
)
if fg_probability.shape[-2:] != entry.image.size:
raise ValueError(
f"bad mask size: {fg_probability.shape[-2:]} vs {entry.image.size}!"
)
return (
safe_as_tensor(fg_probability, torch.float),
full_path,
safe_as_tensor(bbox_xywh, torch.long),
)

return fg_probability, full_path

def _load_images(
self,
Expand Down
14 changes: 9 additions & 5 deletions tests/implicitron/test_frame_data_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from pytorch3d.implicitron.dataset.dataset_base import FrameData
from pytorch3d.implicitron.dataset.frame_data import FrameDataBuilder
from pytorch3d.implicitron.dataset.utils import (
get_bbox_from_mask,
load_16big_png_depth,
load_1bit_png_mask,
load_depth,
Expand Down Expand Up @@ -107,11 +108,14 @@ def test_load_and_adjust_frame_data(self):
)
self.frame_data.effective_image_size_hw = self.frame_data.image_size_hw

(
self.frame_data.fg_probability,
self.frame_data.mask_path,
self.frame_data.bbox_xywh,
) = self.frame_data_builder._load_fg_probability(self.frame_annotation)
fg_mask_np, mask_path = self.frame_data_builder._load_fg_probability(
self.frame_annotation
)
self.frame_data.mask_path = mask_path
self.frame_data.fg_probability = safe_as_tensor(fg_mask_np, torch.float)
mask_thr = self.frame_data_builder.box_crop_mask_thr
bbox_xywh = get_bbox_from_mask(fg_mask_np, mask_thr)
self.frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.long)

self.assertIsNotNone(self.frame_data.mask_path)
self.assertTrue(torch.is_tensor(self.frame_data.fg_probability))
Expand Down

0 comments on commit 7aeedd1

Please sign in to comment.