-
Notifications
You must be signed in to change notification settings - Fork 7.2k
fix prototype transforms tests with set agg_method #6934
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
606363b
8ab25ff
bdd4b63
366551c
9ec83c6
0977def
780a9ac
956db81
b0eded3
e8ecc21
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,17 +12,9 @@ | |
import torch.testing | ||
from datasets_utils import combinations_grid | ||
from torch.nn.functional import one_hot | ||
from torch.testing._comparison import ( | ||
assert_equal as _assert_equal, | ||
BooleanPair, | ||
ErrorMeta, | ||
NonePair, | ||
NumberPair, | ||
TensorLikePair, | ||
UnsupportedInputs, | ||
) | ||
from torch.testing._comparison import assert_equal as _assert_equal, BooleanPair, NonePair, NumberPair, TensorLikePair | ||
from torchvision.prototype import features | ||
from torchvision.prototype.transforms.functional import convert_dtype_image_tensor, to_image_tensor | ||
from torchvision.prototype.transforms.functional import to_image_tensor | ||
from torchvision.transforms.functional_tensor import _max_value as get_max_value | ||
|
||
__all__ = [ | ||
|
@@ -54,7 +46,7 @@ | |
] | ||
|
||
|
||
class PILImagePair(TensorLikePair): | ||
class ImagePair(TensorLikePair): | ||
def __init__( | ||
self, | ||
actual, | ||
|
@@ -64,44 +56,13 @@ def __init__( | |
allowed_percentage_diff=None, | ||
**other_parameters, | ||
): | ||
if not any(isinstance(input, PIL.Image.Image) for input in (actual, expected)): | ||
raise UnsupportedInputs() | ||
|
||
# This parameter is ignored to enable checking PIL images to tensor images no on the CPU | ||
other_parameters["check_device"] = False | ||
if all(isinstance(input, PIL.Image.Image) for input in [actual, expected]): | ||
actual, expected = [to_image_tensor(input) for input in [actual, expected]] | ||
|
||
super().__init__(actual, expected, **other_parameters) | ||
self.agg_method = getattr(torch, agg_method) if isinstance(agg_method, str) else agg_method | ||
self.allowed_percentage_diff = allowed_percentage_diff | ||
|
||
def _process_inputs(self, actual, expected, *, id, allow_subclasses): | ||
actual, expected = [ | ||
to_image_tensor(input) if not isinstance(input, torch.Tensor) else features.Image(input) | ||
for input in [actual, expected] | ||
] | ||
# This broadcast is needed, because `features.Mask`'s can have a 2D shape, but converting the equivalent PIL | ||
# image to a tensor adds a singleton leading dimension. | ||
# Although it looks like this belongs in `self._equalize_attributes`, it has to happen here. | ||
# `self._equalize_attributes` is called after `super()._compare_attributes` and that has an unconditional | ||
# shape check that will fail if we don't broadcast before. | ||
try: | ||
actual, expected = torch.broadcast_tensors(actual, expected) | ||
except RuntimeError: | ||
raise ErrorMeta( | ||
AssertionError, | ||
f"The image shapes are not broadcastable: {actual.shape} != {expected.shape}.", | ||
id=id, | ||
) from None | ||
return super()._process_inputs(actual, expected, id=id, allow_subclasses=allow_subclasses) | ||
|
||
def _equalize_attributes(self, actual, expected): | ||
if actual.dtype != expected.dtype: | ||
dtype = torch.promote_types(actual.dtype, expected.dtype) | ||
actual = convert_dtype_image_tensor(actual, dtype) | ||
expected = convert_dtype_image_tensor(expected, dtype) | ||
|
||
return super()._equalize_attributes(actual, expected) | ||
|
||
def compare(self) -> None: | ||
actual, expected = self.actual, self.expected | ||
|
||
|
@@ -111,16 +72,24 @@ def compare(self) -> None: | |
abs_diff = torch.abs(actual - expected) | ||
|
||
if self.allowed_percentage_diff is not None: | ||
percentage_diff = (abs_diff != 0).to(torch.float).mean() | ||
percentage_diff = float((abs_diff.ne(0).to(torch.float64).mean())) | ||
if percentage_diff > self.allowed_percentage_diff: | ||
self._make_error_meta(AssertionError, "percentage mismatch") | ||
raise self._make_error_meta( | ||
AssertionError, | ||
f"{percentage_diff:.1%} elements differ, " | ||
f"but only {self.allowed_percentage_diff:.1%} is allowed", | ||
) | ||
|
||
if self.agg_method is None: | ||
super()._compare_values(actual, expected) | ||
else: | ||
err = self.agg_method(abs_diff.to(torch.float64)) | ||
if err > self.atol: | ||
self._make_error_meta(AssertionError, "aggregated mismatch") | ||
agg_abs_diff = float(self.agg_method(abs_diff.to(torch.float64))) | ||
if agg_abs_diff > self.atol: | ||
raise self._make_error_meta( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was the source of the actual bug. Without the raise we just created an exception, but never did anything with it. Thus, all tests that set |
||
AssertionError, | ||
f"The '{self.agg_method.__name__}' of the absolute difference is {agg_abs_diff}, " | ||
f"but only {self.atol} is allowed.", | ||
) | ||
|
||
|
||
def assert_close( | ||
|
@@ -148,7 +117,7 @@ def assert_close( | |
NonePair, | ||
BooleanPair, | ||
NumberPair, | ||
PILImagePair, | ||
ImagePair, | ||
TensorLikePair, | ||
), | ||
allow_subclasses=allow_subclasses, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ | |
import math | ||
|
||
import numpy as np | ||
import PIL.Image | ||
import pytest | ||
import torch.testing | ||
import torchvision.ops | ||
|
@@ -62,8 +63,8 @@ def __init__( | |
|
||
|
||
DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS = { | ||
(("TestKernels", "test_against_reference"), torch.float32, "cpu"): dict(atol=1e-5, rtol=0, agg_method="mean"), | ||
(("TestKernels", "test_against_reference"), torch.uint8, "cpu"): dict(atol=1e-5, rtol=0, agg_method="mean"), | ||
(("TestKernels", "test_against_reference"), torch.float32, "cpu"): dict(atol=0.9, rtol=0, agg_method="mean"), | ||
(("TestKernels", "test_against_reference"), torch.uint8, "cpu"): dict(atol=255 * 0.9, rtol=0, agg_method="mean"), | ||
|
||
} | ||
|
||
CUDA_VS_CPU_SINGLE_PIXEL_DIFFERENCE = { | ||
|
@@ -74,14 +75,26 @@ def __init__( | |
|
||
def pil_reference_wrapper(pil_kernel): | ||
@functools.wraps(pil_kernel) | ||
def wrapper(image_tensor, *other_args, **kwargs): | ||
if image_tensor.ndim > 3: | ||
def wrapper(input_tensor, *other_args, **kwargs): | ||
if input_tensor.ndim > 3: | ||
raise pytest.UsageError( | ||
f"Can only test single tensor images against PIL, but input has shape {image_tensor.shape}" | ||
f"Can only test single tensor images against PIL, but input has shape {input_tensor.shape}" | ||
) | ||
|
||
# We don't need to convert back to tensor here, since `assert_close` does that automatically. | ||
return pil_kernel(F.to_image_pil(image_tensor), *other_args, **kwargs) | ||
input_pil = F.to_image_pil(input_tensor) | ||
output_pil = pil_kernel(input_pil, *other_args, **kwargs) | ||
if not isinstance(output_pil, PIL.Image.Image): | ||
return output_pil | ||
|
||
output_tensor = F.convert_dtype_image_tensor(F.to_image_tensor(output_pil), dtype=input_tensor.dtype) | ||
pmeier marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
# 2D mask shenanigans | ||
if output_tensor.ndim == 2 and input_tensor.ndim == 3: | ||
output_tensor = output_tensor.unsqueeze(0) | ||
elif output_tensor.ndim == 3 and input_tensor.ndim == 2: | ||
output_tensor = output_tensor.squeeze(0) | ||
Comment on lines
+121
to
+125
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For all other shape mismatches we can let the comparison logic handle the error. |
||
|
||
return output_tensor | ||
|
||
return wrapper | ||
|
||
|
@@ -400,6 +413,23 @@ def _full_affine_params(**partial_params): | |
] | ||
|
||
|
||
def _get_fills(*, num_channels, dtype, vector=True): | ||
|
||
yield None | ||
|
||
max_value = get_max_value(dtype) | ||
# This intentionally gives us a float and an int scalar fill value | ||
yield max_value / 2 | ||
yield max_value | ||
|
||
if not vector: | ||
return | ||
|
||
if dtype.is_floating_point: | ||
yield [0.1 + c / 10 for c in range(num_channels)] | ||
else: | ||
yield [12.0 + c for c in range(num_channels)] | ||
|
||
|
||
def sample_inputs_affine_image_tensor(): | ||
make_affine_image_loaders = functools.partial( | ||
make_image_loaders, sizes=["random"], color_spaces=[features.ColorSpace.RGB], dtypes=[torch.float32] | ||
|
@@ -409,10 +439,7 @@ def sample_inputs_affine_image_tensor(): | |
yield ArgsKwargs(image_loader, **affine_params) | ||
|
||
for image_loader in make_affine_image_loaders(): | ||
fills = [None, 0.5] | ||
if image_loader.num_channels > 1: | ||
fills.extend(vector_fill * image_loader.num_channels for vector_fill in [(0.5,), (1,), [0.5], [1]]) | ||
for fill in fills: | ||
for fill in _get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype): | ||
yield ArgsKwargs(image_loader, **_full_affine_params(), fill=fill) | ||
|
||
for image_loader, interpolation in itertools.product( | ||
|
@@ -631,7 +658,9 @@ def reference_inputs_convert_format_bounding_box(): | |
|
||
|
||
def sample_inputs_convert_color_space_image_tensor(): | ||
color_spaces = list(set(features.ColorSpace) - {features.ColorSpace.OTHER}) | ||
color_spaces = sorted( | ||
set(features.ColorSpace) - {features.ColorSpace.OTHER}, key=lambda color_space: color_space.value | ||
) | ||
Comment on lines
+713
to
+715
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Drive-by that I found while debugging. Without this, individual failing tests are not reproducible because the order is not stable. |
||
|
||
for old_color_space, new_color_space in cycle_over(color_spaces): | ||
for image_loader in make_image_loaders(sizes=["random"], color_spaces=[old_color_space], constant_alpha=True): | ||
|
@@ -678,7 +707,10 @@ def sample_inputs_convert_color_space_video(): | |
sample_inputs_fn=sample_inputs_convert_color_space_image_tensor, | ||
reference_fn=reference_convert_color_space_image_tensor, | ||
reference_inputs_fn=reference_inputs_convert_color_space_image_tensor, | ||
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS, | ||
closeness_kwargs={ | ||
(("TestKernels", "test_against_reference"), torch.float32, "cpu"): dict(atol=2 / 255, rtol=0), | ||
(("TestKernels", "test_against_reference"), torch.uint8, "cpu"): dict(atol=1, rtol=0), | ||
}, | ||
pmeier marked this conversation as resolved.
Show resolved
Hide resolved
|
||
), | ||
KernelInfo( | ||
F.convert_color_space_video, | ||
|
@@ -775,10 +807,7 @@ def sample_inputs_rotate_image_tensor(): | |
yield ArgsKwargs(image_loader, angle=15.0, center=center) | ||
|
||
for image_loader in make_rotate_image_loaders(): | ||
fills = [None, 0.5] | ||
if image_loader.num_channels > 1: | ||
fills.extend(vector_fill * image_loader.num_channels for vector_fill in [(0.5,), (1,), [0.5], [1]]) | ||
for fill in fills: | ||
for fill in _get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype): | ||
yield ArgsKwargs(image_loader, angle=15.0, fill=fill) | ||
|
||
for image_loader, interpolation in itertools.product( | ||
|
@@ -1062,10 +1091,7 @@ def sample_inputs_pad_image_tensor(): | |
yield ArgsKwargs(image_loader, padding=padding) | ||
|
||
for image_loader in make_pad_image_loaders(): | ||
fills = [None, 0.5] | ||
if image_loader.num_channels > 1: | ||
fills.extend(vector_fill * image_loader.num_channels for vector_fill in [(0.5,), (1,), [0.5], [1]]) | ||
for fill in fills: | ||
for fill in _get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype): | ||
yield ArgsKwargs(image_loader, padding=[1], fill=fill) | ||
|
||
for image_loader, padding_mode in itertools.product( | ||
|
@@ -1084,10 +1110,11 @@ def sample_inputs_pad_image_tensor(): | |
def reference_inputs_pad_image_tensor(): | ||
for image_loader, params in itertools.product(make_image_loaders(extra_dims=[()]), _PAD_PARAMS): | ||
# FIXME: PIL kernel doesn't support sequences of length 1 if the number of channels is larger. Shouldn't it? | ||
fills = [None, 128.0, 128] | ||
if params["padding_mode"] == "constant": | ||
fills.append([12.0 + c for c in range(image_loader.num_channels)]) | ||
for fill in fills: | ||
for fill in _get_fills( | ||
num_channels=image_loader.num_channels, | ||
dtype=image_loader.dtype, | ||
vector=params["padding_mode"] == "constant", | ||
): | ||
yield ArgsKwargs(image_loader, fill=fill, **params) | ||
|
||
|
||
|
@@ -1110,8 +1137,10 @@ def sample_inputs_pad_mask(): | |
|
||
|
||
def reference_inputs_pad_mask(): | ||
for image_loader, fill, params in itertools.product(make_image_loaders(extra_dims=[()]), [None, 127], _PAD_PARAMS): | ||
yield ArgsKwargs(image_loader, fill=fill, **params) | ||
for mask_loader, fill, params in itertools.product( | ||
make_mask_loaders(num_objects=[1], extra_dims=[()]), [None, 127], _PAD_PARAMS | ||
): | ||
yield ArgsKwargs(mask_loader, fill=fill, **params) | ||
|
||
|
||
def sample_inputs_pad_video(): | ||
|
@@ -1197,14 +1226,14 @@ def reference_inputs_pad_bounding_box(): | |
|
||
def sample_inputs_perspective_image_tensor(): | ||
for image_loader in make_image_loaders(sizes=["random"]): | ||
for fill in [None, 128.0, 128, [12.0], [12.0 + c for c in range(image_loader.num_channels)]]: | ||
for fill in _get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype): | ||
yield ArgsKwargs(image_loader, None, None, fill=fill, coefficients=_PERSPECTIVE_COEFFS[0]) | ||
|
||
|
||
def reference_inputs_perspective_image_tensor(): | ||
for image_loader, coefficients in itertools.product(make_image_loaders(extra_dims=[()]), _PERSPECTIVE_COEFFS): | ||
# FIXME: PIL kernel doesn't support sequences of length 1 if the number of channels is larger. Shouldn't it? | ||
for fill in [None, 128.0, 128, [12.0 + c for c in range(image_loader.num_channels)]]: | ||
for fill in _get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype): | ||
yield ArgsKwargs(image_loader, None, None, fill=fill, coefficients=coefficients) | ||
|
||
|
||
|
@@ -1271,7 +1300,7 @@ def _get_elastic_displacement(spatial_size): | |
def sample_inputs_elastic_image_tensor(): | ||
for image_loader in make_image_loaders(sizes=["random"]): | ||
displacement = _get_elastic_displacement(image_loader.spatial_size) | ||
for fill in [None, 128.0, 128, [12.0], [12.0 + c for c in range(image_loader.num_channels)]]: | ||
for fill in _get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype): | ||
yield ArgsKwargs(image_loader, displacement=displacement, fill=fill) | ||
|
||
|
||
|
@@ -1285,7 +1314,7 @@ def reference_inputs_elastic_image_tensor(): | |
], | ||
): | ||
displacement = _get_elastic_displacement(image_loader.spatial_size) | ||
for fill in [None, 128.0, 128, [12.0], [12.0 + c for c in range(image_loader.num_channels)]]: | ||
for fill in _get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype): | ||
yield ArgsKwargs(image_loader, interpolation=interpolation, displacement=displacement, fill=fill) | ||
|
||
|
||
|
@@ -2070,6 +2099,17 @@ def sample_inputs_ten_crop_video(): | |
yield ArgsKwargs(video_loader, size=size) | ||
|
||
|
||
def multi_crop_pil_reference_wrapper(pil_kernel): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Small helper since the regular |
||
def wrapper(input_tensor, *other_args, **kwargs): | ||
output = pil_reference_wrapper(pil_kernel)(input_tensor, *other_args, **kwargs) | ||
return type(output)( | ||
F.convert_dtype_image_tensor(F.to_image_tensor(output_pil), dtype=input_tensor.dtype) | ||
for output_pil in output | ||
) | ||
|
||
return wrapper | ||
|
||
|
||
_common_five_ten_crop_marks = [ | ||
xfail_jit_python_scalar_arg("size"), | ||
mark_framework_limitation(("TestKernels", "test_batched_vs_single"), "Custom batching needed."), | ||
|
@@ -2080,7 +2120,7 @@ def sample_inputs_ten_crop_video(): | |
KernelInfo( | ||
F.five_crop_image_tensor, | ||
sample_inputs_fn=sample_inputs_five_crop_image_tensor, | ||
reference_fn=pil_reference_wrapper(F.five_crop_image_pil), | ||
reference_fn=multi_crop_pil_reference_wrapper(F.five_crop_image_pil), | ||
reference_inputs_fn=reference_inputs_five_crop_image_tensor, | ||
test_marks=_common_five_ten_crop_marks, | ||
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS, | ||
|
@@ -2093,7 +2133,7 @@ def sample_inputs_ten_crop_video(): | |
KernelInfo( | ||
F.ten_crop_image_tensor, | ||
sample_inputs_fn=sample_inputs_ten_crop_image_tensor, | ||
reference_fn=pil_reference_wrapper(F.ten_crop_image_pil), | ||
reference_fn=multi_crop_pil_reference_wrapper(F.ten_crop_image_pil), | ||
reference_inputs_fn=reference_inputs_ten_crop_image_tensor, | ||
test_marks=_common_five_ten_crop_marks, | ||
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1005,7 +1005,7 @@ def make_datapoints(self, supports_pil=True, image_dtype=torch.uint8): | |
|
||
dp = (conv_fn(feature_image), feature_mask) | ||
dp_ref = ( | ||
to_image_pil(feature_image) if supports_pil else torch.Tensor(feature_image), | ||
to_image_pil(feature_image) if supports_pil else feature_image.as_subclass(torch.Tensor), | ||
pmeier marked this conversation as resolved.
Show resolved
Hide resolved
|
||
to_image_pil(feature_mask), | ||
) | ||
|
||
|
@@ -1019,12 +1019,16 @@ def check(self, t, t_ref, data_kwargs=None): | |
for dp, dp_ref in self.make_datapoints(**data_kwargs or dict()): | ||
|
||
self.set_seed() | ||
output = t(dp) | ||
actual = actual_image, actual_mask = t(dp) | ||
|
||
self.set_seed() | ||
expected_output = t_ref(*dp_ref) | ||
expected_image, expected_mask = t_ref(*dp_ref) | ||
if isinstance(actual_image, torch.Tensor) and not isinstance(expected_image, torch.Tensor): | ||
expected_image = legacy_F.pil_to_tensor(expected_image) | ||
expected_mask = legacy_F.pil_to_tensor(expected_mask).squeeze(0) | ||
expected = (expected_image, expected_mask) | ||
Comment on lines
+1030
to
+1034
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was relying on the mixed PIL / tensor image comparison that was removed above. Thus, we do it manually here. |
||
|
||
assert_equal(output, expected_output) | ||
assert_equal(actual, expected) | ||
|
||
@pytest.mark.parametrize( | ||
("t_ref", "t", "data_kwargs"), | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -237,7 +237,6 @@ def test_against_reference(self, test_id, info, args_kwargs): | |
assert_close( | ||
actual, | ||
expected, | ||
check_dtype=False, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can be strict here now, since we perform the conversion correctly. |
||
**info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device), | ||
) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've refactored this to not handle mixed PIL / tensor image pairs. It was a problem if the tolerance is set for floating point images, i.e. in the range
[0.0, 1.0]
, but the comparison converted to uint8, which needs higher tolerances.