Skip to content

[prototype] Add support of inplace on convert_format_bounding_box #6858

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Oct 31, 2022

Conversation

datumbox
Copy link
Contributor

@datumbox datumbox commented Oct 28, 2022

Related to #6818

The current implementations for converting the coordinates in convert_format_bounding_box produce lots of intermediate results. We copy a lot and often have to cast from ints to floats and back. This leads to slow kernels. The kernels can be optimized by avoiding unnecessary castings and implement safe in-place operations on the intermediate tensors.

Moreover convert_format_bounding_box is often used internally in multiple kernels to bring the bboxes on the correct format. To avoid modifying the data inplace, the transform always clones. This is wasteful and can be avoided by adding inplace to the specific hot kernel. The plan of this PR is not to expose inplace to Transform Classes where this can be dangerous. We only use it on selective implementations where this is safe.

This PR:

  • Rewrites the kernels to speed them up via safe inplace ops (similarly to what we do in other operators on the intermediate results).
  • Adds inplace on the convert_format_bounding_box to make it possible to perform the entire operation inplace
  • Sets inplace=True in all usages where this is safe to do so
  • Adds a couple of as_subclass calls missed for BBoxes.

Overall the new changes improve the speed by 30-50%, depending on the conversion, input type and in-place configuration.

cc @vfdev-5 @bjuncek @pmeier

Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks OK to me, thanks @datumbox

@datumbox datumbox marked this pull request as draft October 28, 2022 12:49
@datumbox datumbox requested a review from vfdev-5 October 28, 2022 12:49
@datumbox
Copy link
Contributor Author

datumbox commented Oct 28, 2022

The new implementation seems to produce identical results to the normal ops of TorchVision:

import torch
from torchvision.ops._box_convert import _box_xyxy_to_cxcywh, _box_cxcywh_to_xyxy
from torchvision.prototype.transforms.functional._meta import _xyxy_to_cxcywh, _cxcywh_to_xyxy


def make_data(device, n=10, scale=200):
    x1y1 = torch.randn((n, 2)).mul(scale).abs() + 1
    wh = torch.randn((n, 2)).mul(scale).abs()
    x2y2 = x1y1 + wh

    xyxy = torch.cat([x1y1, x2y2], dim=-1)

    cxcy = (x2y2 + x1y1) / 2
    cxcywh = torch.cat([cxcy, wh], dim=-1)
    return xyxy.to(device), cxcywh.to(device)


torch.manual_seed(0)
for i in range(1000):
    _xyxy, _cxcywh = make_data("cpu")
    for dtype in [torch.float32, torch.float64, torch.int32, torch.int64]:
        xyxy = _xyxy.to(dtype)
        cxcywh = _cxcywh.to(dtype)
        try:
            torch.testing.assert_close(_box_xyxy_to_cxcywh(xyxy).to(dtype), _xyxy_to_cxcywh(xyxy, False), msg=f"Fail: xyxy_to_cxcywh - {i} - {dtype}")
            torch.testing.assert_close(_box_cxcywh_to_xyxy(cxcywh).to(dtype), _cxcywh_to_xyxy(cxcywh, False), msg=f"Fail: cxcywh_to_xyxy - {i} - {dtype}")
        except AssertionError as e:
            print(e)
            exit()

print("OK")

@datumbox
Copy link
Contributor Author

datumbox commented Oct 28, 2022

The new implementation is faster than the existing one (here I'm using torchvision.ops equivalents which are identical to our previous kernels):

[------------------ xyxy_to_cxcywh cpu -----------------]
                     |       old        |       new      
1 threads: ----------------------------------------------
      torch.float32  |   29 (+-  0) us  |   24 (+-  0) us
      torch.float64  |   29 (+-  0) us  |   24 (+-  0) us
      torch.int32    |   41 (+-  0) us  |   24 (+-  0) us
      torch.int64    |   41 (+-  0) us  |   20 (+-  0) us
6 threads: ----------------------------------------------
      torch.float32  |   29 (+-  0) us  |   24 (+-  0) us
      torch.float64  |   29 (+-  0) us  |   24 (+-  0) us
      torch.int32    |   41 (+-  0) us  |   24 (+-  0) us
      torch.int64    |   41 (+-  0) us  |   20 (+-  0) us

Times are in microseconds (us).

[----------------- xyxy_to_cxcywh cuda -----------------]
                     |       old        |       new      
1 threads: ----------------------------------------------
      torch.float32  |   60 (+-  0) us  |   47 (+-  0) us
      torch.float64  |   60 (+-  0) us  |   47 (+-  0) us
      torch.int32    |   87 (+-  1) us  |   47 (+-  0) us
      torch.int64    |   86 (+-  0) us  |   47 (+-  0) us
6 threads: ----------------------------------------------
      torch.float32  |   60 (+-  1) us  |   47 (+-  1) us
      torch.float64  |   60 (+-  1) us  |   47 (+-  0) us
      torch.int32    |   87 (+-  2) us  |   47 (+-  0) us
      torch.int64    |   87 (+-  2) us  |   47 (+-  0) us

Times are in microseconds (us).

[------------------ cxcywh_to_xyxy cpu -----------------]
                     |       old        |       new      
1 threads: ----------------------------------------------
      torch.float32  |   38 (+-  0) us  |   21 (+-  0) us
      torch.float64  |   30 (+-  0) us  |   21 (+-  0) us
      torch.int32    |   54 (+-  0) us  |   21 (+-  0) us
      torch.int64    |   54 (+-  0) us  |   19 (+-  0) us
6 threads: ----------------------------------------------
      torch.float32  |   38 (+-  0) us  |   21 (+-  0) us
      torch.float64  |   30 (+-  0) us  |   21 (+-  0) us
      torch.int32    |   54 (+-  1) us  |   21 (+-  0) us
      torch.int64    |   54 (+-  0) us  |   19 (+-  0) us

Times are in microseconds (us).

[----------------- cxcywh_to_xyxy cuda -----------------]
                     |       old        |       new      
1 threads: ----------------------------------------------
      torch.float32  |   77 (+-  0) us  |   47 (+-  0) us
      torch.float64  |   76 (+-  0) us  |   47 (+-  0) us
      torch.int32    |   87 (+-  1) us  |   47 (+-  0) us
      torch.int64    |   86 (+-  1) us  |   47 (+-  0) us
6 threads: ----------------------------------------------
      torch.float32  |   77 (+-  2) us  |   47 (+-  0) us
      torch.float64  |   76 (+-  2) us  |   47 (+-  1) us
      torch.int32    |   87 (+-  2) us  |   47 (+-  1) us
      torch.int64    |   86 (+-  2) us  |   47 (+-  0) us

Times are in microseconds (us).

Note that these benchmarks don't pass inplace=True on the kernels. These are harder to benchmark due to the modification of the data. Obviously by avoiding the clone, we are even faster.

Copy link
Contributor Author

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding some comments to explain the implementations:


half_wh = cxcywh[..., 2:].div(-2, rounding_mode=None if cxcywh.is_floating_point() else "floor").abs_()
Copy link
Contributor Author

@datumbox datumbox Oct 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a trick to do division by 2 and ceil if we deal with integers. Unfortunately rounding_mode doesn't support ceiling so I'm doing division by negatives and flooring + abs which has the same effect. This trick is slower than just doing simple division by 2 and floor, but the faster version leads to a 1 pixel misalignment with the existing behaviour in TorchVision. Obviously since this was never released, we could have just implemented the faster version. But because we already offer operators at torchvision.ops with the same behaviour, I opted to align with them. See detailed tests above for examples.

For the record the faster version is:

half_wh = cxcywh[..., 2:].div(2, rounding_mode=None if cxcywh.is_floating_point() else "floor")

Comment on lines 139 to 140
cxcywh[..., :2].sub_(half_wh)
cxcywh[..., 2:].add_(cxcywh[..., :2])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another trick to achieve in-place op with minimum writes.

On the first line, we subtract in-place on cx/cy the half width/height. This gives use x1 and y1.
On the second line, we take the width/height and add to it in-place x1/y1. This gives us x2 and y2.

Comment on lines 149 to 150
xyxy[..., 2:].sub_(xyxy[..., :2])
xyxy[..., :2].mul_(2).add_(xyxy[..., 2:]).div_(2, rounding_mode=None if xyxy.is_floating_point() else "floor")
Copy link
Contributor Author

@datumbox datumbox Oct 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the reverse of the trick above.

On the first line we subtract in place from x2/y2 the x1y1. This give us width and height.
On the second line, we take the x1/y1, multiply it by 2, add the width/height and then divide by 2. Effectively we do:

(x1 * 2 + width) / 2 = x1 + width / 2 = x1 + (x2-x1)/2 = (x1 + x2)/2 = cx

The division trick is again applied to ensure we handle the integers without further castings and in-place.

@datumbox datumbox requested a review from pmeier October 28, 2022 17:28
@datumbox datumbox marked this pull request as ready for review October 28, 2022 17:28
@@ -815,18 +811,18 @@ def crop_bounding_box(
) -> Tuple[torch.Tensor, Tuple[int, int]]:
# TODO: Investigate if it makes sense from a performance perspective to have an implementation for every
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes we don't have every conversion. I remember when adding I had discussion with Francisco that all conversions can simply go through intermediate xyxy conversion for sake of less complications.

@datumbox
Copy link
Contributor Author

datumbox commented Oct 31, 2022

I've added comments on the implementation to explain the logic of the intermediate inplace operations.

Here is also a benchmark that compares the new implementation with inplace=False vs True. To avoid data modification, I do compare _cxcywh_to_xyxy(_xyxy_to_cxcywh(x, False), False) vs _cxcywh_to_xyxy(_xyxy_to_cxcywh(x, True), True). Hence the following benchmark uses 2 methods and measures the effect of 2x clone() calls versus 0x. As expected avoiding the clone is 13% faster on CPU and 21% on GPU:

[--------------------- inplace cpu ---------------------]
                     |      False       |       True     
1 threads: ----------------------------------------------
      torch.float32  |   46 (+-  0) us  |   40 (+-  0) us
      torch.float64  |   46 (+-  0) us  |   41 (+-  0) us
      torch.int32    |   46 (+-  0) us  |   41 (+-  0) us
      torch.int64    |   40 (+-  0) us  |   34 (+-  0) us
6 threads: ----------------------------------------------
      torch.float32  |   46 (+-  1) us  |   41 (+-  0) us
      torch.float64  |   46 (+-  0) us  |   41 (+-  1) us
      torch.int32    |   46 (+-  0) us  |   41 (+-  0) us
      torch.int64    |   40 (+-  0) us  |   34 (+-  0) us

Times are in microseconds (us).

[--------------------- inplace cuda --------------------]
                     |      False       |       True     
1 threads: ----------------------------------------------
      torch.float32  |   96 (+-  0) us  |   76 (+-  1) us
      torch.float64  |   96 (+-  1) us  |   76 (+-  0) us
      torch.int32    |   97 (+-  0) us  |   76 (+-  0) us
      torch.int64    |   96 (+-  0) us  |   76 (+-  0) us
6 threads: ----------------------------------------------
      torch.float32  |   96 (+-  1) us  |   76 (+-  2) us
      torch.float64  |   96 (+-  2) us  |   76 (+-  2) us
      torch.int32    |   97 (+-  2) us  |   76 (+-  2) us
      torch.int64    |   96 (+-  2) us  |   76 (+-  1) us

Times are in microseconds (us).

Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks !

@datumbox datumbox merged commit 70faba9 into pytorch:main Oct 31, 2022
@datumbox datumbox deleted the prototype/bbox_inplace branch October 31, 2022 11:30
facebook-github-bot pushed a commit that referenced this pull request Oct 31, 2022
…ng_box` (#6858)

Summary:
* Add support of inplace on `convert_format_bounding_box`

* Move `as_subclass` calls to `F` invocations

* Fix bug.

* Fix _cxcywh_to_xyxy.

* Fixing _xyxy_to_cxcywh.

* Adding comments.

Reviewed By: datumbox

Differential Revision: D40851019

fbshipit-source-id: 965ccdb6f1fc32bff74c1b67314b3d4224952ffa
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants