Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove unnecessary checks from pad_image_tensor #6894

Merged
merged 8 commits into from
Nov 3, 2022
Merged

Conversation

pmeier
Copy link
Collaborator

@pmeier pmeier commented Nov 2, 2022

Closes #6882. As discussed offline, since our implementation already only converts the dtype for non-constant padding

if (padding_mode != "constant") and img.dtype not in (torch.float32, torch.float64):
# Here we temporary cast input tensor to float
# until pytorch issue is resolved :
# https://github.com/pytorch/pytorch/issues/40763
need_cast = True
img = img.to(torch.float32)

there is not much performance to gain on our side. As explained in #6818 (comment), the two possible optimization vectors are edge cases and have to happen in PyTorch core and thus would affect v1 as well as v2.

Thus, this PR is mostly refactoring.

[--------------------------------- pad_image_tensor refactor ----------------------------------]
                                      |     v2 (main)     |   v2 (perf-pad)   |         v1      
1 threads: -------------------------------------------------------------------------------------
      (3, 512, 512), uint8, cpu       |   308 (+-  9) us  |   309 (+-  8) us  |   305 (+-  2) us
      (3, 512, 512), float32, cpu     |   115 (+-  1) us  |   117 (+-  2) us  |   118 (+-  2) us
      (5, 3, 512, 512), uint8, cpu    |  1469 (+- 38) us  |  1476 (+-  7) us  |  1475 (+-  9) us
      (5, 3, 512, 512), float32, cpu  |   910 (+- 89) us  |   893 (+- 64) us  |   919 (+- 60) us

Times are in microseconds (us).

cc @vfdev-5 @datumbox @bjuncek

@@ -645,7 +645,32 @@ def rotate(
return rotate_image_pil(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)


pad_image_pil = _FP.pad
def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've merged the check for invalid types as well as wrong lengths here, since this function is also used by pad_bounding_box and that currently doesn't have these checks.

new_height = height + top + bottom
new_width = width + left + right

return image.reshape(shape[:-3] + (num_channels, new_height, new_width))


# TODO: This should be removed once pytorch pad supports non-scalar padding values
# TODO: This should be removed once torch_pad supports non-scalar padding values
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vfdev-5 Is there an issue for that?

@@ -711,6 +771,9 @@ def _pad_with_vector_fill(
return output


pad_image_pil = _FP.pad
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor cleanup, since we normally define the PIL kernel below the tensor one.

elif isinstance(fill, (int, float)) or len(fill) == 1:
fill_number = fill[0] if isinstance(fill, list) else fill
return _pad_with_scalar_fill(image, padding, fill=fill_number, padding_mode=padding_mode)
return _pad_with_scalar_fill(image, torch_padding, fill=0, padding_mode=padding_mode)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why fill=0 instead of fill=None as originally ? Do we still need this workaround ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously, this was handled by _FT.pad:

if fill is None:
fill = 0

Since we are no longer calling it, we need to handle it our own. For some reason, the fill overwrite above did not work for me while developing and so I went this way. Rechecking today, it seems to work and it was probably caused by something else. I've fixed this in 1622cd6.

More general though, why are we allowing None in the first place and even use it as default if we all we do with it is to map it to 0? I faintly remember there were discussions about it, but I think I was out of the loop on them. Is there a public discussion that you can point me to or did this happen offline? Intuitively, I would remove the None as valid value and just use 0 as default.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More general though, why are we allowing None in the first place and even use it as default if we all we do with it is to map it to 0?

Origins for fill=None starts from your PR #1760 on rotate op.
See this #6623 for later discussions

Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM on green. Just a question:

torchvision/prototype/transforms/functional/_geometry.py Outdated Show resolved Hide resolved
torchvision/prototype/transforms/functional/_geometry.py Outdated Show resolved Hide resolved
Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still LGTM on green CI.

@pmeier pmeier merged commit 4d085f2 into pytorch:main Nov 3, 2022
@pmeier pmeier deleted the perf/pad branch November 3, 2022 15:54
facebook-github-bot pushed a commit that referenced this pull request Nov 4, 2022
Summary:
* remove unnecessary changes from pad_image_tensor

* cleanup

* fix fill=None workaround

* address review comments

* remove more xfails

Reviewed By: datumbox

Differential Revision: D41020544

fbshipit-source-id: d677ea0dd79f8e8055ed7c36a65a0bb980e3b578
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Optimize _FT.pad() on Transforms V2
4 participants