@@ -131,6 +131,24 @@ def test_value_3d_mask(self):
131131 assert_allclose (data_back ["boxes" ], data ["boxes" ], type_test = False , device_test = False , atol = 1e-3 )
132132 assert_allclose (data_back ["labels" ], data ["labels" ], type_test = False , device_test = False , atol = 1e-3 )
133133
134+ def test_shape_assertion (self ):
135+ test_dtype = torch .float32
136+ image = np .zeros ((1 , 10 , 10 , 10 ))
137+ boxes = np .array ([[7 , 8 , 9 , 10 , 12 , 13 ]])
138+ data = {"image" : image , "boxes" : boxes , "labels" : np .array ((1 ,))}
139+ data = CastToTyped (keys = ["image" , "boxes" ], dtype = test_dtype )(data )
140+ transform_to_mask = BoxToMaskd (
141+ box_keys = "boxes" ,
142+ box_mask_keys = "box_mask" ,
143+ box_ref_image_keys = "image" ,
144+ label_keys = "labels" ,
145+ min_fg_label = 0 ,
146+ ellipse_mask = False ,
147+ )
148+ with self .assertRaises (ValueError ) as context :
149+ transform_to_mask (data )
150+ self .assertTrue ("Some boxes are larger than the image." in str (context .exception ))
151+
134152 @parameterized .expand (TESTS_3D )
135153 def test_value_3d (
136154 self ,
0 commit comments