Skip to content

Commit 8e8a208

Browse files
[Cherry-pick for 0.20] Expose transforms.v2 utils for writing custom transforms (#8673)
Co-authored-by: venkatram-dev <45727389+venkatram-dev@users.noreply.github.com>
1 parent 2d8a288 commit 8e8a208

File tree

2 files changed

+48
-0
lines changed

2 files changed

+48
-0
lines changed

test/test_transforms_v2.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6169,3 +6169,50 @@ def test_transform_sequence_len_error(self, quality):
61696169
def test_transform_invalid_quality_error(self, quality):
61706170
with pytest.raises(ValueError, match="quality must be an integer from 1 to 100"):
61716171
transforms.JPEG(quality=quality)
6172+
6173+
6174+
class TestUtils:
6175+
# TODO: Still need to test has_all, has_any, check_type and get_bouding_boxes
6176+
@pytest.mark.parametrize(
6177+
"make_input1", [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask]
6178+
)
6179+
@pytest.mark.parametrize(
6180+
"make_input2", [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask]
6181+
)
6182+
@pytest.mark.parametrize("query", [transforms.query_size, transforms.query_chw])
6183+
def test_query_size_and_query_chw(self, make_input1, make_input2, query):
6184+
size = (32, 64)
6185+
input1 = make_input1(size)
6186+
input2 = make_input2(size)
6187+
6188+
if query is transforms.query_chw and not any(
6189+
transforms.check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video))
6190+
for inpt in (input1, input2)
6191+
):
6192+
return
6193+
6194+
expected = size if query is transforms.query_size else ((3,) + size)
6195+
assert query([input1, input2]) == expected
6196+
6197+
@pytest.mark.parametrize(
6198+
"make_input1", [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask]
6199+
)
6200+
@pytest.mark.parametrize(
6201+
"make_input2", [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask]
6202+
)
6203+
@pytest.mark.parametrize("query", [transforms.query_size, transforms.query_chw])
6204+
def test_different_sizes(self, make_input1, make_input2, query):
6205+
input1 = make_input1((10, 10))
6206+
input2 = make_input2((20, 20))
6207+
if query is transforms.query_chw and not all(
6208+
transforms.check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video))
6209+
for inpt in (input1, input2)
6210+
):
6211+
return
6212+
with pytest.raises(ValueError, match="Found multiple"):
6213+
query([input1, input2])
6214+
6215+
@pytest.mark.parametrize("query", [transforms.query_size, transforms.query_chw])
6216+
def test_no_valid_input(self, query):
6217+
with pytest.raises(TypeError, match="No image"):
6218+
query(["blah"])

torchvision/transforms/v2/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,5 +55,6 @@
5555
)
5656
from ._temporal import UniformTemporalSubsample
5757
from ._type_conversion import PILToTensor, ToImage, ToPILImage, ToPureTensor
58+
from ._utils import check_type, get_bounding_boxes, has_all, has_any, query_chw, query_size
5859

5960
from ._deprecated import ToTensor # usort: skip

0 commit comments

Comments
 (0)