Skip to content
69 changes: 19 additions & 50 deletions test/prototype_common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,9 @@
import torch.testing
from datasets_utils import combinations_grid
from torch.nn.functional import one_hot
from torch.testing._comparison import (
assert_equal as _assert_equal,
BooleanPair,
ErrorMeta,
NonePair,
NumberPair,
TensorLikePair,
UnsupportedInputs,
)
from torch.testing._comparison import assert_equal as _assert_equal, BooleanPair, NonePair, NumberPair, TensorLikePair
from torchvision.prototype import features
from torchvision.prototype.transforms.functional import convert_dtype_image_tensor, to_image_tensor
from torchvision.prototype.transforms.functional import to_image_tensor
from torchvision.transforms.functional_tensor import _max_value as get_max_value

__all__ = [
Expand Down Expand Up @@ -54,7 +46,7 @@
]


class PILImagePair(TensorLikePair):
class ImagePair(TensorLikePair):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've refactored this to not handle mixed PIL / tensor image pairs. It was a problem if the tolerance is set for floating point images, i.e. in the range [0.0, 1.0], but the comparison converted to uint8, which needs higher tolerances.

def __init__(
self,
actual,
Expand All @@ -64,44 +56,13 @@ def __init__(
allowed_percentage_diff=None,
**other_parameters,
):
if not any(isinstance(input, PIL.Image.Image) for input in (actual, expected)):
raise UnsupportedInputs()

# This parameter is ignored to enable checking PIL images to tensor images no on the CPU
other_parameters["check_device"] = False
if all(isinstance(input, PIL.Image.Image) for input in [actual, expected]):
actual, expected = [to_image_tensor(input) for input in [actual, expected]]

super().__init__(actual, expected, **other_parameters)
self.agg_method = getattr(torch, agg_method) if isinstance(agg_method, str) else agg_method
self.allowed_percentage_diff = allowed_percentage_diff

def _process_inputs(self, actual, expected, *, id, allow_subclasses):
actual, expected = [
to_image_tensor(input) if not isinstance(input, torch.Tensor) else features.Image(input)
for input in [actual, expected]
]
# This broadcast is needed, because `features.Mask`'s can have a 2D shape, but converting the equivalent PIL
# image to a tensor adds a singleton leading dimension.
# Although it looks like this belongs in `self._equalize_attributes`, it has to happen here.
# `self._equalize_attributes` is called after `super()._compare_attributes` and that has an unconditional
# shape check that will fail if we don't broadcast before.
try:
actual, expected = torch.broadcast_tensors(actual, expected)
except RuntimeError:
raise ErrorMeta(
AssertionError,
f"The image shapes are not broadcastable: {actual.shape} != {expected.shape}.",
id=id,
) from None
return super()._process_inputs(actual, expected, id=id, allow_subclasses=allow_subclasses)

def _equalize_attributes(self, actual, expected):
if actual.dtype != expected.dtype:
dtype = torch.promote_types(actual.dtype, expected.dtype)
actual = convert_dtype_image_tensor(actual, dtype)
expected = convert_dtype_image_tensor(expected, dtype)

return super()._equalize_attributes(actual, expected)

def compare(self) -> None:
actual, expected = self.actual, self.expected

Expand All @@ -111,16 +72,24 @@ def compare(self) -> None:
abs_diff = torch.abs(actual - expected)

if self.allowed_percentage_diff is not None:
percentage_diff = (abs_diff != 0).to(torch.float).mean()
percentage_diff = float((abs_diff.ne(0).to(torch.float64).mean()))
if percentage_diff > self.allowed_percentage_diff:
self._make_error_meta(AssertionError, "percentage mismatch")
raise self._make_error_meta(
AssertionError,
f"{percentage_diff:.1%} elements differ, "
f"but only {self.allowed_percentage_diff:.1%} is allowed",
)

if self.agg_method is None:
super()._compare_values(actual, expected)
else:
err = self.agg_method(abs_diff.to(torch.float64))
if err > self.atol:
self._make_error_meta(AssertionError, "aggregated mismatch")
agg_abs_diff = float(self.agg_method(abs_diff.to(torch.float64)))
if agg_abs_diff > self.atol:
raise self._make_error_meta(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was the source of the actual bug. Without the raise we just created an exception, but never did anything with it. Thus, all tests that set agg_method in their closeness_kwargs passed without a value check.

AssertionError,
f"The '{self.agg_method.__name__}' of the absolute difference is {agg_abs_diff}, "
f"but only {self.atol} is allowed.",
)


def assert_close(
Expand Down Expand Up @@ -148,7 +117,7 @@ def assert_close(
NonePair,
BooleanPair,
NumberPair,
PILImagePair,
ImagePair,
TensorLikePair,
),
allow_subclasses=allow_subclasses,
Expand Down
Loading