@@ -6169,3 +6169,50 @@ def test_transform_sequence_len_error(self, quality):
6169
6169
def test_transform_invalid_quality_error (self , quality ):
6170
6170
with pytest .raises (ValueError , match = "quality must be an integer from 1 to 100" ):
6171
6171
transforms .JPEG (quality = quality )
6172
+
6173
+
6174
+ class TestUtils :
6175
+ # TODO: Still need to test has_all, has_any, check_type and get_bouding_boxes
6176
+ @pytest .mark .parametrize (
6177
+ "make_input1" , [make_image_tensor , make_image_pil , make_image , make_bounding_boxes , make_segmentation_mask ]
6178
+ )
6179
+ @pytest .mark .parametrize (
6180
+ "make_input2" , [make_image_tensor , make_image_pil , make_image , make_bounding_boxes , make_segmentation_mask ]
6181
+ )
6182
+ @pytest .mark .parametrize ("query" , [transforms .query_size , transforms .query_chw ])
6183
+ def test_query_size_and_query_chw (self , make_input1 , make_input2 , query ):
6184
+ size = (32 , 64 )
6185
+ input1 = make_input1 (size )
6186
+ input2 = make_input2 (size )
6187
+
6188
+ if query is transforms .query_chw and not any (
6189
+ transforms .check_type (inpt , (is_pure_tensor , tv_tensors .Image , PIL .Image .Image , tv_tensors .Video ))
6190
+ for inpt in (input1 , input2 )
6191
+ ):
6192
+ return
6193
+
6194
+ expected = size if query is transforms .query_size else ((3 ,) + size )
6195
+ assert query ([input1 , input2 ]) == expected
6196
+
6197
+ @pytest .mark .parametrize (
6198
+ "make_input1" , [make_image_tensor , make_image_pil , make_image , make_bounding_boxes , make_segmentation_mask ]
6199
+ )
6200
+ @pytest .mark .parametrize (
6201
+ "make_input2" , [make_image_tensor , make_image_pil , make_image , make_bounding_boxes , make_segmentation_mask ]
6202
+ )
6203
+ @pytest .mark .parametrize ("query" , [transforms .query_size , transforms .query_chw ])
6204
+ def test_different_sizes (self , make_input1 , make_input2 , query ):
6205
+ input1 = make_input1 ((10 , 10 ))
6206
+ input2 = make_input2 ((20 , 20 ))
6207
+ if query is transforms .query_chw and not all (
6208
+ transforms .check_type (inpt , (is_pure_tensor , tv_tensors .Image , PIL .Image .Image , tv_tensors .Video ))
6209
+ for inpt in (input1 , input2 )
6210
+ ):
6211
+ return
6212
+ with pytest .raises (ValueError , match = "Found multiple" ):
6213
+ query ([input1 , input2 ])
6214
+
6215
+ @pytest .mark .parametrize ("query" , [transforms .query_size , transforms .query_chw ])
6216
+ def test_no_valid_input (self , query ):
6217
+ with pytest .raises (TypeError , match = "No image" ):
6218
+ query (["blah" ])
0 commit comments