Skip to content
69 changes: 19 additions & 50 deletions test/prototype_common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,9 @@
import torch.testing
from datasets_utils import combinations_grid
from torch.nn.functional import one_hot
from torch.testing._comparison import (
assert_equal as _assert_equal,
BooleanPair,
ErrorMeta,
NonePair,
NumberPair,
TensorLikePair,
UnsupportedInputs,
)
from torch.testing._comparison import assert_equal as _assert_equal, BooleanPair, NonePair, NumberPair, TensorLikePair
from torchvision.prototype import features
from torchvision.prototype.transforms.functional import convert_dtype_image_tensor, to_image_tensor
from torchvision.prototype.transforms.functional import to_image_tensor
from torchvision.transforms.functional_tensor import _max_value as get_max_value

__all__ = [
Expand Down Expand Up @@ -54,7 +46,7 @@
]


class PILImagePair(TensorLikePair):
class ImagePair(TensorLikePair):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've refactored this to not handle mixed PIL / tensor image pairs. It was a problem if the tolerance is set for floating point images, i.e. in the range [0.0, 1.0], but the comparison converted to uint8, which needs higher tolerances.

def __init__(
self,
actual,
Expand All @@ -64,44 +56,13 @@ def __init__(
allowed_percentage_diff=None,
**other_parameters,
):
if not any(isinstance(input, PIL.Image.Image) for input in (actual, expected)):
raise UnsupportedInputs()

# This parameter is ignored to enable checking PIL images to tensor images no on the CPU
other_parameters["check_device"] = False
if all(isinstance(input, PIL.Image.Image) for input in [actual, expected]):
actual, expected = [to_image_tensor(input) for input in [actual, expected]]

super().__init__(actual, expected, **other_parameters)
self.agg_method = getattr(torch, agg_method) if isinstance(agg_method, str) else agg_method
self.allowed_percentage_diff = allowed_percentage_diff

def _process_inputs(self, actual, expected, *, id, allow_subclasses):
actual, expected = [
to_image_tensor(input) if not isinstance(input, torch.Tensor) else features.Image(input)
for input in [actual, expected]
]
# This broadcast is needed, because `features.Mask`'s can have a 2D shape, but converting the equivalent PIL
# image to a tensor adds a singleton leading dimension.
# Although it looks like this belongs in `self._equalize_attributes`, it has to happen here.
# `self._equalize_attributes` is called after `super()._compare_attributes` and that has an unconditional
# shape check that will fail if we don't broadcast before.
try:
actual, expected = torch.broadcast_tensors(actual, expected)
except RuntimeError:
raise ErrorMeta(
AssertionError,
f"The image shapes are not broadcastable: {actual.shape} != {expected.shape}.",
id=id,
) from None
return super()._process_inputs(actual, expected, id=id, allow_subclasses=allow_subclasses)

def _equalize_attributes(self, actual, expected):
if actual.dtype != expected.dtype:
dtype = torch.promote_types(actual.dtype, expected.dtype)
actual = convert_dtype_image_tensor(actual, dtype)
expected = convert_dtype_image_tensor(expected, dtype)

return super()._equalize_attributes(actual, expected)

def compare(self) -> None:
actual, expected = self.actual, self.expected

Expand All @@ -111,16 +72,24 @@ def compare(self) -> None:
abs_diff = torch.abs(actual - expected)

if self.allowed_percentage_diff is not None:
percentage_diff = (abs_diff != 0).to(torch.float).mean()
percentage_diff = float((abs_diff.ne(0).to(torch.float64).mean()))
if percentage_diff > self.allowed_percentage_diff:
self._make_error_meta(AssertionError, "percentage mismatch")
raise self._make_error_meta(
AssertionError,
f"{percentage_diff:.1%} elements differ, "
f"but only {self.allowed_percentage_diff:.1%} is allowed",
)

if self.agg_method is None:
super()._compare_values(actual, expected)
else:
err = self.agg_method(abs_diff.to(torch.float64))
if err > self.atol:
self._make_error_meta(AssertionError, "aggregated mismatch")
agg_abs_diff = float(self.agg_method(abs_diff.to(torch.float64)))
if agg_abs_diff > self.atol:
raise self._make_error_meta(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was the source of the actual bug. Without the raise we just created an exception, but never did anything with it. Thus, all tests that set agg_method in their closeness_kwargs passed without a value check.

AssertionError,
f"The '{self.agg_method.__name__}' of the absolute difference is {agg_abs_diff}, "
f"but only {self.atol} is allowed.",
)


def assert_close(
Expand Down Expand Up @@ -148,7 +117,7 @@ def assert_close(
NonePair,
BooleanPair,
NumberPair,
PILImagePair,
ImagePair,
TensorLikePair,
),
allow_subclasses=allow_subclasses,
Expand Down
106 changes: 73 additions & 33 deletions test/prototype_transforms_kernel_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import math

import numpy as np
import PIL.Image
import pytest
import torch.testing
import torchvision.ops
Expand Down Expand Up @@ -62,8 +63,8 @@ def __init__(


DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS = {
(("TestKernels", "test_against_reference"), torch.float32, "cpu"): dict(atol=1e-5, rtol=0, agg_method="mean"),
(("TestKernels", "test_against_reference"), torch.uint8, "cpu"): dict(atol=1e-5, rtol=0, agg_method="mean"),
(("TestKernels", "test_against_reference"), torch.float32, "cpu"): dict(atol=0.9, rtol=0, agg_method="mean"),
(("TestKernels", "test_against_reference"), torch.uint8, "cpu"): dict(atol=255 * 0.9, rtol=0, agg_method="mean"),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are crazy tolerances and render the test effectively useless. This PR just lays the foundation to fix these in the near future. I'll open an issue ASAP with a game plan to review all the operators again that need these tolerances for some reason.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are crazy tolerances indeed. From your earlier comment, I understand that the tests were effectively not throwing exceptions, so here you adjust the values to make them pass and then revisit all kernels that fail. Is my understanding correct?

Copy link
Collaborator Author

@pmeier pmeier Nov 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. See #6937.

}

CUDA_VS_CPU_SINGLE_PIXEL_DIFFERENCE = {
Expand All @@ -74,14 +75,26 @@ def __init__(

def pil_reference_wrapper(pil_kernel):
@functools.wraps(pil_kernel)
def wrapper(image_tensor, *other_args, **kwargs):
if image_tensor.ndim > 3:
def wrapper(input_tensor, *other_args, **kwargs):
if input_tensor.ndim > 3:
raise pytest.UsageError(
f"Can only test single tensor images against PIL, but input has shape {image_tensor.shape}"
f"Can only test single tensor images against PIL, but input has shape {input_tensor.shape}"
)

# We don't need to convert back to tensor here, since `assert_close` does that automatically.
return pil_kernel(F.to_image_pil(image_tensor), *other_args, **kwargs)
input_pil = F.to_image_pil(input_tensor)
output_pil = pil_kernel(input_pil, *other_args, **kwargs)
if not isinstance(output_pil, PIL.Image.Image):
return output_pil

output_tensor = F.convert_dtype_image_tensor(F.to_image_tensor(output_pil), dtype=input_tensor.dtype)

# 2D mask shenanigans
if output_tensor.ndim == 2 and input_tensor.ndim == 3:
output_tensor = output_tensor.unsqueeze(0)
elif output_tensor.ndim == 3 and input_tensor.ndim == 2:
output_tensor = output_tensor.squeeze(0)
Comment on lines +121 to +125
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For all other shape mismatches we can let the comparison logic handle the error.


return output_tensor

return wrapper

Expand Down Expand Up @@ -400,6 +413,23 @@ def _full_affine_params(**partial_params):
]


def _get_fills(*, num_channels, dtype, vector=True):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fill was creating issues because we set something like 128.0 for floating point images. This is not an issue in general, but affects stuff like aggregated diff. This just generalizes the fill generation for all tests below.

yield None

max_value = get_max_value(dtype)
# This intentionally gives us a float and an int scalar fill value
yield max_value / 2
yield max_value

if not vector:
return

if dtype.is_floating_point:
yield [0.1 + c / 10 for c in range(num_channels)]
else:
yield [12.0 + c for c in range(num_channels)]


def sample_inputs_affine_image_tensor():
make_affine_image_loaders = functools.partial(
make_image_loaders, sizes=["random"], color_spaces=[features.ColorSpace.RGB], dtypes=[torch.float32]
Expand All @@ -409,10 +439,7 @@ def sample_inputs_affine_image_tensor():
yield ArgsKwargs(image_loader, **affine_params)

for image_loader in make_affine_image_loaders():
fills = [None, 0.5]
if image_loader.num_channels > 1:
fills.extend(vector_fill * image_loader.num_channels for vector_fill in [(0.5,), (1,), [0.5], [1]])
for fill in fills:
for fill in _get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
yield ArgsKwargs(image_loader, **_full_affine_params(), fill=fill)

for image_loader, interpolation in itertools.product(
Expand Down Expand Up @@ -631,7 +658,9 @@ def reference_inputs_convert_format_bounding_box():


def sample_inputs_convert_color_space_image_tensor():
color_spaces = list(set(features.ColorSpace) - {features.ColorSpace.OTHER})
color_spaces = sorted(
set(features.ColorSpace) - {features.ColorSpace.OTHER}, key=lambda color_space: color_space.value
)
Comment on lines +713 to +715
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drive-by that I found while debugging. Without this, individual failing tests are not reproducible because the order is not stable.


for old_color_space, new_color_space in cycle_over(color_spaces):
for image_loader in make_image_loaders(sizes=["random"], color_spaces=[old_color_space], constant_alpha=True):
Expand Down Expand Up @@ -678,7 +707,10 @@ def sample_inputs_convert_color_space_video():
sample_inputs_fn=sample_inputs_convert_color_space_image_tensor,
reference_fn=reference_convert_color_space_image_tensor,
reference_inputs_fn=reference_inputs_convert_color_space_image_tensor,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
closeness_kwargs={
(("TestKernels", "test_against_reference"), torch.float32, "cpu"): dict(atol=2 / 255, rtol=0),
(("TestKernels", "test_against_reference"), torch.uint8, "cpu"): dict(atol=1, rtol=0),
},
),
KernelInfo(
F.convert_color_space_video,
Expand Down Expand Up @@ -775,10 +807,7 @@ def sample_inputs_rotate_image_tensor():
yield ArgsKwargs(image_loader, angle=15.0, center=center)

for image_loader in make_rotate_image_loaders():
fills = [None, 0.5]
if image_loader.num_channels > 1:
fills.extend(vector_fill * image_loader.num_channels for vector_fill in [(0.5,), (1,), [0.5], [1]])
for fill in fills:
for fill in _get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
yield ArgsKwargs(image_loader, angle=15.0, fill=fill)

for image_loader, interpolation in itertools.product(
Expand Down Expand Up @@ -1062,10 +1091,7 @@ def sample_inputs_pad_image_tensor():
yield ArgsKwargs(image_loader, padding=padding)

for image_loader in make_pad_image_loaders():
fills = [None, 0.5]
if image_loader.num_channels > 1:
fills.extend(vector_fill * image_loader.num_channels for vector_fill in [(0.5,), (1,), [0.5], [1]])
for fill in fills:
for fill in _get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
yield ArgsKwargs(image_loader, padding=[1], fill=fill)

for image_loader, padding_mode in itertools.product(
Expand All @@ -1084,10 +1110,11 @@ def sample_inputs_pad_image_tensor():
def reference_inputs_pad_image_tensor():
for image_loader, params in itertools.product(make_image_loaders(extra_dims=[()]), _PAD_PARAMS):
# FIXME: PIL kernel doesn't support sequences of length 1 if the number of channels is larger. Shouldn't it?
fills = [None, 128.0, 128]
if params["padding_mode"] == "constant":
fills.append([12.0 + c for c in range(image_loader.num_channels)])
for fill in fills:
for fill in _get_fills(
num_channels=image_loader.num_channels,
dtype=image_loader.dtype,
vector=params["padding_mode"] == "constant",
):
yield ArgsKwargs(image_loader, fill=fill, **params)


Expand All @@ -1110,8 +1137,10 @@ def sample_inputs_pad_mask():


def reference_inputs_pad_mask():
for image_loader, fill, params in itertools.product(make_image_loaders(extra_dims=[()]), [None, 127], _PAD_PARAMS):
yield ArgsKwargs(image_loader, fill=fill, **params)
for mask_loader, fill, params in itertools.product(
make_mask_loaders(num_objects=[1], extra_dims=[()]), [None, 127], _PAD_PARAMS
):
yield ArgsKwargs(mask_loader, fill=fill, **params)


def sample_inputs_pad_video():
Expand Down Expand Up @@ -1197,14 +1226,14 @@ def reference_inputs_pad_bounding_box():

def sample_inputs_perspective_image_tensor():
for image_loader in make_image_loaders(sizes=["random"]):
for fill in [None, 128.0, 128, [12.0], [12.0 + c for c in range(image_loader.num_channels)]]:
for fill in _get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
yield ArgsKwargs(image_loader, None, None, fill=fill, coefficients=_PERSPECTIVE_COEFFS[0])


def reference_inputs_perspective_image_tensor():
for image_loader, coefficients in itertools.product(make_image_loaders(extra_dims=[()]), _PERSPECTIVE_COEFFS):
# FIXME: PIL kernel doesn't support sequences of length 1 if the number of channels is larger. Shouldn't it?
for fill in [None, 128.0, 128, [12.0 + c for c in range(image_loader.num_channels)]]:
for fill in _get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
yield ArgsKwargs(image_loader, None, None, fill=fill, coefficients=coefficients)


Expand Down Expand Up @@ -1271,7 +1300,7 @@ def _get_elastic_displacement(spatial_size):
def sample_inputs_elastic_image_tensor():
for image_loader in make_image_loaders(sizes=["random"]):
displacement = _get_elastic_displacement(image_loader.spatial_size)
for fill in [None, 128.0, 128, [12.0], [12.0 + c for c in range(image_loader.num_channels)]]:
for fill in _get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
yield ArgsKwargs(image_loader, displacement=displacement, fill=fill)


Expand All @@ -1285,7 +1314,7 @@ def reference_inputs_elastic_image_tensor():
],
):
displacement = _get_elastic_displacement(image_loader.spatial_size)
for fill in [None, 128.0, 128, [12.0], [12.0 + c for c in range(image_loader.num_channels)]]:
for fill in _get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
yield ArgsKwargs(image_loader, interpolation=interpolation, displacement=displacement, fill=fill)


Expand Down Expand Up @@ -2070,6 +2099,17 @@ def sample_inputs_ten_crop_video():
yield ArgsKwargs(video_loader, size=size)


def multi_crop_pil_reference_wrapper(pil_kernel):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small helper since the regular pil_reference_wrapper cannot handle conversion of these tuple or list outputs.

def wrapper(input_tensor, *other_args, **kwargs):
output = pil_reference_wrapper(pil_kernel)(input_tensor, *other_args, **kwargs)
return type(output)(
F.convert_dtype_image_tensor(F.to_image_tensor(output_pil), dtype=input_tensor.dtype)
for output_pil in output
)

return wrapper


_common_five_ten_crop_marks = [
xfail_jit_python_scalar_arg("size"),
mark_framework_limitation(("TestKernels", "test_batched_vs_single"), "Custom batching needed."),
Expand All @@ -2080,7 +2120,7 @@ def sample_inputs_ten_crop_video():
KernelInfo(
F.five_crop_image_tensor,
sample_inputs_fn=sample_inputs_five_crop_image_tensor,
reference_fn=pil_reference_wrapper(F.five_crop_image_pil),
reference_fn=multi_crop_pil_reference_wrapper(F.five_crop_image_pil),
reference_inputs_fn=reference_inputs_five_crop_image_tensor,
test_marks=_common_five_ten_crop_marks,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
Expand All @@ -2093,7 +2133,7 @@ def sample_inputs_ten_crop_video():
KernelInfo(
F.ten_crop_image_tensor,
sample_inputs_fn=sample_inputs_ten_crop_image_tensor,
reference_fn=pil_reference_wrapper(F.ten_crop_image_pil),
reference_fn=multi_crop_pil_reference_wrapper(F.ten_crop_image_pil),
reference_inputs_fn=reference_inputs_ten_crop_image_tensor,
test_marks=_common_five_ten_crop_marks,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
Expand Down
12 changes: 8 additions & 4 deletions test/test_prototype_transforms_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,7 +1005,7 @@ def make_datapoints(self, supports_pil=True, image_dtype=torch.uint8):

dp = (conv_fn(feature_image), feature_mask)
dp_ref = (
to_image_pil(feature_image) if supports_pil else torch.Tensor(feature_image),
to_image_pil(feature_image) if supports_pil else feature_image.as_subclass(torch.Tensor),
to_image_pil(feature_mask),
)

Expand All @@ -1019,12 +1019,16 @@ def check(self, t, t_ref, data_kwargs=None):
for dp, dp_ref in self.make_datapoints(**data_kwargs or dict()):

self.set_seed()
output = t(dp)
actual = actual_image, actual_mask = t(dp)

self.set_seed()
expected_output = t_ref(*dp_ref)
expected_image, expected_mask = t_ref(*dp_ref)
if isinstance(actual_image, torch.Tensor) and not isinstance(expected_image, torch.Tensor):
expected_image = legacy_F.pil_to_tensor(expected_image)
expected_mask = legacy_F.pil_to_tensor(expected_mask).squeeze(0)
expected = (expected_image, expected_mask)
Comment on lines +1030 to +1034
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was relying on the mixed PIL / tensor image comparison that was removed above. Thus, we do it manually here.


assert_equal(output, expected_output)
assert_equal(actual, expected)

@pytest.mark.parametrize(
("t_ref", "t", "data_kwargs"),
Expand Down
1 change: 0 additions & 1 deletion test/test_prototype_transforms_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,6 @@ def test_against_reference(self, test_id, info, args_kwargs):
assert_close(
actual,
expected,
check_dtype=False,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can be strict here now, since we perform the conversion correctly.

**info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device),
)

Expand Down