-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Torch] Restore class-aware NMS for detection models by graph rewrite #7154
Conversation
@masahi Thanks for the perf improvement. Could you provide the CPU numbers as well? |
@zhiics Sure updated the description. Unfortunately I cannot claim that this is perf improvement. The regression is only 200 us on CPU, so it may be just a measurement noise, though. I have no idea why I'm not getting good speed up. IOU tests, including memory access to boxes should be definitely reduced. The only additional overhead I think of is that the input to NMS is one column wider, due to storing class ids. Performance is not great, but I believe having access to class ids should not be a bad idea... |
@masahi I think this is an plausible as well particularly it is only in the parser. @kevinthesun please help take a look as well. Thanks. |
I should mention that this rewrite is not run by default, so there is no perf risk. |
This is a bit of a shot in the dark. I wonder if we're memory access limited, and so that's why you don't see a performance improvement. When we do the nested loop, we always have to check if the id of instance k matches the id of instance j. Since the input shape is (batch_size, num_anchors, features), and features = 6 here, I wouldn't be surprised if checking the instance of k ends up reading all of the features of k into registers, and that memory read is the expensive operation. Once it's in memory, actually doing the iou calculation is relatively cheap, so skipping it doesn't help that much. |
4280e98
to
8029644
Compare
That's highly possible. Looking at this if condition in the triangle inner loop: tvm/python/tvm/topi/cuda/nms.py Lines 535 to 540 in 9956b5b
previously, That brings me to one of my pain points with our NMS API: I belieave our NMS API needs to be reworked. The current way of packing class ids and scores together with bbox coordinates is a design mistake that we inherited from MXNet. To store class ids, I have to cast ids to float32, update and pass |
07f31f4
to
35a177b
Compare
Sorry for the delay in responding to this, I wanted to look at the frameworks more closely. We currently have 5 importers that leverage NMS: MXNET does multibox_transform_loc and then NMS on the outputs. multi_box_transform_loc converts a 3D array of scores with shape (batch_size, class_num, num_anchors) into a most likely class and score for that class, plus does some coordinate transforms on the box. ONNX takes a 3D tensor of (batch_size, class, num_anchors), does slicing/concatenating with the boxes, and then does a per-class get_valid_counts->non_max. Pytorch takes in a 1D tensor of scores and concats it with the boxes before performing get_valid_counts and nms. As @masahi shows in this PR, there is preprocessing to embed all classes into that 1D tensor outside of the op. TF takes a 1D tensor of scores and concats it to the boxes before performing get_valid_counts and nms. I'm not sure if the rest of the TF graph is handling the loop over batch size and classes. TFlite takes a 3D score tensor of shape (batch size, num_anchors, class_id), reorders it to (batch_size, class_id, num_anchors), performs multibox_transform_loc->nms, and strangely does get_valid_counts after NMS. It looks like we're doing pre-processing in every framework to reduce the amount of score information and convert it to the 5 or 6 D form the nms API wants. None of the frameworks give us inputs in the packed form the API expects, and we jump through hoops in every importer to convert inputs into that form. Then in at least TFLite and ONNX, we perform further splitting/slicing/concatenating to restore the separate class ids. I think I agree with @masahi, we seem to be jumping through a lot of hoops in the importers to support a TVM NMS API that's out of line with the frameworks, and that might be hurting our overall performance. |
d9b9995
to
4d43fdc
Compare
4d43fdc
to
af38f4d
Compare
I highly agree with you guys. For class-aware NMS, the [batch, num_anchors, 6] format seems very inefficient. It means all anchors need to be checked just to see if the classes match. A [batch, num_classes, num_anchors, 5] format would give us a nicely defined slice of memory where the same-class anchors are located.
That's correct, TF's NMS is only for single class and single batch, so the TF graph loops over batches and classes. To do that, they use tf.map_fn so the execution of each NMS can actually still run in parallel. However, this turns into a mess of control flow operators and TensorArrays, so Relay isn't able to do the same parallelization. This PR's graph rewrite could actually benefit TF OD models as well, but the pattern is a lot more complicated for TF. |
@kevinthesun @zhiics @mbrookhart As shown in my new NMS PR #7257, this rewrite results in a better speed up with improved memory layout. Can we merge this? I have new rewrites coming to further optimize PyTorch NMS and MaskRCNN / FasterRCNN. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
…apache#7154) * add a pattern to rewrite nms to batched nms * update object detection test to add rewrite * updated tutorial * add doc * fixed coord_start * test fixed by setting force_surpress=False * revert tutorial change * add some comment to explain the pattern * update NMS pattern following frontend change
…apache#7154) * add a pattern to rewrite nms to batched nms * update object detection test to add rewrite * updated tutorial * add doc * fixed coord_start * test fixed by setting force_surpress=False * revert tutorial change * add some comment to explain the pattern * update NMS pattern following frontend change
…apache#7154) * add a pattern to rewrite nms to batched nms * update object detection test to add rewrite * updated tutorial * add doc * fixed coord_start * test fixed by setting force_surpress=False * revert tutorial change * add some comment to explain the pattern * update NMS pattern following frontend change
…apache#7154) * add a pattern to rewrite nms to batched nms * update object detection test to add rewrite * updated tutorial * add doc * fixed coord_start * test fixed by setting force_surpress=False * revert tutorial change * add some comment to explain the pattern * update NMS pattern following frontend change
…apache#7154) * add a pattern to rewrite nms to batched nms * update object detection test to add rewrite * updated tutorial * add doc * fixed coord_start * test fixed by setting force_surpress=False * revert tutorial change * add some comment to explain the pattern * update NMS pattern following frontend change
NMS used by PyTorch detection model actually performs multiclass NMS in one go, by adding different offsets to boxes from different classes so that two boxes from different classes never overlap. See
https://github.com/pytorch/vision/blob/3d60f498e71ba63b428edb184c9ac38fa3737fa6/torchvision/ops/boxes.py#L80-L89
But this means most of O(N**2) IOU tests we do in the NMS triangle loop are useless. The goal of this PR is to restore class indices which is one of the inputs to
batched_nms
function above and perform class-aware NMS for TVM-compiled detection models.I did this by pattern matching and rewriting after model import. Specifically, I pattern match against this subgraph corresponding to PyTorch
batched_nms
used by maskrcnn / faster rcnn.Unfortunately, this optimization didn't yield speedup I hoped: On GPU it only makes 70ms faster, and on CPU it actually makes it slightly slower (?) for some reason. I haven't looked into why it is not going much faster.
nvprof output from running MaskRCNN on GPU
Before
After
On CPU, the output from VM profiler:
Before
After
So performance wise this change doesn't matter much, but I hope this also serves as a non trivial use of pattern matching and rewrite.
cc @kevinthesun @mbrookhart @zhiics @t-vi What do you think?