Skip to content

expose some prototype transforms utils #6989

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
)
from torchvision.ops.boxes import box_iou
from torchvision.prototype import features, transforms
from torchvision.prototype.transforms._utils import _isinstance
from torchvision.prototype.transforms.utils import check_type
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image

BATCH_EXTRA_DIMS = [extra_dims for extra_dims in DEFAULT_EXTRA_DIMS if extra_dims]
Expand Down Expand Up @@ -1860,7 +1860,7 @@ def test_permute_dimensions(dims, inverse_dims):
value_type = type(value)
transformed_value = transformed_sample[key]

if _isinstance(value, (features.Image, features.is_simple_tensor, features.Video)):
if check_type(value, (features.Image, features.is_simple_tensor, features.Video)):
if transform.dims.get(value_type) is not None:
assert transformed_value.permute(inverse_dims[value_type]).equal(value)
assert type(transformed_value) == torch.Tensor
Expand Down Expand Up @@ -1893,7 +1893,7 @@ def test_transpose_dimensions(dims):
transformed_value = transformed_sample[key]

transposed_dims = transform.dims.get(value_type)
if _isinstance(value, (features.Image, features.is_simple_tensor, features.Video)):
if check_type(value, (features.Image, features.is_simple_tensor, features.Video)):
if transposed_dims is not None:
assert transformed_value.transpose(*transposed_dims).equal(value)
assert type(transformed_value) == torch.Tensor
Expand Down
2 changes: 1 addition & 1 deletion test/test_prototype_transforms_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
from torchvision._utils import sequence_to_str
from torchvision.prototype import features, transforms as prototype_transforms
from torchvision.prototype.transforms import functional as prototype_F
from torchvision.prototype.transforms._utils import query_spatial_size
from torchvision.prototype.transforms.functional import to_image_pil
from torchvision.prototype.transforms.utils import query_spatial_size
from torchvision.transforms import functional as legacy_F

DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=[features.ColorSpace.RGB], extra_dims=[(4,)])
Expand Down
2 changes: 1 addition & 1 deletion test/test_prototype_transforms_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from prototype_common_utils import make_bounding_box, make_detection_mask, make_image

from torchvision.prototype import features
from torchvision.prototype.transforms._utils import has_all, has_any
from torchvision.prototype.transforms.functional import to_image_pil
from torchvision.prototype.transforms.utils import has_all, has_any


IMAGE = make_image(color_space=features.ColorSpace.RGB)
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from torchvision.transforms import AutoAugmentPolicy, InterpolationMode # usort: skip

from . import functional # usort: skip
from . import functional, utils # usort: skip

from ._transform import Transform # usort: skip
from ._presets import StereoMatching # usort: skip
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/transforms/_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torchvision.prototype.transforms import functional as F, InterpolationMode

from ._transform import _RandomApplyTransform
from ._utils import has_any, query_chw, query_spatial_size
from .utils import has_any, query_chw, query_spatial_size


class RandomErasing(_RandomApplyTransform):
Expand Down
5 changes: 3 additions & 2 deletions torchvision/prototype/transforms/_auto_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from torchvision.prototype.transforms.functional._meta import get_spatial_size
from torchvision.transforms import functional_tensor as _FT

from ._utils import _isinstance, _setup_fill_arg
from ._utils import _setup_fill_arg
from .utils import check_type


class _AutoAugmentBase(Transform):
Expand Down Expand Up @@ -38,7 +39,7 @@ def _flatten_and_extract_image_or_video(

image_or_videos = []
for idx, inpt in enumerate(flat_inputs):
if _isinstance(inpt, (features.Image, PIL.Image.Image, features.is_simple_tensor, features.Video)):
if check_type(inpt, (features.Image, PIL.Image.Image, features.is_simple_tensor, features.Video)):
image_or_videos.append((idx, inpt))
elif isinstance(inpt, unsupported_types):
raise TypeError(f"Inputs of type {type(inpt).__name__} are not supported by {type(self).__name__}()")
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/transforms/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torchvision.prototype.transforms import functional as F, Transform

from ._transform import _RandomApplyTransform
from ._utils import query_chw
from .utils import query_chw


class ColorJitter(Transform):
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/transforms/_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from typing_extensions import Literal

from ._transform import _RandomApplyTransform
from ._utils import query_chw
from .utils import query_chw


class ToTensor(Transform):
Expand Down
5 changes: 1 addition & 4 deletions torchvision/prototype/transforms/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,8 @@
_setup_fill_arg,
_setup_float_or_seq,
_setup_size,
has_all,
has_any,
query_bounding_box,
query_spatial_size,
)
from .utils import has_all, has_any, query_bounding_box, query_spatial_size


class RandomHorizontalFlip(_RandomApplyTransform):
Expand Down
3 changes: 2 additions & 1 deletion torchvision/prototype/transforms/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, Transform

from ._utils import _get_defaultdict, _setup_float_or_seq, _setup_size, has_any, query_bounding_box
from ._utils import _get_defaultdict, _setup_float_or_seq, _setup_size
from .utils import has_any, query_bounding_box


class Identity(Transform):
Expand Down
8 changes: 3 additions & 5 deletions torchvision/prototype/transforms/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from torch import nn
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision.prototype.transforms._utils import _isinstance
from torchvision.prototype.transforms.utils import check_type
from torchvision.utils import _log_api_usage_once


Expand Down Expand Up @@ -36,8 +36,7 @@ def forward(self, *inputs: Any) -> Any:
params = self._get_params(flat_inputs)

flat_outputs = [
self._transform(inpt, params) if _isinstance(inpt, self._transformed_types) else inpt
for inpt in flat_inputs
self._transform(inpt, params) if check_type(inpt, self._transformed_types) else inpt for inpt in flat_inputs
]

return tree_unflatten(flat_outputs, spec)
Expand Down Expand Up @@ -80,8 +79,7 @@ def forward(self, *inputs: Any) -> Any:
params = self._get_params(flat_inputs)

flat_outputs = [
self._transform(inpt, params) if _isinstance(inpt, self._transformed_types) else inpt
for inpt in flat_inputs
self._transform(inpt, params) if check_type(inpt, self._transformed_types) else inpt for inpt in flat_inputs
]

return tree_unflatten(flat_outputs, spec)
68 changes: 1 addition & 67 deletions torchvision/prototype/transforms/_utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
import functools
import numbers
from collections import defaultdict
from typing import Any, Callable, Dict, List, Sequence, Tuple, Type, TypeVar, Union
from typing import Any, Dict, Sequence, Type, TypeVar, Union

import PIL.Image

from torchvision._utils import sequence_to_str
from torchvision.prototype import features
from torchvision.prototype.features._feature import FillType, FillTypeJIT

from torchvision.prototype.transforms.functional._meta import get_dimensions, get_spatial_size
from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401

from typing_extensions import Literal
Expand Down Expand Up @@ -100,65 +96,3 @@ def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None:
def _check_padding_mode_arg(padding_mode: Literal["constant", "edge", "reflect", "symmetric"]) -> None:
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")


def query_bounding_box(flat_inputs: List[Any]) -> features.BoundingBox:
bounding_boxes = [inpt for inpt in flat_inputs if isinstance(inpt, features.BoundingBox)]
if not bounding_boxes:
raise TypeError("No bounding box was found in the sample")
elif len(bounding_boxes) > 1:
raise ValueError("Found multiple bounding boxes in the sample")
return bounding_boxes.pop()


def query_chw(flat_inputs: List[Any]) -> Tuple[int, int, int]:
chws = {
tuple(get_dimensions(inpt))
for inpt in flat_inputs
if isinstance(inpt, (features.Image, PIL.Image.Image, features.Video)) or features.is_simple_tensor(inpt)
}
if not chws:
raise TypeError("No image or video was found in the sample")
elif len(chws) > 1:
raise ValueError(f"Found multiple CxHxW dimensions in the sample: {sequence_to_str(sorted(chws))}")
c, h, w = chws.pop()
return c, h, w


def query_spatial_size(flat_inputs: List[Any]) -> Tuple[int, int]:
sizes = {
tuple(get_spatial_size(inpt))
for inpt in flat_inputs
if isinstance(inpt, (features.Image, PIL.Image.Image, features.Video, features.Mask, features.BoundingBox))
or features.is_simple_tensor(inpt)
}
if not sizes:
raise TypeError("No image, video, mask or bounding box was found in the sample")
elif len(sizes) > 1:
raise ValueError(f"Found multiple HxW dimensions in the sample: {sequence_to_str(sorted(sizes))}")
h, w = sizes.pop()
return h, w


def _isinstance(obj: Any, types_or_checks: Tuple[Union[Type, Callable[[Any], bool]], ...]) -> bool:
for type_or_check in types_or_checks:
if isinstance(obj, type_or_check) if isinstance(type_or_check, type) else type_or_check(obj):
return True
return False


def has_any(flat_inputs: List[Any], *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool:
for inpt in flat_inputs:
if _isinstance(inpt, types_or_checks):
return True
return False


def has_all(flat_inputs: List[Any], *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool:
for type_or_check in types_or_checks:
for inpt in flat_inputs:
if isinstance(inpt, type_or_check) if isinstance(type_or_check, type) else type_or_check(inpt):
break
else:
return False
return True
69 changes: 69 additions & 0 deletions torchvision/prototype/transforms/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from typing import Any, Callable, List, Tuple, Type, Union

import PIL.Image

from torchvision._utils import sequence_to_str
from torchvision.prototype import features
from torchvision.prototype.transforms.functional import get_dimensions, get_spatial_size


def query_bounding_box(flat_inputs: List[Any]) -> features.BoundingBox:
bounding_boxes = [inpt for inpt in flat_inputs if isinstance(inpt, features.BoundingBox)]
if not bounding_boxes:
raise TypeError("No bounding box was found in the sample")
elif len(bounding_boxes) > 1:
raise ValueError("Found multiple bounding boxes in the sample")
return bounding_boxes.pop()


def query_chw(flat_inputs: List[Any]) -> Tuple[int, int, int]:
chws = {
tuple(get_dimensions(inpt))
for inpt in flat_inputs
if isinstance(inpt, (features.Image, PIL.Image.Image, features.Video)) or features.is_simple_tensor(inpt)
}
if not chws:
raise TypeError("No image or video was found in the sample")
elif len(chws) > 1:
raise ValueError(f"Found multiple CxHxW dimensions in the sample: {sequence_to_str(sorted(chws))}")
c, h, w = chws.pop()
return c, h, w


def query_spatial_size(flat_inputs: List[Any]) -> Tuple[int, int]:
sizes = {
tuple(get_spatial_size(inpt))
for inpt in flat_inputs
if isinstance(inpt, (features.Image, PIL.Image.Image, features.Video, features.Mask, features.BoundingBox))
or features.is_simple_tensor(inpt)
}
if not sizes:
raise TypeError("No image, video, mask or bounding box was found in the sample")
elif len(sizes) > 1:
raise ValueError(f"Found multiple HxW dimensions in the sample: {sequence_to_str(sorted(sizes))}")
h, w = sizes.pop()
return h, w


def check_type(obj: Any, types_or_checks: Tuple[Union[Type, Callable[[Any], bool]], ...]) -> bool:
for type_or_check in types_or_checks:
if isinstance(obj, type_or_check) if isinstance(type_or_check, type) else type_or_check(obj):
return True
return False


def has_any(flat_inputs: List[Any], *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool:
for inpt in flat_inputs:
if check_type(inpt, types_or_checks):
return True
return False


def has_all(flat_inputs: List[Any], *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool:
for type_or_check in types_or_checks:
for inpt in flat_inputs:
if isinstance(inpt, type_or_check) if isinstance(type_or_check, type) else type_or_check(inpt):
break
else:
return False
return True