Skip to content

Commit 2ecd048

Browse files
committed
explicit check for box extending over image shape
Signed-off-by: Tomasz Bartczak <kretesenator@gmail.com>
1 parent 2cbed6c commit 2ecd048

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

monai/apps/detection/transforms/box_ops.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,9 @@ def convert_box_to_mask(
242242
boxes_mask_np = np.ones((labels.shape[0],) + spatial_size, dtype=np.int16) * np.int16(bg_label)
243243

244244
boxes_np: np.ndarray = convert_data_type(boxes, np.ndarray, dtype=np.int32)[0]
245+
if np.any(boxes_np[:, spatial_dims:] > np.array(spatial_size)):
246+
raise ValueError("Some boxes are larger than the image.")
247+
245248
labels_np, *_ = convert_to_dst_type(src=labels, dst=boxes_np)
246249
for b in range(boxes_np.shape[0]):
247250
# generate a foreground mask

tests/test_box_transform.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,24 @@ def test_value_3d_mask(self):
131131
assert_allclose(data_back["boxes"], data["boxes"], type_test=False, device_test=False, atol=1e-3)
132132
assert_allclose(data_back["labels"], data["labels"], type_test=False, device_test=False, atol=1e-3)
133133

134+
def test_shape_assertion(self):
135+
test_dtype = torch.float32
136+
image = np.zeros((1, 10, 10, 10))
137+
boxes = np.array([[7, 8, 9, 10, 12, 13]])
138+
data = {"image": image, "boxes": boxes, "labels": np.array((1,))}
139+
data = CastToTyped(keys=["image", "boxes"], dtype=test_dtype)(data)
140+
transform_to_mask = BoxToMaskd(
141+
box_keys="boxes",
142+
box_mask_keys="box_mask",
143+
box_ref_image_keys="image",
144+
label_keys="labels",
145+
min_fg_label=0,
146+
ellipse_mask=False,
147+
)
148+
with self.assertRaises(ValueError) as context:
149+
transform_to_mask(data)
150+
self.assertTrue("Some boxes are larger than the image." in str(context.exception))
151+
134152
@parameterized.expand(TESTS_3D)
135153
def test_value_3d(
136154
self,

0 commit comments

Comments
 (0)