Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] 5522 random crop port #5555

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class TestSmoke:
transforms.HorizontalFlip(),
transforms.Resize([16, 16]),
transforms.CenterCrop([16, 16]),
transforms.RandomCrop([16, 16], pad_if_needed=True),
transforms.ConvertImageDtype(),
)
def test_common(self, transform, input):
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ._augment import RandomErasing, RandomMixup, RandomCutmix
from ._auto_augment import RandAugment, TrivialAugmentWide, AutoAugment, AugMix
from ._container import Compose, RandomApply, RandomChoice, RandomOrder
from ._geometry import HorizontalFlip, Resize, CenterCrop, RandomResizedCrop, FiveCrop, TenCrop, BatchMultiCrop
from ._geometry import HorizontalFlip, Resize, CenterCrop, RandomResizedCrop, FiveCrop, TenCrop, BatchMultiCrop, RandomCrop
from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace
from ._misc import Identity, Normalize, ToDtype, Lambda
from ._presets import (
Expand Down
95 changes: 94 additions & 1 deletion torchvision/prototype/transforms/_geometry.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import collections.abc
import math
import warnings
from typing import Any, Dict, List, Union, Sequence, Tuple, cast
from typing import Any, Dict, List, Union, Sequence, Tuple, cast, Literal, Optional

import PIL.Image
import torch
Expand Down Expand Up @@ -256,3 +256,96 @@ def apply_recursively(obj: Any) -> Any:
return obj

return apply_recursively(inputs if len(inputs) > 1 else inputs[0])


class RandomCrop(Transform):
def __init__(
self,
size: Union[int, Sequence[int]],
padding: Optional[Sequence[int]] = None,
pad_if_needed: bool = False,
fill: Union[int, str, Sequence[int]] = 0,
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> None:
super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")

self.padding = padding
self.pad_if_needed = pad_if_needed
self.fill = fill
self.padding_mode = padding_mode

def _get_crop_parameters(self, image: Any) -> Dict[str, Any]:
"""Get parameters for ``crop`` for a random crop.
Args:
sample (PIL Image, Tensor or features.Image): Image to be cropped.
Returns:
dict: Dict containing 'top', 'left', 'height', and 'width'
"""

_, h, w = get_image_dimensions(image)

th, tw = self.size

if h + 1 < th or w + 1 < tw:
raise ValueError(f"Required crop size {(th, tw)} is larger then input image size {(h, w)}")

if w == tw and h == th:
return dict(top=0, left=0, height=h, width=w)

i = torch.randint(0, h - th + 1, size=(1,)).item()
j = torch.randint(0, w - tw + 1, size=(1,)).item()
return dict(top=i, left=j, height=th, width=tw)

def _transform(self, input: Any, params: Dict[str, Any]) -> Any:

if isinstance(input, features.Image):
output = F.random_pad_image_tensor(
input,
output_size=self.size,
image_size=get_image_dimensions(input),
padding=cast(List[int], tuple(self.padding)),
pad_if_needed=self.pad_if_needed,
fill=self.fill,
padding_mode=self.padding_mode,
)
input = features.Image.new_like(input, output)
elif isinstance(input, PIL.Image.Image):
input = F.random_pad_image_pil(
input,
output_size=self.size,
image_size=get_image_dimensions(input),
padding=self.padding,
pad_if_needed=self.pad_if_needed,
fill=self.fill,
padding_mode=self.padding_mode,
)
elif is_simple_tensor(input):
input = F.random_pad_image_tensor(
input,
output_size=self.size,
image_size=get_image_dimensions(input),
padding=self.padding,
pad_if_needed=self.pad_if_needed,
fill=self.fill, # TODO: should be converted to number
padding_mode=self.padding_mode,
)

params.update(self._get_crop_parameters(input))

if isinstance(input, features.Image):
output = F.crop_image_tensor(input, **params)
return features.Image.new_like(input, output)
elif isinstance(input, PIL.Image.Image):
return F.crop_image_pil(input, **params)
elif is_simple_tensor(input):
return F.crop_image_tensor(input, **params)
else:
return input

def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if has_any(sample, features.BoundingBox, features.SegmentationMask):
raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()")

return super().forward(sample)
2 changes: 2 additions & 0 deletions torchvision/prototype/transforms/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
center_crop_image_pil,
resized_crop_image_tensor,
resized_crop_image_pil,
random_pad_image_tensor,
random_pad_image_pil,
affine_image_tensor,
affine_image_pil,
rotate_image_tensor,
Expand Down
112 changes: 111 additions & 1 deletion torchvision/prototype/transforms/functional/_geometry.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numbers
from typing import Tuple, List, Optional, Sequence, Union
from typing import Tuple, List, Optional, Sequence, Union, Literal

import PIL.Image
import torch
Expand Down Expand Up @@ -390,3 +390,113 @@ def ten_crop_image_pil(img: PIL.Image.Image, size: List[int], vertical_flip: boo
tl_flip, tr_flip, bl_flip, br_flip, center_flip = five_crop_image_pil(img, size)

return [tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip]


def random_crop_image_tensor(
img: torch.Tensor,
top: int,
left: int,
height: int,
width: int,
size: List[int],
padding: Optional[List[int]] = None,
pad_if_needed: bool = False,
fill: int = 0,
padding_mode: str = "constant",
) -> torch.Tensor:

if padding is not None:
img = pad_image_tensor(img, padding, fill, padding_mode)

_, h, w = get_dimensions_image_tensor(img)

# pad the width if needed
if pad_if_needed and w < size[1]:
padding = [size[1] - w, 0]
img = pad_image_tensor(img, padding, fill, padding_mode)

# pad the height if needed
if pad_if_needed and h < size[0]:
padding = [0, size[0] - h]
img = pad_image_tensor(img, padding, fill, padding_mode)

return crop_image_tensor(img, top, left, height, width)


def random_crop_image_pil(
img: PIL.Image.Image,
top: int,
left: int,
height: int,
width: int,
size: List[int],
padding: Optional[List[int]] = None,
pad_if_needed: bool = False,
fill: int = 0,
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> PIL.Image.Image:
if padding is not None:
img = pad_image_pil(img, padding, fill, padding_mode)

_, h, w = get_dimensions_image_pil(img)

# pad the width if needed
if pad_if_needed and w < size[1]:
padding = [size[1] - w, 0]
img = pad_image_pil(img, padding, fill, padding_mode)

# pad the height if needed
if pad_if_needed and h < size[0]:
padding = [0, size[0] - h]
img = pad_image_pil(img, padding, fill, padding_mode)

return crop_image_pil(img, top, left, height, width)


def random_pad_image_tensor(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this? Shouldn't pad_image_tensor be able to handle this? In general, we don't have kernels for random functions. All randomness should be handled in the transform.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @pmeier, sorry I am a bit new to this repo. The reason I created a new function is that:

  1. random_pad_image_tensor does not have the same logic as this function. This function takes into account the output shape required and the current image shape. I wasn't sure if this logic should be in forward.
  2. Seeing the current code, I felt that transform-related code should be in _geometry and that there should be separate functions for tensor and PIL.
  3. _get_params in RandomCrop requires the output from this function. So adding this logic within _transform wouldn't work as the params in _transform would not be valid.

Please do provide me with any other approach you have in mind. I could incorporate those changes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I misjudged the situation. I was not aware that the forward actually modified the image:

if self.padding is not None:
img = F.pad(img, self.padding, self.fill, self.padding_mode)
_, height, width = F.get_dimensions(img)
# pad the width if needed
if self.pad_if_needed and width < self.size[1]:
padding = [self.size[1] - width, 0]
img = F.pad(img, padding, self.fill, self.padding_mode)
# pad the height if needed
if self.pad_if_needed and height < self.size[0]:
padding = [0, self.size[0] - height]
img = F.pad(img, padding, self.fill, self.padding_mode)

This makes things more complicated. cc @datumbox for awareness.

I would move this code into _transform. Although the structure is the same for all possible types, we still need to call different pad kernels. That would be a lot easier if we had the Pad transform from #5521 first. This way we could simply substitute pad_image_*(...) with pad(...) where pad is

pad = functools.partial(
    lambda image, padding: Pad(
        padding,
        fill=self.fill,
        padding_mode=self.padding_mode,
    )(image)
)

and not worry about the dispatch. Thoughts?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi @pmeier , I can keep this PR on hold and work on #5521 first, if it helps.

img: torch.Tensor,
output_size: List[int],
image_size: Tuple[int, int, int],
padding: Optional[Sequence[int]] = None,
pad_if_needed: bool = False,
fill: int = 0,
padding_mode: str = "constant",
) -> torch.Tensor:
_, height, width = image_size

if padding is not None:
img = pad_image_tensor(img, padding, fill, padding_mode)
# pad the width if needed
if pad_if_needed and width < output_size[1]:
padding = [output_size[1] - width, 0]
img = pad_image_tensor(img, padding, fill, padding_mode)
# pad the height if needed
if pad_if_needed and height < output_size[0]:
padding = [0, output_size[0] - height]
img = pad_image_tensor(img, padding, fill, padding_mode)
return img


def random_pad_image_pil(
img: PIL.Image.Image,
output_size: List[int],
image_size: Tuple[int, int, int],
padding: Optional[Sequence[int]] = None,
pad_if_needed: bool = False,
fill: Union[int, str, Sequence[int]] = 0,
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> PIL.Image.Image:

_, height, width = image_size

if padding is not None:
img = pad_image_pil(img, padding, fill, padding_mode)
# pad the width if needed
if pad_if_needed and width < output_size[1]:
padding = [output_size[1] - width, 0]
img = pad_image_pil(img, padding, fill, padding_mode)
# pad the height if needed
if pad_if_needed and height < output_size[0]:
padding = [0, output_size[0] - height]
img = pad_image_pil(img, padding, fill, padding_mode)
return img