Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions test/prototype_common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,21 +628,34 @@ def mark_framework_limitation(test_id, reason):


class InfoBase:
def __init__(self, *, id, test_marks=None, closeness_kwargs=None):
def __init__(
self,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change just moves the comments to the signature as it is done in all other related classes.

*,
# Identifier if the info that shows up the parametrization.
self.id = id
id,
# Test markers that will be (conditionally) applied to an `ArgsKwargs` parametrization.
# See the `TestMark` class for details
self.test_marks = test_marks or []
# Additional parameters, e.g. `rtol=1e-3`, passed to `assert_close`.
self.closeness_kwargs = closeness_kwargs or dict()
test_marks=None,
# Additional parameters, e.g. `rtol=1e-3`, passed to `assert_close`. Keys are a 3-tuple of `test_id` (see
# `TestMark`), the dtype, and the device.
closeness_kwargs=None,
):
self.id = id

self.test_marks = test_marks or []
test_marks_map = defaultdict(list)
for test_mark in self.test_marks:
test_marks_map[test_mark.test_id].append(test_mark)
self._test_marks_map = dict(test_marks_map)

self.closeness_kwargs = closeness_kwargs or dict()

def get_marks(self, test_id, args_kwargs):
return [
test_mark.mark for test_mark in self._test_marks_map.get(test_id, []) if test_mark.condition(args_kwargs)
]

def get_closeness_kwargs(self, test_id, *, dtype, device):
if isinstance(device, torch.device):
device = device.type
return self.closeness_kwargs.get((test_id, dtype, device), dict())
83 changes: 42 additions & 41 deletions test/prototype_transforms_kernel_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,10 @@ def __init__(
self.reference_inputs_fn = reference_inputs_fn


DEFAULT_IMAGE_CLOSENESS_KWARGS = dict(
atol=1e-5,
rtol=0,
agg_method="mean",
)
DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS = {
(("TestKernels", "test_against_reference"), torch.float32, "cpu"): dict(atol=1e-5, rtol=0, agg_method="mean"),
(("TestKernels", "test_against_reference"), torch.uint8, "cpu"): dict(atol=1e-5, rtol=0, agg_method="mean"),
}
Comment on lines +64 to +67
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two lessons here:

  1. We only need the tolerances for the reference tests, but they were applied everywhere. There was no harm here, but this could have hidden bugs before.
  2. Some of our kernels like resize produce quite large differences for some pixels, which are hidden by aggregating them. Maybe we should review and use stricter "default" tolerances and just increase it for the few that need more.



def pil_reference_wrapper(pil_kernel):
Expand Down Expand Up @@ -176,7 +175,7 @@ def reference_inputs_flip_bounding_box():
sample_inputs_fn=sample_inputs_horizontal_flip_image_tensor,
reference_fn=pil_reference_wrapper(F.horizontal_flip_image_pil),
reference_inputs_fn=reference_inputs_horizontal_flip_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
),
KernelInfo(
F.horizontal_flip_bounding_box,
Expand Down Expand Up @@ -320,7 +319,7 @@ def reference_inputs_resize_bounding_box():
sample_inputs_fn=sample_inputs_resize_image_tensor,
reference_fn=reference_resize_image_tensor,
reference_inputs_fn=reference_inputs_resize_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
test_marks=[
xfail_jit_python_scalar_arg("size"),
],
Expand All @@ -339,7 +338,7 @@ def reference_inputs_resize_bounding_box():
sample_inputs_fn=sample_inputs_resize_mask,
reference_fn=reference_resize_mask,
reference_inputs_fn=reference_inputs_resize_mask,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
test_marks=[
xfail_jit_python_scalar_arg("size"),
],
Expand Down Expand Up @@ -556,7 +555,7 @@ def sample_inputs_affine_video():
sample_inputs_fn=sample_inputs_affine_image_tensor,
reference_fn=pil_reference_wrapper(F.affine_image_pil),
reference_inputs_fn=reference_inputs_affine_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
test_marks=[
xfail_jit_python_scalar_arg("shear"),
xfail_jit_tuple_instead_of_list("fill"),
Expand All @@ -569,7 +568,9 @@ def sample_inputs_affine_video():
sample_inputs_fn=sample_inputs_affine_bounding_box,
reference_fn=reference_affine_bounding_box,
reference_inputs_fn=reference_inputs_affine_bounding_box,
closeness_kwargs=dict(atol=1, rtol=0),
closeness_kwargs={
(("TestKernels", "test_against_reference"), torch.int64, "cpu"): dict(atol=1, rtol=0),
},
test_marks=[
xfail_jit_python_scalar_arg("shear"),
],
Expand All @@ -579,7 +580,7 @@ def sample_inputs_affine_video():
sample_inputs_fn=sample_inputs_affine_mask,
reference_fn=reference_affine_mask,
reference_inputs_fn=reference_inputs_resize_mask,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
test_marks=[
xfail_jit_python_scalar_arg("shear"),
],
Expand Down Expand Up @@ -668,7 +669,7 @@ def sample_inputs_convert_color_space_video():
sample_inputs_fn=sample_inputs_convert_color_space_image_tensor,
reference_fn=reference_convert_color_space_image_tensor,
reference_inputs_fn=reference_inputs_convert_color_space_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
),
KernelInfo(
F.convert_color_space_video,
Expand Down Expand Up @@ -729,7 +730,7 @@ def reference_vertical_flip_bounding_box(bounding_box, *, format, spatial_size):
sample_inputs_fn=sample_inputs_vertical_flip_image_tensor,
reference_fn=pil_reference_wrapper(F.vertical_flip_image_pil),
reference_inputs_fn=reference_inputs_vertical_flip_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
),
KernelInfo(
F.vertical_flip_bounding_box,
Expand Down Expand Up @@ -820,7 +821,7 @@ def sample_inputs_rotate_video():
sample_inputs_fn=sample_inputs_rotate_image_tensor,
reference_fn=pil_reference_wrapper(F.rotate_image_pil),
reference_inputs_fn=reference_inputs_rotate_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
test_marks=[
xfail_jit_tuple_instead_of_list("fill"),
# TODO: check if this is a regression since it seems that should be supported if `int` is ok
Expand All @@ -836,7 +837,7 @@ def sample_inputs_rotate_video():
sample_inputs_fn=sample_inputs_rotate_mask,
reference_fn=reference_rotate_mask,
reference_inputs_fn=reference_inputs_rotate_mask,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
),
KernelInfo(
F.rotate_video,
Expand Down Expand Up @@ -918,7 +919,7 @@ def reference_inputs_crop_bounding_box():
sample_inputs_fn=sample_inputs_crop_image_tensor,
reference_fn=pil_reference_wrapper(F.crop_image_pil),
reference_inputs_fn=reference_inputs_crop_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
),
KernelInfo(
F.crop_bounding_box,
Expand All @@ -931,7 +932,7 @@ def reference_inputs_crop_bounding_box():
sample_inputs_fn=sample_inputs_crop_mask,
reference_fn=pil_reference_wrapper(F.crop_image_pil),
reference_inputs_fn=reference_inputs_crop_mask,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
),
KernelInfo(
F.crop_video,
Expand Down Expand Up @@ -1010,7 +1011,7 @@ def sample_inputs_resized_crop_video():
sample_inputs_fn=sample_inputs_resized_crop_image_tensor,
reference_fn=reference_resized_crop_image_tensor,
reference_inputs_fn=reference_inputs_resized_crop_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
),
KernelInfo(
F.resized_crop_bounding_box,
Expand All @@ -1021,7 +1022,7 @@ def sample_inputs_resized_crop_video():
sample_inputs_fn=sample_inputs_resized_crop_mask,
reference_fn=pil_reference_wrapper(F.resized_crop_image_pil),
reference_inputs_fn=reference_inputs_resized_crop_mask,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
),
KernelInfo(
F.resized_crop_video,
Expand Down Expand Up @@ -1144,7 +1145,7 @@ def reference_inputs_pad_bounding_box():
sample_inputs_fn=sample_inputs_pad_image_tensor,
reference_fn=pil_reference_wrapper(F.pad_image_pil),
reference_inputs_fn=reference_inputs_pad_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
test_marks=[
xfail_jit_tuple_instead_of_list("padding"),
xfail_jit_tuple_instead_of_list("fill"),
Expand All @@ -1166,7 +1167,7 @@ def reference_inputs_pad_bounding_box():
sample_inputs_fn=sample_inputs_pad_mask,
reference_fn=pil_reference_wrapper(F.pad_image_pil),
reference_inputs_fn=reference_inputs_pad_mask,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
),
KernelInfo(
F.pad_video,
Expand Down Expand Up @@ -1225,7 +1226,7 @@ def sample_inputs_perspective_video():
sample_inputs_fn=sample_inputs_perspective_image_tensor,
reference_fn=pil_reference_wrapper(F.perspective_image_pil),
reference_inputs_fn=reference_inputs_perspective_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
),
KernelInfo(
F.perspective_bounding_box,
Expand All @@ -1236,7 +1237,7 @@ def sample_inputs_perspective_video():
sample_inputs_fn=sample_inputs_perspective_mask,
reference_fn=pil_reference_wrapper(F.perspective_image_pil),
reference_inputs_fn=reference_inputs_perspective_mask,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
),
KernelInfo(
F.perspective_video,
Expand Down Expand Up @@ -1306,7 +1307,7 @@ def sample_inputs_elastic_video():
sample_inputs_fn=sample_inputs_elastic_image_tensor,
reference_fn=pil_reference_wrapper(F.elastic_image_pil),
reference_inputs_fn=reference_inputs_elastic_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
),
KernelInfo(
F.elastic_bounding_box,
Expand All @@ -1317,7 +1318,7 @@ def sample_inputs_elastic_video():
sample_inputs_fn=sample_inputs_elastic_mask,
reference_fn=pil_reference_wrapper(F.elastic_image_pil),
reference_inputs_fn=reference_inputs_elastic_mask,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
),
KernelInfo(
F.elastic_video,
Expand Down Expand Up @@ -1387,7 +1388,7 @@ def sample_inputs_center_crop_video():
sample_inputs_fn=sample_inputs_center_crop_image_tensor,
reference_fn=pil_reference_wrapper(F.center_crop_image_pil),
reference_inputs_fn=reference_inputs_center_crop_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
test_marks=[
xfail_jit_python_scalar_arg("output_size"),
],
Expand All @@ -1404,7 +1405,7 @@ def sample_inputs_center_crop_video():
sample_inputs_fn=sample_inputs_center_crop_mask,
reference_fn=pil_reference_wrapper(F.center_crop_image_pil),
reference_inputs_fn=reference_inputs_center_crop_mask,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
test_marks=[
xfail_jit_python_scalar_arg("output_size"),
],
Expand Down Expand Up @@ -1441,7 +1442,7 @@ def sample_inputs_gaussian_blur_video():
KernelInfo(
F.gaussian_blur_image_tensor,
sample_inputs_fn=sample_inputs_gaussian_blur_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
test_marks=[
xfail_jit_python_scalar_arg("kernel_size"),
xfail_jit_python_scalar_arg("sigma"),
Expand Down Expand Up @@ -1529,7 +1530,7 @@ def sample_inputs_equalize_video():
sample_inputs_fn=sample_inputs_equalize_image_tensor,
reference_fn=pil_reference_wrapper(F.equalize_image_pil),
reference_inputs_fn=reference_inputs_equalize_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
),
KernelInfo(
F.equalize_video,
Expand Down Expand Up @@ -1566,7 +1567,7 @@ def sample_inputs_invert_video():
sample_inputs_fn=sample_inputs_invert_image_tensor,
reference_fn=pil_reference_wrapper(F.invert_image_pil),
reference_inputs_fn=reference_inputs_invert_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
),
KernelInfo(
F.invert_video,
Expand Down Expand Up @@ -1607,7 +1608,7 @@ def sample_inputs_posterize_video():
sample_inputs_fn=sample_inputs_posterize_image_tensor,
reference_fn=pil_reference_wrapper(F.posterize_image_pil),
reference_inputs_fn=reference_inputs_posterize_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
),
KernelInfo(
F.posterize_video,
Expand Down Expand Up @@ -1651,7 +1652,7 @@ def sample_inputs_solarize_video():
sample_inputs_fn=sample_inputs_solarize_image_tensor,
reference_fn=pil_reference_wrapper(F.solarize_image_pil),
reference_inputs_fn=reference_inputs_solarize_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
),
KernelInfo(
F.solarize_video,
Expand Down Expand Up @@ -1688,7 +1689,7 @@ def sample_inputs_autocontrast_video():
sample_inputs_fn=sample_inputs_autocontrast_image_tensor,
reference_fn=pil_reference_wrapper(F.autocontrast_image_pil),
reference_inputs_fn=reference_inputs_autocontrast_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
),
KernelInfo(
F.autocontrast_video,
Expand Down Expand Up @@ -1729,7 +1730,7 @@ def sample_inputs_adjust_sharpness_video():
sample_inputs_fn=sample_inputs_adjust_sharpness_image_tensor,
reference_fn=pil_reference_wrapper(F.adjust_sharpness_image_pil),
reference_inputs_fn=reference_inputs_adjust_sharpness_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
),
KernelInfo(
F.adjust_sharpness_video,
Expand Down Expand Up @@ -1800,7 +1801,7 @@ def sample_inputs_adjust_brightness_video():
sample_inputs_fn=sample_inputs_adjust_brightness_image_tensor,
reference_fn=pil_reference_wrapper(F.adjust_brightness_image_pil),
reference_inputs_fn=reference_inputs_adjust_brightness_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
),
KernelInfo(
F.adjust_brightness_video,
Expand Down Expand Up @@ -1841,7 +1842,7 @@ def sample_inputs_adjust_contrast_video():
sample_inputs_fn=sample_inputs_adjust_contrast_image_tensor,
reference_fn=pil_reference_wrapper(F.adjust_contrast_image_pil),
reference_inputs_fn=reference_inputs_adjust_contrast_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
),
KernelInfo(
F.adjust_contrast_video,
Expand Down Expand Up @@ -1886,7 +1887,7 @@ def sample_inputs_adjust_gamma_video():
sample_inputs_fn=sample_inputs_adjust_gamma_image_tensor,
reference_fn=pil_reference_wrapper(F.adjust_gamma_image_pil),
reference_inputs_fn=reference_inputs_adjust_gamma_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
),
KernelInfo(
F.adjust_gamma_video,
Expand Down Expand Up @@ -1927,7 +1928,7 @@ def sample_inputs_adjust_hue_video():
sample_inputs_fn=sample_inputs_adjust_hue_image_tensor,
reference_fn=pil_reference_wrapper(F.adjust_hue_image_pil),
reference_inputs_fn=reference_inputs_adjust_hue_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
),
KernelInfo(
F.adjust_hue_video,
Expand Down Expand Up @@ -1967,7 +1968,7 @@ def sample_inputs_adjust_saturation_video():
sample_inputs_fn=sample_inputs_adjust_saturation_image_tensor,
reference_fn=pil_reference_wrapper(F.adjust_saturation_image_pil),
reference_inputs_fn=reference_inputs_adjust_saturation_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
),
KernelInfo(
F.adjust_saturation_video,
Expand Down Expand Up @@ -2061,7 +2062,7 @@ def sample_inputs_ten_crop_video():
reference_fn=pil_reference_wrapper(F.five_crop_image_pil),
reference_inputs_fn=reference_inputs_five_crop_image_tensor,
test_marks=_common_five_ten_crop_marks,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
),
KernelInfo(
F.five_crop_video,
Expand All @@ -2074,7 +2075,7 @@ def sample_inputs_ten_crop_video():
reference_fn=pil_reference_wrapper(F.ten_crop_image_pil),
reference_inputs_fn=reference_inputs_ten_crop_image_tensor,
test_marks=_common_five_ten_crop_marks,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
),
KernelInfo(
F.ten_crop_video,
Expand Down
Loading